Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init_subclass__(cls, **kwargs) -> None:
cls.format = with_callbacks(cls.format)
cls.parse = with_callbacks(cls.parse)

def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
def __call__(self, lm, lm_kwargs, signature, demos, inputs):
inputs_ = self.format(signature, demos, inputs)
inputs_ = dict(prompt=inputs_) if isinstance(inputs_, str) else dict(messages=inputs_)

Expand All @@ -27,7 +27,7 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
if isinstance(output, dict):
output, output_logprobs = output["text"], output["logprobs"]

value = self.parse(signature, output, _parse_values=_parse_values)
value = self.parse(signature, output)

assert set(value.keys()) == set(signature.output_fields.keys()), \
f"Expected {signature.output_fields.keys()} but got {value.keys()}"
Expand All @@ -41,16 +41,16 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):

except Exception as e:
from .json_adapter import JSONAdapter
if _parse_values and not isinstance(self, JSONAdapter):
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs, _parse_values=_parse_values)
if not isinstance(self, JSONAdapter):
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)
raise e

@abstractmethod
def format(self, signature, demos, inputs):
raise NotImplementedError

@abstractmethod
def parse(self, signature, completion, _parse_values):
def parse(self, signature, completion):
raise NotImplementedError

def format_finetune_data(self, signature, demos, inputs, outputs):
Expand Down
4 changes: 2 additions & 2 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict
messages.append(format_turn(signature, inputs, role="user"))
return messages

def parse(self, signature, completion, _parse_values=True):
def parse(self, signature, completion):
sections = [(None, [])]

for line in completion.splitlines():
Expand All @@ -74,7 +74,7 @@ def parse(self, signature, completion, _parse_values=True):
for k, v in sections:
if (k not in fields) and (k in signature.output_fields):
try:
fields[k] = parse_value(v, signature.output_fields[k].annotation) if _parse_values else v
fields[k] = parse_value(v, signature.output_fields[k].annotation)
except Exception as e:
raise ValueError(
f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n```\n{v}\n```"
Expand Down
6 changes: 3 additions & 3 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class JSONAdapter(Adapter):
def __init__(self):
pass

def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
def __call__(self, lm, lm_kwargs, signature, demos, inputs):
inputs = self.format(signature, demos, inputs)
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)

Expand All @@ -58,7 +58,7 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
values = []

for output in outputs:
value = self.parse(signature, output, _parse_values=_parse_values)
value = self.parse(signature, output)
assert set(value.keys()) == set(
signature.output_fields.keys()
), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
Expand Down Expand Up @@ -90,7 +90,7 @@ def format(self, signature, demos, inputs):

return messages

def parse(self, signature, completion, _parse_values=True):
def parse(self, signature, completion):
fields = json_repair.loads(completion)
fields = {k: v for k, v in fields.items() if k in signature.output_fields}

Expand Down
7 changes: 0 additions & 7 deletions dspy/dsp/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,11 @@
adapter=None,
rm=None,
branch_idx=0,
reranker=None,
compiled_lm=None,
force_reuse_cached_compilation=False,
compiling=False,
skip_logprobs=False,
trace=[],
release=0,
bypass_assert=False,
bypass_suggest=False,
assert_failures=0,
suggest_failures=0,
langchain_history=[],
experimental=False,
backoff_time=10,
callbacks=[],
Expand Down
27 changes: 5 additions & 22 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import logging
import random
from functools import lru_cache

from pydantic import BaseModel

Expand All @@ -12,18 +10,12 @@
from dspy.utils.callback import with_callbacks


@lru_cache(maxsize=None)
def warn_once(msg: str):
logging.warning(msg)


class Predict(Module, Parameter):
def __init__(self, signature, _parse_values=True, callbacks=None, **config):
def __init__(self, signature, callbacks=None, **config):
self.stage = random.randbytes(8).hex()
self.signature = ensure_signature(signature)
self.config = config
self.callbacks = callbacks or []
self._parse_values = _parse_values
self.reset()

def reset(self):
Expand Down Expand Up @@ -80,7 +72,7 @@ def load_state(self, state):
self.signature = self.signature.load_state(state["signature"])

if "extended_signature" in state: # legacy, up to and including 2.5, for CoT.
self.signature = self.signature.load_state(state["extended_signature"])
raise NotImplementedError("Loading extended_signature is no longer supported in DSPy 2.6+")

return self

Expand All @@ -90,7 +82,6 @@ def __call__(self, **kwargs):

def forward(self, **kwargs):
import dspy
assert not dspy.settings.compiling, "It's no longer ever the case that .compiling is True"

# Extract the three privileged keyword arguments.
assert "new_signature" not in kwargs, "new_signature is no longer a valid keyword argument."
Expand All @@ -115,7 +106,9 @@ def forward(self, **kwargs):
missing = [k for k in signature.input_fields if k not in kwargs]
print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.")

completions = v2_5_generate(lm, config, signature, demos, kwargs, _parse_values=self._parse_values)
import dspy
adapter = dspy.settings.adapter or dspy.ChatAdapter()
completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs)

pred = Prediction.from_completions(completions, signature=signature)

Expand All @@ -135,16 +128,6 @@ def __repr__(self):
return f"{self.__class__.__name__}({self.signature})"


def v2_5_generate(lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
import dspy

adapter = dspy.settings.adapter or dspy.ChatAdapter()

return adapter(
lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values
)


# TODO: get some defaults during init from the context window?
# # TODO: FIXME: Hmm, I guess expected behavior is that contexts can
# affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates.
Expand Down
6 changes: 1 addition & 5 deletions dspy/teleprompt/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,6 @@ def _train(self):
sample_size = max(0, sample_size)

raw_demos = rng.sample(raw_demos, sample_size)

if dspy.settings.release >= 20230928:
predictor.demos = raw_demos + augmented_demos
else:
predictor.demos = augmented_demos + raw_demos
predictor.demos = augmented_demos + raw_demos

return self.student
Loading