Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import inspect
import os
import openai
Expand All @@ -7,8 +8,9 @@
from typing import Annotated, List, Tuple # noqa: UP035
from dsp.templates import passages2text
import json
from dspy.primitives.prediction import Prediction

from dspy.signatures.signature import ensure_signature
from dspy.signatures.signature import ensure_signature, make_signature


MAX_RETRIES = 3
Expand Down Expand Up @@ -71,7 +73,7 @@ def TypedChainOfThought(signature) -> dspy.Module: # noqa: N802
class TypedPredictor(dspy.Module):
def __init__(self, signature):
super().__init__()
self.signature = signature
self.signature = ensure_signature(signature)
self.predictor = dspy.Predict(signature)

def copy(self) -> "TypedPredictor":
Expand All @@ -81,7 +83,7 @@ def copy(self) -> "TypedPredictor":
def _make_example(type_) -> str:
# Note: DSPy will cache this call so we only pay the first time TypedPredictor is called.
json_object = dspy.Predict(
dspy.Signature(
make_signature(
"json_schema -> json_object",
"Make a very succinct json object that validates with the following schema",
),
Expand Down Expand Up @@ -127,8 +129,7 @@ def _prepare_signature(self) -> dspy.Signature:
name,
desc=field.json_schema_extra.get("desc", "")
+ (
". Respond with a single JSON object. JSON Schema: "
+ json.dumps(type_.model_json_schema())
". Respond with a single JSON object. JSON Schema: " + json.dumps(type_.model_json_schema())
),
format=lambda x, to_json=to_json: (x if isinstance(x, str) else to_json(x)),
parser=lambda x, from_json=from_json: from_json(_unwrap_json(x)),
Expand All @@ -152,13 +153,20 @@ def forward(self, **kwargs) -> dspy.Prediction:
for try_i in range(MAX_RETRIES):
result = self.predictor(**modified_kwargs, new_signature=signature)
errors = {}
parsed_results = {}
parsed_results = []
# Parse the outputs
for name, field in signature.output_fields.items():
for i, completion in enumerate(result.completions):
try:
value = getattr(result, name)
parser = field.json_schema_extra.get("parser", lambda x: x)
parsed_results[name] = parser(value)
parsed = {}
for name, field in signature.output_fields.items():
value = completion[name]
parser = field.json_schema_extra.get("parser", lambda x: x)
completion[name] = parser(value)
parsed[name] = parser(value)
# Instantiate the actual signature with the parsed values.
# This allow pydantic to validate the fields defined in the signature.
_dummy = self.signature(**kwargs, **parsed)
parsed_results.append(parsed)
except (pydantic.ValidationError, ValueError) as e:
errors[name] = _format_error(e)
# If we can, we add an example to the error message
Expand All @@ -168,11 +176,14 @@ def forward(self, **kwargs) -> dspy.Prediction:
continue # Only add examples to JSON objects
suffix, current_desc = current_desc[i:], current_desc[:i]
prefix = "You MUST use this format: "
if try_i + 1 < MAX_RETRIES \
and prefix not in current_desc \
and (example := self._make_example(field.annotation)):
if (
try_i + 1 < MAX_RETRIES
and prefix not in current_desc
and (example := self._make_example(field.annotation))
):
signature = signature.with_updated_fields(
name, desc=current_desc + "\n" + prefix + example + "\n" + suffix,
name,
desc=current_desc + "\n" + prefix + example + "\n" + suffix,
)
if errors:
# Add new fields for each error
Expand All @@ -187,11 +198,12 @@ def forward(self, **kwargs) -> dspy.Prediction:
)
else:
# If there are no errors, we return the parsed results
for name, value in parsed_results.items():
setattr(result, name, value)
return result
return Prediction.from_completions(
{key: [r[key] for r in parsed_results] for key in signature.output_fields}
)
raise ValueError(
"Too many retries trying to get the correct output format. " + "Try simplifying the requirements.", errors,
"Too many retries trying to get the correct output format. " + "Try simplifying the requirements.",
errors,
)


Expand Down
20 changes: 10 additions & 10 deletions dspy/primitives/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,29 @@
class Prediction(Example):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

del self._demos
del self._input_keys

self._completions = None

@classmethod
def from_completions(cls, list_or_dict, signature=None):
obj = cls()
obj._completions = Completions(list_or_dict, signature=signature)
obj._store = {k: v[0] for k, v in obj._completions.items()}

return obj

def __repr__(self):
store_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._store.items())
store_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._store.items())

if self._completions is None or len(self._completions) == 1:
return f"Prediction(\n {store_repr}\n)"

num_completions = len(self._completions)
return f"Prediction(\n {store_repr},\n completions=Completions(...)\n) ({num_completions-1} completions omitted)"

def __str__(self):
return self.__repr__()

Expand Down Expand Up @@ -62,15 +62,15 @@ def __getitem__(self, key):
if isinstance(key, int):
if key < 0 or key >= len(self):
raise IndexError("Index out of range")

return Prediction(**{k: v[key] for k, v in self._completions.items()})

return self._completions[key]

def __getattr__(self, name):
if name in self._completions:
return self._completions[name]

raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def __len__(self):
Expand All @@ -82,7 +82,7 @@ def __contains__(self, key):
return key in self._completions

def __repr__(self):
items_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._completions.items())
items_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._completions.items())
return f"Completions(\n {items_repr}\n)"

def __str__(self):
Expand Down
Loading