diff --git a/dspy/dsp/utils/settings.py b/dspy/dsp/utils/settings.py index 821f80e03b..79295ea029 100644 --- a/dspy/dsp/utils/settings.py +++ b/dspy/dsp/utils/settings.py @@ -19,26 +19,34 @@ async_max_workers=8, ) -# Global base configuration +# Global base configuration and owner tracking main_thread_config = copy.deepcopy(DEFAULT_CONFIG) +config_owner_thread_id = None +# Global lock for settings configuration +global_lock = threading.Lock() class ThreadLocalOverrides(threading.local): def __init__(self): - self.overrides = dotdict() # Initialize thread-local overrides + self.overrides = dotdict() - -# Create the thread-local storage thread_local_overrides = ThreadLocalOverrides() class Settings: """ A singleton class for DSPy configuration settings. - - This is thread-safe. User threads are supported both through ParallelExecutor and native threading. - - If native threading is used, the thread inherits the initial config from the main thread. - - If ParallelExecutor is used, the thread inherits the initial config from its parent thread. + Thread-safe global configuration. + - 'configure' can be called by only one 'owner' thread (the first thread that calls it). + - Other threads see the configured global values from 'main_thread_config'. + - 'context' sets thread-local overrides. These overrides propagate to threads spawned + inside that context block, when (and only when!) using a ParallelExecutor that copies overrides. + + 1. Only one unique thread (which can be any thread!) can call dspy.configure. + 2. It affects a global state, visible to all. As a result, user threads work, but they shouldn't be + mixed with concurrent changes to dspy.configure from the "main" thread. + (TODO: In the future, add warnings: if there are near-in-time user-thread reads followed by .configure calls.) + 3. Any thread can use dspy.context. It propagates to child threads created with DSPy primitives: Parallel, asyncify, etc. """ _instance = None @@ -46,9 +54,12 @@ class Settings: def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) - cls._instance.lock = threading.Lock() # maintained here for DSPy assertions.py return cls._instance + @property + def lock(self): + return global_lock + def __getattr__(self, name): overrides = getattr(thread_local_overrides, 'overrides', dotdict()) if name in overrides: @@ -64,8 +75,6 @@ def __setattr__(self, name, value): else: self.configure(**{name: value}) - # Dictionary-like access - def __getitem__(self, key): return self.__getattr__(key) @@ -88,42 +97,40 @@ def copy(self): @property def config(self): - config = self.copy() - if 'lock' in config: - del config['lock'] - return config - - # Configuration methods + return self.copy() def configure(self, **kwargs): - global main_thread_config + global main_thread_config, config_owner_thread_id + current_thread_id = threading.get_ident() - # Get or initialize thread-local overrides - overrides = getattr(thread_local_overrides, 'overrides', dotdict()) - thread_local_overrides.overrides = dotdict( - {**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs} - ) + with self.lock: + # First configuration: establish ownership. If ownership established, only that thread can configure. + if config_owner_thread_id in [None, current_thread_id]: + config_owner_thread_id = current_thread_id + else: + raise RuntimeError("dspy.settings can only be changed by the thread that initially configured it.") - # Update main_thread_config, in the main thread only - if threading.current_thread() is threading.main_thread(): - main_thread_config = thread_local_overrides.overrides + # Update global config + for k, v in kwargs.items(): + main_thread_config[k] = v @contextmanager def context(self, **kwargs): - """Context manager for temporary configuration changes.""" - global main_thread_config + """ + Context manager for temporary configuration changes at the thread level. + Does not affect global configuration. Changes only apply to the current thread. + If threads are spawned inside this block using ParallelExecutor, they will inherit these overrides. + """ + original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy() - original_main_thread_config = main_thread_config.copy() + new_overrides = dotdict({**main_thread_config, **original_overrides, **kwargs}) + thread_local_overrides.overrides = new_overrides - self.configure(**kwargs) try: yield finally: thread_local_overrides.overrides = original_overrides - if threading.current_thread() is threading.main_thread(): - main_thread_config = original_main_thread_config - def __repr__(self): overrides = getattr(thread_local_overrides, 'overrides', dotdict()) combined_config = {**main_thread_config, **overrides} diff --git a/dspy/utils/asyncify.py b/dspy/utils/asyncify.py index 9f55df81a3..3ef10efc7b 100644 --- a/dspy/utils/asyncify.py +++ b/dspy/utils/asyncify.py @@ -10,7 +10,6 @@ def get_async_max_workers(): import dspy - return dspy.settings.async_max_workers @@ -31,28 +30,31 @@ def asyncify(program: Module) -> Callable[[Any, Any], Awaitable[Any]]: Wraps a DSPy program so that it can be called asynchronously. This is useful for running a program in parallel with another task (e.g., another DSPy program). + This implementation propagates the current thread's configuration context to the worker thread. + Args: program: The DSPy program to be wrapped for asynchronous execution. Returns: - A function that takes the same arguments as the program, but returns an awaitable that - resolves to the program's output. - - Example: - >>> class TestSignature(dspy.Signature): - >>> input_text: str = dspy.InputField() - >>> output_text: str = dspy.OutputField() - >>> - >>> # Create the program and wrap it for asynchronous execution - >>> program = dspy.asyncify(dspy.Predict(TestSignature)) - >>> - >>> # Use the program asynchronously - >>> async def get_prediction(): - >>> prediction = await program(input_text="Test") - >>> print(prediction) # Handle the result of the asynchronous execution + An async function that, when awaited, runs the program in a worker thread. The current + thread's configuration context is inherited for each call. """ - import threading - - assert threading.current_thread() is threading.main_thread(), "asyncify can only be called from the main thread" - # NOTE: To allow this to be nested, we'd need behavior with contextvars like parallelizer.py - return asyncer.asyncify(program, abandon_on_cancel=True, limiter=get_limiter()) + async def async_program(*args, **kwargs) -> Any: + # Capture the current overrides at call-time. + from dspy.dsp.utils.settings import thread_local_overrides + parent_overrides = thread_local_overrides.overrides.copy() + + def wrapped_program(*a, **kw): + from dspy.dsp.utils.settings import thread_local_overrides + original_overrides = thread_local_overrides.overrides + thread_local_overrides.overrides = parent_overrides.copy() + try: + return program(*a, **kw) + finally: + thread_local_overrides.overrides = original_overrides + + # Create a fresh asyncified callable each time, ensuring the latest context is used. + call_async = asyncer.asyncify(wrapped_program, abandon_on_cancel=True, limiter=get_limiter()) + return await call_async(*args, **kwargs) + + return async_program diff --git a/dspy/utils/parallelizer.py b/dspy/utils/parallelizer.py index 7f0ce758f7..08dae84597 100644 --- a/dspy/utils/parallelizer.py +++ b/dspy/utils/parallelizer.py @@ -10,6 +10,7 @@ logger = logging.getLogger(__name__) + class ParallelExecutor: def __init__( self, @@ -20,7 +21,6 @@ def __init__( compare_results=False, ): """Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1.""" - self.num_threads = num_threads self.disable_progress_bar = disable_progress_bar self.max_errors = max_errors @@ -72,15 +72,17 @@ def _execute_isolated_single_thread(self, function, data): file=sys.stdout ) + from dspy.dsp.utils.settings import thread_local_overrides + original_overrides = thread_local_overrides.overrides + for item in data: with logging_redirect_tqdm(): if self.cancel_jobs.is_set(): break - # Create an isolated context for each task using thread-local overrides - from dspy.dsp.utils.settings import thread_local_overrides - original_overrides = thread_local_overrides.overrides - thread_local_overrides.overrides = thread_local_overrides.overrides.copy() + # Create an isolated context for each task by copying current overrides + # This way, even if an iteration modifies the overrides, it won't affect subsequent iterations + thread_local_overrides.overrides = original_overrides.copy() try: result = function(item) @@ -122,6 +124,8 @@ def _execute_multi_thread(self, function, data): @contextlib.contextmanager def interrupt_handler_manager(): """Sets the cancel_jobs event when a SIGINT is received, only in the main thread.""" + + # TODO: Is this check conducive to nested usage of ParallelExecutor? if threading.current_thread() is threading.main_thread(): default_handler = signal.getsignal(signal.SIGINT) @@ -145,7 +149,7 @@ def cancellable_function(parent_overrides, index_item): if self.cancel_jobs.is_set(): return index, job_cancelled - # Create an isolated context for each task using thread-local overrides + # Create an isolated context for each task by copying parent's overrides from dspy.dsp.utils.settings import thread_local_overrides original_overrides = thread_local_overrides.overrides thread_local_overrides.overrides = parent_overrides.copy() @@ -156,7 +160,6 @@ def cancellable_function(parent_overrides, index_item): thread_local_overrides.overrides = original_overrides with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager(): - # Capture the parent thread's overrides from dspy.dsp.utils.settings import thread_local_overrides parent_overrides = thread_local_overrides.overrides.copy()