diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 4dffc17e1c..e4d3b7b4d8 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -4,7 +4,7 @@ import threading import uuid from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Callable import litellm import ujson @@ -84,6 +84,10 @@ def __call__(self, prompt=None, messages=None, **kwargs): cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] kwargs = {**self.kwargs, **kwargs} + callable_kwargs = {} + for k, v in list(kwargs.items()): + if isinstance(v, Callable): + callable_kwargs[k] = kwargs.pop(k) # Make the request and handle LRU & disk caching. if self.model_type == "chat": @@ -94,6 +98,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): response = completion( request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)), num_retries=self.num_retries, + **callable_kwargs, ) outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] @@ -213,16 +218,18 @@ def copy(self, **kwargs): @functools.lru_cache(maxsize=None) -def cached_litellm_completion(request, num_retries: int): +def cached_litellm_completion(request, num_retries: int, **kwargs): return litellm_completion( request, cache={"no-cache": False, "no-store": False}, num_retries=num_retries, + **kwargs, ) -def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): - kwargs = ujson.loads(request) +def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}, **kwargs): + req_kwargs = ujson.loads(request) + kwargs = {**req_kwargs, **kwargs} return litellm.completion( num_retries=num_retries, cache=cache, @@ -231,17 +238,19 @@ def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-s @functools.lru_cache(maxsize=None) -def cached_litellm_text_completion(request, num_retries: int): +def cached_litellm_text_completion(request, num_retries: int,**kwargs): return litellm_text_completion( request, num_retries=num_retries, cache={"no-cache": False, "no-store": False}, + **kwargs, ) -def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): - kwargs = ujson.loads(request) - +def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True},**kwargs): + req_kwargs = ujson.loads(request) + kwargs = {**req_kwargs, **kwargs} + # Extract the provider and model from the model string. # TODO: Not all the models are in the format of "provider/model" model = kwargs.pop("model").split("/", 1)