Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 67 additions & 56 deletions dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy
import threading
from contextlib import contextmanager
from copy import deepcopy

from contextlib import contextmanager
from contextvars import ContextVar
from dsp.utils.utils import dotdict

DEFAULT_CONFIG = dotdict(
Expand All @@ -27,85 +28,95 @@
async_max_workers=8,
)

# Global base configuration
main_thread_config = copy.deepcopy(DEFAULT_CONFIG)

# Initialize the context variable with an empty dict as default
dspy_ctx_overrides = ContextVar('dspy_ctx_overrides', default=dotdict())


class Settings:
"""DSP configuration settings."""
"""
A singleton class for DSPy configuration settings.

This is thread-safe. User threads are supported both through ParallelExecutor and native threading.
- If native threading is used, the thread inherits the initial config from the main thread.
- If ParallelExecutor is used, the thread inherits the initial config from its parent thread.
"""

_instance = None

def __new__(cls):
"""
Singleton Pattern. See https://python-patterns.guide/gang-of-four/singleton/
"""

if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.lock = threading.Lock()
cls._instance.main_tid = threading.get_ident()
cls._instance.main_stack = []
cls._instance.stack_by_thread = {}
cls._instance.stack_by_thread[threading.get_ident()] = cls._instance.main_stack
cls._instance.lock = threading.Lock() # maintained here for assertions
return cls._instance

# TODO: remove first-class support for re-ranker and potentially combine with RM to form a pipeline of sorts
# eg: RetrieveThenRerankPipeline(RetrievalModel, Reranker)
# downstream operations like dsp.retrieve would use configs from the defined pipeline.
def __getattr__(self, name):
overrides = dspy_ctx_overrides.get()
if name in overrides:
return overrides[name]
elif name in main_thread_config:
return main_thread_config[name]
else:
raise AttributeError(f"'Settings' object has no attribute '{name}'")

# make a deepcopy of the default config to avoid modifying the default config
cls._instance.__append(deepcopy(DEFAULT_CONFIG))
def __setattr__(self, name, value):
if name in ('_instance',):
super().__setattr__(name, value)
else:
self.configure(**{name: value})

return cls._instance
# Dictionary-like access

@property
def config(self):
thread_id = threading.get_ident()
if thread_id not in self.stack_by_thread:
self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
return self.stack_by_thread[thread_id][-1]
def __getitem__(self, key):
return self.__getattr__(key)

def __getattr__(self, name):
if hasattr(self.config, name):
return getattr(self.config, name)
def __setitem__(self, key, value):
self.__setattr__(key, value)

if name in self.config:
return self.config[name]
def __contains__(self, key):
overrides = dspy_ctx_overrides.get()
return key in overrides or key in main_thread_config

super().__getattr__(name)
def get(self, key, default=None):
try:
return self[key]
except AttributeError:
return default

def __append(self, config):
thread_id = threading.get_ident()
if thread_id not in self.stack_by_thread:
self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
self.stack_by_thread[thread_id].append(config)
def copy(self):
overrides = dspy_ctx_overrides.get()
return dotdict({**main_thread_config, **overrides})

def __pop(self):
thread_id = threading.get_ident()
if thread_id in self.stack_by_thread:
self.stack_by_thread[thread_id].pop()
# Configuration methods

def configure(self, inherit_config: bool = True, **kwargs):
"""Set configuration settings.
def configure(self, return_token=False, **kwargs):
global main_thread_config
overrides = dspy_ctx_overrides.get()
new_overrides = dotdict({**main_thread_config, **overrides, **kwargs})
token = dspy_ctx_overrides.set(new_overrides)

Args:
inherit_config (bool, optional): Set configurations for the given, and use existing configurations for the rest. Defaults to True.
"""
if inherit_config:
config = {**self.config, **kwargs}
else:
config = {**kwargs}
# Update main_thread_config, in the main thread only
if threading.current_thread() is threading.main_thread():
main_thread_config = new_overrides

self.__append(config)
if return_token:
return token

@contextmanager
def context(self, inherit_config=True, **kwargs):
self.configure(inherit_config=inherit_config, **kwargs)

def context(self, **kwargs):
"""Context manager for temporary configuration changes."""
token = self.configure(return_token=True, **kwargs)
try:
yield
finally:
self.__pop()
dspy_ctx_overrides.reset(token)

def __repr__(self) -> str:
return repr(self.config)
def __repr__(self):
overrides = dspy_ctx_overrides.get()
combined_config = {**main_thread_config, **overrides}
return repr(combined_config)


settings = Settings()
settings = Settings()
21 changes: 3 additions & 18 deletions dspy/utils/asyncify.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,7 @@ def get_limiter():


def asyncify(program):
import dspy
import threading

assert threading.get_ident() == dspy.settings.main_tid, "asyncify can only be called from the main thread"

def wrapped(*args, **kwargs):
thread_stacks = dspy.settings.stack_by_thread
current_thread_id = threading.get_ident()
creating_new_thread = current_thread_id not in thread_stacks

assert creating_new_thread
thread_stacks[current_thread_id] = list(dspy.settings.main_stack)

try:
return program(*args, **kwargs)
finally:
del thread_stacks[threading.get_ident()]

return asyncer.asyncify(wrapped, abandon_on_cancel=True, limiter=get_limiter())
assert threading.current_thread() is threading.main_thread(), "asyncify can only be called from the main thread"
# NOTE: To allow this to be nested, we'd need behavior with contextvars like parallelizer.py
return asyncer.asyncify(program, abandon_on_cancel=True, limiter=get_limiter())
98 changes: 54 additions & 44 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import logging
import sys
import tqdm
import dspy
import signal
import logging
import threading
import traceback
import contextlib

from contextvars import copy_context
from tqdm.contrib.logging import logging_redirect_tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed


logger = logging.getLogger(__name__)


Expand All @@ -23,6 +22,8 @@ def __init__(
provide_traceback=False,
compare_results=False,
):
"""Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1."""

self.num_threads = num_threads
self.disable_progress_bar = disable_progress_bar
self.max_errors = max_errors
Expand All @@ -33,34 +34,18 @@ def __init__(
self.error_lock = threading.Lock()
self.cancel_jobs = threading.Event()


def execute(self, function, data):
wrapped_function = self._wrap_function(function)
if self.num_threads == 1:
return self._execute_single_thread(wrapped_function, data)
return self._execute_isolated_single_thread(wrapped_function, data)
else:
return self._execute_multi_thread(wrapped_function, data)


def _wrap_function(self, function):
# Wrap the function with threading context and error handling
def wrapped(item, parent_id=None):
thread_stacks = dspy.settings.stack_by_thread
current_thread_id = threading.get_ident()
creating_new_thread = current_thread_id not in thread_stacks

assert creating_new_thread or threading.get_ident() == dspy.settings.main_tid

if creating_new_thread:
# If we have a parent thread ID, copy its stack. TODO: Should the caller just pass a copy of the stack?
if parent_id and parent_id in thread_stacks:
thread_stacks[current_thread_id] = list(thread_stacks[parent_id])
else:
thread_stacks[current_thread_id] = list(dspy.settings.main_stack)

# TODO: Consider the behavior below.
# import copy; thread_stacks[current_thread_id].append(copy.deepcopy(thread_stacks[current_thread_id][-1]))

# Wrap the function with error handling
def wrapped(item):
if self.cancel_jobs.is_set():
return None
try:
return function(item)
except Exception as e:
Expand All @@ -79,45 +64,53 @@ def wrapped(item, parent_id=None):
f"Error processing item {item}: {e}. Set `provide_traceback=True` to see the stack trace."
)
return None
finally:
if creating_new_thread:
del thread_stacks[threading.get_ident()]
return wrapped


def _execute_single_thread(self, function, data):
def _execute_isolated_single_thread(self, function, data):
results = []
pbar = tqdm.tqdm(
total=len(data),
dynamic_ncols=True,
disable=self.disable_progress_bar,
file=sys.stdout,
file=sys.stdout
)

for item in data:
with logging_redirect_tqdm():
if self.cancel_jobs.is_set():
break
result = function(item)

# Create an isolated context for each task
task_ctx = copy_context()
result = task_ctx.run(function, item)
results.append(result)

if self.compare_results:
# Assumes score is the last element of the result tuple
self._update_progress(pbar, sum([r[-1] for r in results if r is not None]), len([r for r in data if r is not None]))
self._update_progress(
pbar,
sum([r[-1] for r in results if r is not None]),
len([r for r in data if r is not None]),
)
else:
self._update_progress(pbar, len(results), len(data))

pbar.close()

if self.cancel_jobs.is_set():
logger.warning("Execution was cancelled due to errors.")
raise Exception("Execution was cancelled due to errors.")
return results

return results

def _update_progress(self, pbar, nresults, ntotal):
if self.compare_results:
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({round(100 * nresults / ntotal, 1) if ntotal > 0 else 0}%)")
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)")
else:
pbar.set_description(f"Processed {nresults} / {ntotal} examples")
pbar.update()

pbar.update()

def _execute_multi_thread(self, function, data):
results = [None] * len(data) # Pre-allocate results list to maintain order
Expand All @@ -132,6 +125,7 @@ def interrupt_handler_manager():
def interrupt_handler(sig, frame):
self.cancel_jobs.set()
logger.warning("Received SIGINT. Cancelling execution.")
# Re-raise the signal to allow default behavior
default_handler(sig, frame)

signal.signal(signal.SIGINT, interrupt_handler)
Expand All @@ -143,37 +137,53 @@ def interrupt_handler(sig, frame):
# If not in the main thread, skip setting signal handlers
yield

def cancellable_function(index_item, parent_id=None):
def cancellable_function(index_item):
index, item = index_item
if self.cancel_jobs.is_set():
return index, job_cancelled
return index, function(item, parent_id)

parent_id = threading.get_ident() if threading.current_thread() is not threading.main_thread() else None
return index, function(item)

with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager():
futures = {executor.submit(cancellable_function, pair, parent_id): pair for pair in enumerate(data)}
futures = {}
for pair in enumerate(data):
# Capture the context for each task
task_ctx = copy_context()
future = executor.submit(task_ctx.run, cancellable_function, pair)
futures[future] = pair

pbar = tqdm.tqdm(
total=len(data),
dynamic_ncols=True,
disable=self.disable_progress_bar,
file=sys.stdout,
file=sys.stdout
)

for future in as_completed(futures):
index, result = future.result()

if result is job_cancelled:
continue

results[index] = result

if self.compare_results:
# Assumes score is the last element of the result tuple
self._update_progress(pbar, sum([r[-1] for r in results if r is not None]), len([r for r in results if r is not None]))
self._update_progress(
pbar,
sum([r[-1] for r in results if r is not None]),
len([r for r in results if r is not None]),
)
else:
self._update_progress(pbar, len([r for r in results if r is not None]), len(data))
self._update_progress(
pbar,
len([r for r in results if r is not None]),
len(data),
)

pbar.close()

if self.cancel_jobs.is_set():
logger.warning("Execution was cancelled due to errors.")
raise Exception("Execution was cancelled due to errors.")

return results
Loading