Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9"]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
Expand All @@ -64,20 +64,20 @@ jobs:
- name: Install Poetry
run: python -m pip install --upgrade "poetry==${{ env.POETRY_VERSION }}"
- name: Install dependencies
run: poetry install --no-interaction
run: |
poetry install --no-interaction || poetry lock && poetry install --no-interaction
- name: Run lint with tests
uses: chartboost/ruff-action@v1
with:
args: check --fix-only
- name: Run tests with pytest
run: poetry run pytest tests/

build_poetry:
name: Build Poetry
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9"]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v4
- name: Load cached Poetry installation
Expand All @@ -103,7 +103,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9"]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v4
- name: Load cached Poetry installation
Expand Down
13 changes: 9 additions & 4 deletions dspy/predict/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ def __init__(
self,
num_threads: int = 32,
max_errors: int = 10,
access_examples: bool = True,
return_failed_examples: bool = False,
provide_traceback: bool = False,
disable_progress_bar: bool = False,
):
super().__init__()
self.num_threads = num_threads
self.max_errors = max_errors
self.access_examples = access_examples
self.return_failed_examples = return_failed_examples
self.provide_traceback = provide_traceback
self.disable_progress_bar = disable_progress_bar
Expand All @@ -28,7 +30,6 @@ def __init__(
self.failed_examples = []
self.exceptions = []


def forward(self, exec_pairs: List[Tuple[Any, Example]], num_threads: int = None) -> List[Any]:
num_threads = num_threads if num_threads is not None else self.num_threads

Expand All @@ -44,15 +45,20 @@ def process_pair(pair):
module, example = pair

if isinstance(example, Example):
result = module(**example.inputs())
if self.access_examples:
result = module(**example.inputs())
else:
result = module(example)
elif isinstance(example, dict):
result = module(**example)
elif isinstance(example, list) and module.__class__.__name__ == "Parallel":
result = module(example)
elif isinstance(example, tuple):
result = module(*example)
else:
raise ValueError(f"Invalid example type: {type(example)}, only supported types are Example, dict, list and tuple")
raise ValueError(
f"Invalid example type: {type(example)}, only supported types are Example, dict, list and tuple"
)
return result

# Execute the processing function over the execution pairs
Expand All @@ -63,6 +69,5 @@ def process_pair(pair):
else:
return results


def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kwargs)
1 change: 1 addition & 0 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
self.cancel_jobs = threading.Event()

def execute(self, function, data):
tqdm.tqdm._instances.clear()
wrapped = self._wrap_function(function)
return self._execute_parallel(wrapped, data)

Expand Down
Loading
Loading