Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 41 additions & 35 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,45 @@
import os
import uuid
import ujson
import functools
from pathlib import Path
import os
import uuid
from datetime import datetime
from pathlib import Path

try:
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
import litellm
litellm.telemetry = False
import litellm
import ujson
from litellm.caching import Cache

from litellm.caching import Cache
disk_cache_dir = os.environ.get('DSPY_CACHEDIR') or os.path.join(Path.home(), '.dspy_cache')
litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk")
disk_cache_dir = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk")
litellm.telemetry = False

except ImportError:
class LitellmPlaceholder:
def __getattr__(self, _): raise ImportError("The LiteLLM package is not installed. Run `pip install litellm`.")
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

litellm = LitellmPlaceholder()

class LM:
def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, cache=True, **kwargs):
def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, cache=True, **kwargs):
self.model = model
self.model_type = model_type
self.cache = cache
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
self.history = []

if "o1-" in model:
assert max_tokens >= 5000 and temperature == 1.0, \
"OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"
assert (
max_tokens >= 5000 and temperature == 1.0
), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"

def __call__(self, prompt=None, messages=None, **kwargs):
# Build the request.
cache = kwargs.pop("cache", self.cache)
messages = messages or [{"role": "user", "content": prompt}]
kwargs = {**self.kwargs, **kwargs}

# Make the request and handle LRU & disk caching.
if self.model_type == "chat": completion = cached_litellm_completion if cache else litellm_completion
else: completion = cached_litellm_text_completion if cache else litellm_text_completion
if self.model_type == "chat":
completion = cached_litellm_completion if cache else litellm_completion
else:
completion = cached_litellm_text_completion if cache else litellm_text_completion

response = completion(ujson.dumps(dict(model=self.model, messages=messages, **kwargs)))
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
Expand All @@ -63,8 +57,9 @@ def __call__(self, prompt=None, messages=None, **kwargs):
model_type=self.model_type,
)
self.history.append(entry)

return outputs

def inspect_history(self, n: int = 1):
_inspect_history(self, n)

Expand All @@ -73,14 +68,17 @@ def inspect_history(self, n: int = 1):
def cached_litellm_completion(request):
return litellm_completion(request, cache={"no-cache": False, "no-store": False})


def litellm_completion(request, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
return litellm.completion(cache=cache, **kwargs)


@functools.lru_cache(maxsize=None)
def cached_litellm_text_completion(request):
return litellm_text_completion(request, cache={"no-cache": False, "no-store": False})


def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)

Expand All @@ -93,32 +91,40 @@ def litellm_text_completion(request, cache={"no-cache": True, "no-store": True})
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")

# Build the prompt from the messages.
prompt = '\n\n'.join([x['content'] for x in kwargs.pop("messages")] + ['BEGIN RESPONSE:'])
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])

return litellm.text_completion(cache=cache, model=f'text-completion-openai/{model}', api_key=api_key,
api_base=api_base, prompt=prompt, **kwargs)
return litellm.text_completion(
cache=cache,
model=f"text-completion-openai/{model}",
api_key=api_key,
api_base=api_base,
prompt=prompt,
**kwargs,
)


def _green(text: str, end: str = "\n"):
return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end


def _red(text: str, end: str = "\n"):
return "\x1b[31m" + str(text) + "\x1b[0m" + end


def _inspect_history(lm, n: int = 1):
"""Prints the last n prompts and their completions."""

for item in lm.history[-n:]:
messages = item["messages"] or [{"role": "user", "content": item['prompt']}]
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}]
outputs = item["outputs"]
timestamp = item.get("timestamp", "Unknown time")

print("\n\n\n")
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n")

for msg in messages:
print(_red(f"{msg['role'].capitalize()} message:"))
print(msg['content'].strip())
print(msg["content"].strip())
print("\n")

print(_red("Response:"))
Expand All @@ -127,5 +133,5 @@ def _inspect_history(lm, n: int = 1):
if len(outputs) > 1:
choices_text = f" \t (and {len(outputs)-1} other completions)"
print(_red(choices_text, end=""))
print("\n\n\n")

print("\n\n\n")