-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Add dspy.Embedding
#1735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dspy.Embedding
#1735
Conversation
d33a487 to
7c51351
Compare
| litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") | ||
| litellm.telemetry = False | ||
|
|
||
| if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs to be done before LiteLLM is imported anywhere in DSPy, for it to have an effect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I searched their code, and this env var is read at runtime: https://github.com/BerriAI/litellm/blob/5652c375b3e22bab6704e93058c868620c72d6ee/litellm/__init__.py#L309, so our current order should be okay.
dspy/clients/embedding.py
Outdated
| kwargs: Additional keyword arguments to pass to the embedding model. | ||
| Returns: | ||
| A list of embeddings, one for each input, in the same order as the inputs. Or the output of the custom |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we ensure the output of this is a numpy tensor or something? Both for litellm and for callables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
|
An ideal version of this PR would involve improving the docs at this page: https://dspy-docs.vercel.app/quick-start/getting-started-02/ (see the second cell, or see below) import torch
import functools
from litellm import embedding as Embed
with open("test_collection.jsonl") as f:
corpus = [ujson.loads(line) for line in f]
index = torch.load('index.pt', weights_only=True)
max_characters = 4000 # >98th percentile of document lengths
@functools.lru_cache(maxsize=None)
def search(query, k=5):
query_embedding = torch.tensor(Embed(input=query, model="text-embedding-3-small").data[0]['embedding'])
topk_scores, topk_indices = torch.matmul(index, query_embedding).topk(k)
topK = [dict(score=score.item(), **corpus[idx]) for idx, score in zip(topk_indices, topk_scores)]
return [doc['text'][:max_characters] for doc in topK]I'd love to get the same functionality but without that complexity... |
76f162b to
40bfd8b
Compare
Very simple
dspy.Embeddingsupports:dspy.Embedding, and the output is just the custom callable's output.Added unit test for both scenarios.
Confirmed that the cache works: