Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def infer_model(
from .cohere import CohereEmbeddingModel

return CohereEmbeddingModel(model_name, provider=provider)
elif model_kind == 'sentence-transformers':
from .sentence_transformers import SentenceTransformerEmbeddingModel

return SentenceTransformerEmbeddingModel(model_name)
else:
raise UserError(f'Unknown embeddings model: {model}') # pragma: no cover

Expand Down
115 changes: 115 additions & 0 deletions pydantic_ai_slim/pydantic_ai/embeddings/sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, overload

from pydantic_ai.embeddings.base import EmbeddingModel
from pydantic_ai.embeddings.settings import EmbeddingSettings

# Optional dependency: sentence-transformers
try: # pragma: no cover - depends on optional install
from sentence_transformers import SentenceTransformer
except ImportError as _import_error: # pragma: no cover - depends on optional install
SentenceTransformer = None
raise ImportError(
'Please install `sentence-transformers` to use the Sentence-Transformers embeddings model, '
'you can use the `sentence-transformers` optional group — '
'pip install "pydantic-ai-slim[sentence-transformers]"'
) from _import_error


class SentenceTransformersEmbeddingSettings(EmbeddingSettings, total=False):
"""Settings used for a Sentence-Transformers embedding model request.

All fields are `sentence_transformers_`-prefixed so settings can be merged across providers safely.
"""

# Device to run inference on, e.g. "cpu", "cuda", "cuda:0", "mps".
sentence_transformers_device: str

# Whether to L2-normalize embeddings. Mirrors `normalize_embeddings` in SentenceTransformer.encode.
sentence_transformers_normalize_embeddings: bool

# Batch size to use during encoding.
sentence_transformers_batch_size: int


@dataclass(init=False)
class SentenceTransformerEmbeddingModel(EmbeddingModel):
"""Local embeddings using `sentence-transformers` models.

Example models include "all-MiniLM-L6-v2" and many others hosted on Hugging Face.
"""

_model_name: str = field(repr=False)
_model: Any = field(repr=False)

def __init__(self, model_name: str, *, settings: EmbeddingSettings | None = None) -> None:
"""Initialize a Sentence-Transformers embedding model.

Args:
model_name: The model name or local path to load with `SentenceTransformer`.
settings: Model-specific settings that will be used as defaults for this model.
"""
self._model_name = model_name
if SentenceTransformer is None: # pragma: no cover - depends on optional install
raise ImportError(
'Please install `sentence-transformers` to use this embeddings model, '
'you can use the `sentence-transformers` optional group — '
'pip install "pydantic-ai-slim[sentence-transformers]"'
)
# Defer device selection to encode() where we can override via settings
self._model = SentenceTransformer(model_name)

super().__init__(settings=settings)

@property
def base_url(self) -> str | None:
"""No base URL — runs locally."""
return None

@property
def model_name(self) -> str:
"""The embedding model name."""
return self._model_name

@property
def system(self) -> str:
"""The embedding model provider/system identifier."""
return 'sentence-transformers'

@overload
async def embed(self, documents: str, *, settings: EmbeddingSettings | None = None) -> list[float]: ...

@overload
async def embed(
self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None
) -> list[list[float]]: ...

async def embed(
self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
) -> list[float] | list[list[float]]:
docs, is_single_document, settings = self.prepare_embed(documents, settings)
embeddings = await self._embed(docs, settings)
return embeddings[0] if is_single_document else embeddings

async def _embed(
self, documents: Sequence[str], settings: SentenceTransformersEmbeddingSettings
) -> list[list[float]]:
device = settings.get('sentence_transformers_device', None)
normalize = settings.get('sentence_transformers_normalize_embeddings', False)
batch_size = settings.get('sentence_transformers_batch_size')

encode_kwargs: dict[str, Any] = {
'show_progress_bar': False,
'convert_to_numpy': True,
'convert_to_tensor': False,
'device': device,
'normalize_embeddings': normalize,
}
if batch_size is not None:
encode_kwargs['batch_size'] = batch_size

np_embeddings = self._model.encode(list(documents), **encode_kwargs)
return np_embeddings.tolist()
1 change: 1 addition & 0 deletions pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ groq = ["groq>=0.25.0"]
mistral = ["mistralai>=1.9.10"]
bedrock = ["boto3>=1.40.14"]
huggingface = ["huggingface-hub[inference]>=0.33.5"]
sentence-transformers = ["sentence-transformers"]
outlines-transformers = ["outlines[transformers]>=1.0.0, <1.3.0; (sys_platform != 'darwin' or platform_machine != 'x86_64')", "transformers>=4.0.0", "pillow", "torch; (sys_platform != 'darwin' or platform_machine != 'x86_64')"]
outlines-llamacpp = ["outlines[llamacpp]>=1.0.0, <1.3.0"]
outlines-mlxlm = ["outlines[mlxlm]>=1.0.0, <1.3.0; (sys_platform != 'darwin' or platform_machine != 'x86_64')"]
Expand Down
Loading