Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import copy
import threading

from contextlib import contextmanager
from contextvars import ContextVar
from dsp.utils.utils import dotdict

DEFAULT_CONFIG = dotdict(
Expand Down Expand Up @@ -31,8 +29,14 @@
# Global base configuration
main_thread_config = copy.deepcopy(DEFAULT_CONFIG)

# Initialize the context variable with an empty dict as default
dspy_ctx_overrides = ContextVar('dspy_ctx_overrides', default=dotdict())

class ThreadLocalOverrides(threading.local):
def __init__(self):
self.overrides = dotdict() # Initialize thread-local overrides


# Create the thread-local storage
thread_local_overrides = ThreadLocalOverrides()


class Settings:
Expand All @@ -53,7 +57,7 @@ def __new__(cls):
return cls._instance

def __getattr__(self, name):
overrides = dspy_ctx_overrides.get()
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
if name in overrides:
return overrides[name]
elif name in main_thread_config:
Expand All @@ -76,7 +80,7 @@ def __setitem__(self, key, value):
self.__setattr__(key, value)

def __contains__(self, key):
overrides = dspy_ctx_overrides.get()
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
return key in overrides or key in main_thread_config

def get(self, key, default=None):
Expand All @@ -86,45 +90,49 @@ def get(self, key, default=None):
return default

def copy(self):
overrides = dspy_ctx_overrides.get()
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
return dotdict({**main_thread_config, **overrides})

@property
def config(self):
config = self.copy()
del config['lock']
if 'lock' in config:
del config['lock']
return config

# Configuration methods

def configure(self, return_token=False, **kwargs):
def configure(self, **kwargs):
global main_thread_config
overrides = dspy_ctx_overrides.get()
new_overrides = dotdict({**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs})
token = dspy_ctx_overrides.set(new_overrides)

# Get or initialize thread-local overrides
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
thread_local_overrides.overrides = dotdict(
{**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs}
)

# Update main_thread_config, in the main thread only
if threading.current_thread() is threading.main_thread():
main_thread_config = new_overrides

if return_token:
return token
main_thread_config = thread_local_overrides.overrides

@contextmanager
def context(self, **kwargs):
"""Context manager for temporary configuration changes."""
token = self.configure(return_token=True, **kwargs)
global main_thread_config
original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy()
original_main_thread_config = main_thread_config.copy()

self.configure(**kwargs)
try:
yield
finally:
dspy_ctx_overrides.reset(token)
thread_local_overrides.overrides = original_overrides

if threading.current_thread() is threading.main_thread():
global main_thread_config
main_thread_config = dotdict({**copy.deepcopy(DEFAULT_CONFIG), **dspy_ctx_overrides.get()})
main_thread_config = original_main_thread_config

def __repr__(self):
overrides = dspy_ctx_overrides.get()
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
combined_config = {**main_thread_config, **overrides}
return repr(combined_config)

Expand Down
39 changes: 27 additions & 12 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
import threading
import traceback
import contextlib

from contextvars import copy_context
from tqdm.contrib.logging import logging_redirect_tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

logger = logging.getLogger(__name__)


class ParallelExecutor:
def __init__(
self,
Expand Down Expand Up @@ -80,10 +77,16 @@ def _execute_isolated_single_thread(self, function, data):
if self.cancel_jobs.is_set():
break

# Create an isolated context for each task
task_ctx = copy_context()
result = task_ctx.run(function, item)
results.append(result)
# Create an isolated context for each task using thread-local overrides
from dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = thread_local_overrides.overrides.copy()

try:
result = function(item)
results.append(result)
finally:
thread_local_overrides.overrides = original_overrides

if self.compare_results:
# Assumes score is the last element of the result tuple
Expand Down Expand Up @@ -137,18 +140,30 @@ def interrupt_handler(sig, frame):
# If not in the main thread, skip setting signal handlers
yield

def cancellable_function(index_item):
def cancellable_function(parent_overrides, index_item):
index, item = index_item
if self.cancel_jobs.is_set():
return index, job_cancelled
return index, function(item)

# Create an isolated context for each task using thread-local overrides
from dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()

try:
return index, function(item)
finally:
thread_local_overrides.overrides = original_overrides

with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager():
# Capture the parent thread's overrides
from dsp.utils.settings import thread_local_overrides
parent_overrides = thread_local_overrides.overrides.copy()

futures = {}
for pair in enumerate(data):
# Capture the context for each task
task_ctx = copy_context()
future = executor.submit(task_ctx.run, cancellable_function, pair)
# Pass the parent thread's overrides to each thread
future = executor.submit(cancellable_function, parent_overrides, pair)
futures[future] = pair

pbar = tqdm.tqdm(
Expand Down
Loading