-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Add dspy.Embedding
#1735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add dspy.Embedding
#1735
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,15 @@ | ||
| from .lm import LM | ||
| from .base_lm import BaseLM, inspect_history | ||
| from .embedding import Embedding | ||
| import litellm | ||
| import os | ||
| from pathlib import Path | ||
| from litellm.caching import Cache | ||
|
|
||
| DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") | ||
| litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") | ||
| litellm.telemetry = False | ||
|
|
||
| if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: | ||
| # accessed at run time by litellm; i.e., fine to keep after import | ||
| os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| import litellm | ||
| import numpy as np | ||
|
|
||
|
|
||
| class Embedding: | ||
| """DSPy embedding class. | ||
|
|
||
| The class for computing embeddings for text inputs. This class provides a unified interface for both: | ||
|
|
||
| 1. Hosted embedding models (e.g. OpenAI's text-embedding-3-small) via litellm integration | ||
| 2. Custom embedding functions that you provide | ||
|
|
||
| For hosted models, simply pass the model name as a string (e.g. "openai/text-embedding-3-small"). The class will use | ||
| litellm to handle the API calls and caching. | ||
|
|
||
| For custom embedding models, pass a callable function that: | ||
| - Takes a list of strings as input. | ||
| - Returns embeddings as either: | ||
| - A 2D numpy array of float32 values | ||
| - A 2D list of float32 values | ||
| - Each row should represent one embedding vector | ||
|
|
||
| Args: | ||
| model: The embedding model to use. This can be either a string (representing the name of the hosted embedding | ||
| model, must be an embedding model supported by litellm) or a callable that represents a custom embedding | ||
| model. | ||
|
|
||
| Examples: | ||
| Example 1: Using a hosted model. | ||
|
|
||
| ```python | ||
| import dspy | ||
|
|
||
| embedder = dspy.Embedding("openai/text-embedding-3-small") | ||
| embeddings = embedder(["hello", "world"]) | ||
|
|
||
| assert embeddings.shape == (2, 1536) | ||
| ``` | ||
|
|
||
| Example 2: Using a custom function. | ||
|
|
||
| ```python | ||
| import dspy | ||
|
|
||
| def my_embedder(texts): | ||
| return np.random.rand(len(texts), 10) | ||
|
|
||
| embedder = dspy.Embedding(my_embedder) | ||
| embeddings = embedder(["hello", "world"]) | ||
|
|
||
| assert embeddings.shape == (2, 10) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__(self, model): | ||
| self.model = model | ||
|
|
||
| def __call__(self, inputs, caching=True, **kwargs): | ||
| """Compute embeddings for the given inputs. | ||
|
|
||
| Args: | ||
| inputs: The inputs to compute embeddings for, can be a single string or a list of strings. | ||
| caching: Whether to cache the embedding response, only valid when using a hosted embedding model. | ||
| kwargs: Additional keyword arguments to pass to the embedding model. | ||
|
|
||
| Returns: | ||
| A 2-D numpy array of embeddings, one embedding per row. | ||
| """ | ||
| if isinstance(inputs, str): | ||
| inputs = [inputs] | ||
| if isinstance(self.model, str): | ||
| embedding_response = litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs) | ||
| return np.array([data["embedding"] for data in embedding_response.data], dtype=np.float32) | ||
| elif callable(self.model): | ||
| return np.array(self.model(inputs, **kwargs), dtype=np.float32) | ||
| else: | ||
| raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import pytest | ||
| from unittest.mock import Mock, patch | ||
| import numpy as np | ||
|
|
||
| from dspy.clients.embedding import Embedding | ||
|
|
||
|
|
||
| # Mock response format similar to litellm's embedding response. | ||
| class MockEmbeddingResponse: | ||
| def __init__(self, embeddings): | ||
| self.data = [{"embedding": emb} for emb in embeddings] | ||
| self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | ||
| self.model = "mock_model" | ||
| self.object = "list" | ||
|
|
||
|
|
||
| def test_litellm_embedding(): | ||
| model = "text-embedding-ada-002" | ||
| inputs = ["hello", "world"] | ||
| mock_embeddings = [ | ||
| [0.1, 0.2, 0.3], # embedding for "hello" | ||
| [0.4, 0.5, 0.6], # embedding for "world" | ||
| ] | ||
|
|
||
| with patch("litellm.embedding") as mock_litellm: | ||
| # Configure mock to return proper response format. | ||
| mock_litellm.return_value = MockEmbeddingResponse(mock_embeddings) | ||
|
|
||
| # Create embedding instance and call it. | ||
| embedding = Embedding(model) | ||
| result = embedding(inputs) | ||
|
|
||
| # Verify litellm was called with correct parameters. | ||
| mock_litellm.assert_called_once_with(model=model, input=inputs, caching=True) | ||
|
|
||
| assert len(result) == len(inputs) | ||
| np.testing.assert_allclose(result, mock_embeddings) | ||
|
|
||
|
|
||
| def test_callable_embedding(): | ||
| inputs = ["hello", "world", "test"] | ||
|
|
||
| expected_embeddings = [ | ||
| [0.1, 0.2, 0.3], # embedding for "hello" | ||
| [0.4, 0.5, 0.6], # embedding for "world" | ||
| [0.7, 0.8, 0.9], # embedding for "test" | ||
| ] | ||
|
|
||
| def mock_embedding_fn(texts): | ||
| # Simple callable that returns random embeddings. | ||
| return expected_embeddings | ||
|
|
||
| # Create embedding instance with callable | ||
| embedding = Embedding(mock_embedding_fn) | ||
| result = embedding(inputs) | ||
|
|
||
| np.testing.assert_allclose(result, expected_embeddings) | ||
|
|
||
|
|
||
| def test_invalid_model_type(): | ||
| # Test that invalid model type raises ValueError | ||
| with pytest.raises(ValueError): | ||
| embedding = Embedding(123) # Invalid model type | ||
| embedding(["test"]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs to be done before LiteLLM is imported anywhere in DSPy, for it to have an effect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I searched their code, and this env var is read at runtime: https://github.com/BerriAI/litellm/blob/5652c375b3e22bab6704e93058c868620c72d6ee/litellm/__init__.py#L309, so our current order should be okay.