Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import threading
from contextlib import contextmanager
from copy import deepcopy

from contextlib import contextmanager
from dsp.utils.utils import dotdict

DEFAULT_CONFIG = dotdict(
Expand Down Expand Up @@ -49,17 +49,16 @@ def __new__(cls):
# TODO: remove first-class support for re-ranker and potentially combine with RM to form a pipeline of sorts
# eg: RetrieveThenRerankPipeline(RetrievalModel, Reranker)
# downstream operations like dsp.retrieve would use configs from the defined pipeline.

# make a deepcopy of the default config to avoid modifying the default config
cls._instance.__append(deepcopy(DEFAULT_CONFIG))
config = copy.deepcopy(DEFAULT_CONFIG)
cls._instance.__append(config)

return cls._instance

@property
def config(self):
thread_id = threading.get_ident()
if thread_id not in self.stack_by_thread:
self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
# if thread_id not in self.stack_by_thread:
# self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
return self.stack_by_thread[thread_id][-1]

def __getattr__(self, name):
Expand All @@ -73,14 +72,14 @@ def __getattr__(self, name):

def __append(self, config):
thread_id = threading.get_ident()
if thread_id not in self.stack_by_thread:
self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
# if thread_id not in self.stack_by_thread:
# self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
self.stack_by_thread[thread_id].append(config)

def __pop(self):
thread_id = threading.get_ident()
if thread_id in self.stack_by_thread:
self.stack_by_thread[thread_id].pop()
# if thread_id in self.stack_by_thread:
self.stack_by_thread[thread_id].pop()

def configure(self, inherit_config: bool = True, **kwargs):
"""Set configuration settings.
Expand Down
8 changes: 6 additions & 2 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,20 @@ def _wrap_function(self, function):
def wrapped(item, parent_id=None):
thread_stacks = dspy.settings.stack_by_thread
current_thread_id = threading.get_ident()

creating_new_thread = current_thread_id not in thread_stacks

assert creating_new_thread or threading.get_ident() == dspy.settings.main_tid

if creating_new_thread:
# If we have a parent thread ID, copy its stack
# If we have a parent thread ID, copy its stack. TODO: Should the caller just pass a copy of the stack?
if parent_id and parent_id in thread_stacks:
thread_stacks[current_thread_id] = list(thread_stacks[parent_id])
else:
thread_stacks[current_thread_id] = list(dspy.settings.main_stack)

# TODO: Consider the behavior below.
# import copy; thread_stacks[current_thread_id].append(copy.deepcopy(thread_stacks[current_thread_id][-1]))

try:
return function(item)
except Exception as e:
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import dspy
import copy

from dsp.utils.settings import DEFAULT_CONFIG


Expand All @@ -10,7 +11,7 @@ def clear_settings():

yield

dspy.settings.configure(**DEFAULT_CONFIG, inherit_config=False)
dspy.settings.configure(**copy.deepcopy(DEFAULT_CONFIG), inherit_config=False)


@pytest.fixture
Expand Down
Loading