Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dsp/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
langchain_history=[],
experimental=False,
backoff_time=10,
callbacks=[],
)


Expand Down
14 changes: 14 additions & 0 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
from dspy.utils.callback import with_callbacks


class Adapter:
def __init__(self, callbacks=None):
self.callbacks = callbacks or []

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
from dspy.utils.callback import with_callbacks

# Decorate format() and parse() method with with_callbacks
cls.format = with_callbacks(cls.format)
cls.parse = with_callbacks(cls.parse)

def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
inputs = self.format(signature, demos, inputs)
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)
Expand Down
4 changes: 4 additions & 0 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
execute_finetune_job,
)

from dspy.utils.callback import with_callbacks
import litellm
from litellm.caching import Cache

Expand All @@ -35,6 +36,7 @@ def __init__(
max_tokens=1000,
cache=True,
launch_kwargs=None,
callbacks=None,
**kwargs
):
# Remember to update LM.copy() if you modify the constructor!
Expand All @@ -44,6 +46,7 @@ def __init__(
self.launch_kwargs = launch_kwargs or {}
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
self.history = []
self.callbacks = callbacks or []

# TODO: Arbitrary model strings could include the substring "o1-". We
# should find a more robust way to check for the "o1-" family models.
Expand All @@ -52,6 +55,7 @@ def __init__(
max_tokens >= 5000 and temperature == 1.0
), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"

@with_callbacks
def __call__(self, prompt=None, messages=None, **kwargs):
# Build the request.
cache = kwargs.pop("cache", self.cache)
Expand Down
5 changes: 4 additions & 1 deletion dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dspy.primitives.prediction import Prediction
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature, signature_to_template
from dspy.utils.callback import with_callbacks


@lru_cache(maxsize=None)
Expand All @@ -17,10 +18,11 @@ def warn_once(msg: str):


class Predict(Module, Parameter):
def __init__(self, signature, _parse_values=True, **config):
def __init__(self, signature, _parse_values=True, callbacks=None, **config):
self.stage = random.randbytes(8).hex()
self.signature = ensure_signature(signature)
self.config = config
self.callbacks = callbacks or []
self._parse_values = _parse_values
self.reset()

Expand Down Expand Up @@ -114,6 +116,7 @@ def _load_state_legacy(self, state):
*_, last_key = self.extended_signature.fields.keys()
self.extended_signature = self.extended_signature.with_updated_fields(last_key, prefix=prefix)

@with_callbacks
def __call__(self, **kwargs):
return self.forward(**kwargs)

Expand Down
5 changes: 4 additions & 1 deletion dspy/primitives/program.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dspy.utils.callback import with_callbacks
import magicattr

import dspy
Expand All @@ -13,9 +14,11 @@ class Module(BaseModule, metaclass=ProgramMeta):
def _base_init(self):
self._compiled = False

def __init__(self):
def __init__(self, callbacks=None):
self.callbacks = callbacks or []
self._compiled = False

@with_callbacks
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

Expand Down
5 changes: 4 additions & 1 deletion dspy/retrieve/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dsp
from dspy.predict.parameter import Parameter
from dspy.primitives.prediction import Prediction
from dspy.utils.callback import with_callbacks


def single_query_passage(passages):
Expand All @@ -21,9 +22,10 @@ class Retrieve(Parameter):
input_variable = "query"
desc = "takes a search query and returns one or more potentially relevant passages from a corpus"

def __init__(self, k=3):
def __init__(self, k=3, callbacks=None):
self.stage = random.randbytes(8).hex()
self.k = k
self.callbacks = callbacks or []

def reset(self):
pass
Expand All @@ -37,6 +39,7 @@ def load_state(self, state):
for name, value in state.items():
setattr(self, name, value)

@with_callbacks
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions dspy/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .callback import *
from .dummies import *
from .logging import *
Loading
Loading