diff --git a/dsp/utils/settings.py b/dsp/utils/settings.py index 76943784a4..e973f47052 100644 --- a/dsp/utils/settings.py +++ b/dsp/utils/settings.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from dsp.utils.utils import dotdict +from functools import lru_cache DEFAULT_CONFIG = dotdict( lm=None, @@ -27,6 +28,12 @@ async_max_workers=8, ) +@lru_cache(maxsize=None) +def warn_once(msg: str): + import logging + logger = logging.getLogger(__name__) + logger.warning(msg) + class Settings: """DSP configuration settings.""" @@ -59,7 +66,11 @@ def config(self): thread_id = threading.get_ident() # if thread_id not in self.stack_by_thread: # self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()] - return self.stack_by_thread[thread_id][-1] + try: + return self.stack_by_thread[thread_id][-1] + except Exception: + warn_once("Warning: You seem to be creating DSPy threads in an unsupported way.") + return self.main_stack[-1] def __getattr__(self, name): if hasattr(self.config, name): @@ -74,6 +85,8 @@ def __append(self, config): thread_id = threading.get_ident() # if thread_id not in self.stack_by_thread: # self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()] + + assert thread_id in self.stack_by_thread, "Error: You seem to be creating DSPy threads in an unsupported way." self.stack_by_thread[thread_id].append(config) def __pop(self): diff --git a/dspy/utils/asyncify.py b/dspy/utils/asyncify.py index 8cd8062987..ca801e12a1 100644 --- a/dspy/utils/asyncify.py +++ b/dspy/utils/asyncify.py @@ -24,4 +24,22 @@ def get_limiter(): def asyncify(program): - return asyncer.asyncify(program, abandon_on_cancel=True, limiter=get_limiter()) + import dspy + import threading + + assert threading.get_ident() == dspy.settings.main_tid, "asyncify can only be called from the main thread" + + def wrapped(*args, **kwargs): + thread_stacks = dspy.settings.stack_by_thread + current_thread_id = threading.get_ident() + creating_new_thread = current_thread_id not in thread_stacks + + assert creating_new_thread + thread_stacks[current_thread_id] = list(dspy.settings.main_stack) + + try: + return program(*args, **kwargs) + finally: + del thread_stacks[threading.get_ident()] + + return asyncer.asyncify(wrapped, abandon_on_cancel=True, limiter=get_limiter())