Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 90 additions & 91 deletions dspy/retrieve/pinecone_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Author: Dhar Rawal (@drawal1)
"""

from abc import ABC, abstractmethod
from typing import List, Optional, Union

import backoff
Expand All @@ -13,9 +14,6 @@
try:
import pinecone
except ImportError:
pinecone = None

if pinecone is None:
raise ImportError(
"The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`",
)
Expand All @@ -33,6 +31,64 @@
except Exception:
ERRORS = (openai.RateLimitError, openai.APIError)


class CloudEmbedProvider(ABC):
def __init__ (self, model, api_key=None):
self.model = model
self.api_key = api_key

@abstractmethod
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
pass

class OpenAIEmbed(CloudEmbedProvider):
def __init__(self, model="text-embedding-ada-002", api_key: Optional[str]=None, org: Optional[str]=None):
super().__init__(model, api_key)
self.org = org
if self.api_key:
openai.api_key = self.api_key
if self.org:
openai.organization = org


@backoff.on_exception(
backoff.expo,
ERRORS,
max_time=15,
)
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could the caching functionality included in this PR be used here?

I realize we probably need an abstracted Retriever class to handle this but for now, feel free to check if this suffices

if OPENAI_LEGACY:
embedding = openai.Embedding.create(
input=queries, model=self.model,
)
else:
embedding = openai.embeddings.create(
input=queries, model=self.model,
).model_dump()
return [embedding["embedding"] for embedding in embedding["data"]]

class CohereEmbed(CloudEmbedProvider):
def __init__(self, model: str = "multilingual-22-12", api_key: Optional[str] = None):
try:
import cohere
except ImportError:
raise ImportError(
"The cohere library is required to use CohereEmbed. Install it with `pip install cohere`",
)
super().__init__(model, api_key)
self.client = cohere.Client(api_key)

@backoff.on_exception(
backoff.expo,
ERRORS,
max_time=15,
)
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
embeddings = self.client.embed(texts=queries, model=self.model).embeddings
return embeddings



class PineconeRM(dspy.Retrieve):
"""
A retrieval module that uses Pinecone to return the top passages for a given query.
Expand All @@ -43,11 +99,8 @@ class PineconeRM(dspy.Retrieve):
Args:
pinecone_index_name (str): The name of the Pinecone index to query against.
pinecone_api_key (str, optional): The Pinecone API key. Defaults to None.
pinecone_env (str, optional): The Pinecone environment. Defaults to None.
local_embed_model (str, optional): The local embedding model to use. A popular default is "sentence-transformers/all-mpnet-base-v2".
openai_embed_model (str, optional): The OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
openai_api_key (str, optional): The API key for OpenAI. Defaults to None.
openai_org (str, optional): The organization for OpenAI. Defaults to None.
cloud_emded_provider (CloudEmbedProvider, optional): The cloud embedding provider to use. Defaults to None.
k (int, optional): The number of top passages to retrieve. Defaults to 3.

Returns:
Expand All @@ -57,6 +110,7 @@ class PineconeRM(dspy.Retrieve):
Below is a code snippet that shows how to use this as the default retriver:
```python
llm = dspy.OpenAI(model="gpt-3.5-turbo")
retriever_model = PineconeRM(index_name, cloud_emded_provider=OpenAIEmbed())
retriever_model = PineconeRM(openai.api_key)
dspy.settings.configure(lm=llm, rm=retriever_model)
```
Expand All @@ -71,11 +125,8 @@ def __init__(
self,
pinecone_index_name: str,
pinecone_api_key: Optional[str] = None,
pinecone_env: Optional[str] = None,
local_embed_model: Optional[str] = None,
openai_embed_model: Optional[str] = "text-embedding-ada-002",
openai_api_key: Optional[str] = None,
openai_org: Optional[str] = None,
cloud_emded_provider: Optional[CloudEmbedProvider] = None,
k: int = 3,
):
if local_embed_model is not None:
Expand All @@ -95,69 +146,25 @@ def __init__(
'mps' if torch.backends.mps.is_available()
else 'cpu',
)
elif openai_embed_model is not None:
self._openai_embed_model = openai_embed_model

elif cloud_emded_provider is not None:
self.use_local_model = False
# If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION
if openai_api_key:
openai.api_key = openai_api_key
if openai_org:
openai.organization = openai_org
self.cloud_emded_provider = cloud_emded_provider

else:
raise ValueError(
"Either local_embed_model or openai_embed_model must be provided.",
"Either local_embed_model or cloud_embed_provider must be provided.",
)

self._pinecone_index = self._init_pinecone(
pinecone_index_name, pinecone_api_key, pinecone_env,
)

super().__init__(k=k)

def _init_pinecone(
self,
index_name: str,
api_key: Optional[str] = None,
environment: Optional[str] = None,
dimension: Optional[int] = None,
distance_metric: Optional[str] = None,
) -> pinecone.Index:
"""Initialize pinecone and return the loaded index.

Args:
index_name (str): The name of the index to load. If the index is not does not exist, it will be created.
api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided.
environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter`. Defaults to env PINECONE_ENVIRONMENT.

Raises:
ValueError: If api_key or environment is not provided and not set as an environment variable.

Returns:
pinecone.Index: The loaded index.
"""
if pinecone_api_key is None:
self.pinecone_client = pinecone.Pinecone()
else:
self.pinecone_client = pinecone.Pinecone(api_key=pinecone_api_key)

# Pinecone init overrides default if kwargs are present, so we need to exclude if None
kwargs = {}
if api_key:
kwargs["api_key"] = api_key
if environment:
kwargs["environment"] = environment
pinecone.init(**kwargs)

active_indexes = pinecone.list_indexes()
if index_name not in active_indexes:
if dimension is None and distance_metric is None:
raise ValueError(
"dimension and distance_metric must be provided since the index provided does not exist.",
)
self._pinecone_index = self.pinecone_client.Index(pinecone_index_name)

pinecone.create_index(
name=index_name,
dimension=dimension,
metric=distance_metric,
)
super().__init__(k=k)

return pinecone.Index(index_name)

def _mean_pooling(
self,
Expand All @@ -175,11 +182,7 @@ def _mean_pooling(
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

@backoff.on_exception(
backoff.expo,
ERRORS,
max_time=15,
)

def _get_embeddings(
self,
queries: List[str],
Expand All @@ -192,24 +195,16 @@ def _get_embeddings(
Returns:
List[List[float]]: List of embeddings corresponding to each query.
"""
if not self.use_local_model:
return self.cloud_emded_provider.get_embeddings(queries)

try:
import torch
except ImportError as exc:
raise ModuleNotFoundError(
"You need to install torch to use a local embedding model with PineconeRM.",
) from exc

if not self.use_local_model:
if OPENAI_LEGACY:
embedding = openai.Embedding.create(
input=queries, model=self._openai_embed_model,
)
else:
embedding = openai.embeddings.create(
input=queries, model=self._openai_embed_model,
).model_dump()
return [embedding["embedding"] for embedding in embedding["data"]]

# Use local model
encoded_input = self._local_tokenizer(queries, padding=True, truncation=True, return_tensors="pt").to(self.device)
with torch.no_grad():
Expand All @@ -222,51 +217,55 @@ def _get_embeddings(
# we need a pooling strategy to get a single vector representation of the input
# so the default is to take the mean of the hidden states

def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction:
"""Search with pinecone for self.k top passages for query

Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k

Returns:
dspy.Prediction: An object containing the retrieved passages.
"""
k = k if k is not None else self.k
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
queries = [q for q in queries if q] # Filter empty queries
embeddings = self._get_embeddings(queries)

# For single query, just look up the top k passages
if len(queries) == 1:
results_dict = self._pinecone_index.query(
embeddings[0], top_k=self.k, include_metadata=True,
vector=embeddings[0], top_k=k, include_metadata=True,
)

# Sort results by score
sorted_results = sorted(
results_dict["matches"], key=lambda x: x.get("scores", 0.0), reverse=True,
)

passages = [result["metadata"]["text"] for result in sorted_results]
passages = [dotdict({"long_text": passage for passage in passages})]
return dspy.Prediction(passages=passages)
passages = [dotdict({"long_text": passage}) for passage in passages]
return passages

# For multiple queries, query each and return the highest scoring passages
# If a passage is returned multiple times, the score is accumulated. For this reason we increase top_k by 3x
passage_scores = {}
for embedding in embeddings:
results_dict = self._pinecone_index.query(
embedding, top_k=self.k * 3, include_metadata=True,
vector=embedding, top_k=k * 3, include_metadata=True,
)
for result in results_dict["matches"]:
passage_scores[result["metadata"]["text"]] = (
passage_scores.get(result["metadata"]["text"], 0.0)
+ result["score"]
)

sorted_passages = sorted(
passage_scores.items(), key=lambda x: x[1], reverse=True,
)[: self.k]
return dspy.Prediction(passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages])
)[: k]

passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages]
return passages
Loading