diff --git a/docs/retrieval_models_client.md b/docs/retrieval_models_client.md index 2632961b87..53dbc4d1db 100644 --- a/docs/retrieval_models_client.md +++ b/docs/retrieval_models_client.md @@ -9,6 +9,7 @@ This documentation provides an overview of the DSPy Retrieval Model Clients. | ColBERTv2 | [ColBERTv2 Section](#ColBERTv2) | | AzureCognitiveSearch | [AzureCognitiveSearch Section](#AzureCognitiveSearch) | | ChromadbRM | [ChromadbRM Section](#ChromadbRM) | +| PineconeRM | [PineconeRM Section](#PineconeRM) | ## ColBERTv2 @@ -138,6 +139,7 @@ ChromadbRM( ``` **Parameters:** + - `collection_name` (_str_): The name of the chromadb collection. - `persist_directory` (_str_): Path to the directory where chromadb data is persisted. - `embedding_function` (_Optional[EmbeddingFunction[Embeddable]]_, _optional_): The function used for embedding documents and queries. Defaults to `DefaultEmbeddingFunction()` if not specified. @@ -150,8 +152,68 @@ ChromadbRM( Search the chromadb collection for the top `k` passages matching the given query or queries, using embeddings generated via the specified `embedding_function`. **Parameters:** + +- `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for. +- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization. + +**Returns:** + +- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with a `long_text` attribute. + +## PineconeRM + +### Quickstart with Cohere Embeddings + +PineconeRM provides a retrieval module utilizing Pinecone to retrieve the top passages for a given query. It offers suppport for openai and cohere as cloud-based embedding providers and the use of local models like all-mpnet-base-v2. This example showcases the utilization of Cohere embeddings. + +```python +from dspy.retrieve.pinecone_rm import PineconeRM, CohereEmbed + +retriever_model = PineconeRM( + pinecone_index_name="your_index_name", + cloud_emded_provider=CohereEmbed(), +) + +results = retriever_model("How does machine learning work?", k=5) + +for result in results: + print("Passage:", result.long_text, "\n") +``` + +### Constructor + +Initialize an instance of the PineconeRM class with options for embedding providers. + +```python +PineconeRM( + pinecone_index_name: str, + pinecone_api_key: Optional[str] = None, + local_embed_model: Optional[str] = None, + cloud_emded_provider: Optional[CloudEmbedProvider] = None, + k: int = 3, +) +``` + +**Parameters:** + +- `pinecone_index_name` (_str_): The name of the Pinecone index to query against. +- `pinecone_api_key` (_Optional[str], optional_): The Pinecone API key. Defaults to None. +- `local_embed_model` (_Optional[str], optional_): The local embedding model to use. Defaults to None. +- `cloud_emded_provider` (_Optional[CloudEmbedProvider], optional_): The cloud embedding provider to use. Defaults to None. +- `k` (_int, optional_): The number of top passages to retrieve. Defaults to 3. + +### Methods + +### `forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction` + +Searches the Pinecone index for the top k passages matching the given query or queries using embeddings generated via the specified embedding provider. + +**Parameters:** + - `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for. - `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization. **Returns:** -- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with a `long_text` attribute. \ No newline at end of file + +- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with a `long_text` attribute. + diff --git a/dspy/retrieve/pinecone_rm.py b/dspy/retrieve/pinecone_rm.py index 65645f8cd7..321274d10a 100644 --- a/dspy/retrieve/pinecone_rm.py +++ b/dspy/retrieve/pinecone_rm.py @@ -3,17 +3,16 @@ Author: Dhar Rawal (@drawal1) """ +import os from dsp.utils import dotdict -from typing import Optional, List, Union +from typing import Optional, List, Union, Any import dspy import backoff +from abc import ABC, abstractmethod try: import pinecone except ImportError: - pinecone = None - -if pinecone is None: raise ImportError( "The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`" ) @@ -30,6 +29,64 @@ except Exception: ERRORS = (openai.RateLimitError, openai.APIError) + +class CloudEmbedProvider(ABC): + def __init__ (self, model, api_key=None): + self.model = model + self.api_key = api_key + + @abstractmethod + def get_embeddings(self, queries: List[str]) -> List[List[float]]: + pass + +class OpenAIEmbed(CloudEmbedProvider): + def __init__(self, model="text-embedding-ada-002", api_key: Optional[str]=None, org: Optional[str]=None): + super().__init__(model, api_key) + self.org = org + if self.api_key: + openai.api_key = self.api_key + if self.org: + openai.organization = org + + + @backoff.on_exception( + backoff.expo, + ERRORS, + max_time=15, + ) + def get_embeddings(self, queries: List[str]) -> List[List[float]]: + if OPENAI_LEGACY: + embedding = openai.Embedding.create( + input=queries, model=self.model + ) + else: + embedding = openai.embeddings.create( + input=queries, model=self.model + ).model_dump() + return [embedding["embedding"] for embedding in embedding["data"]] + +class CohereEmbed(CloudEmbedProvider): + def __init__(self, model: str = "multilingual-22-12", api_key: Optional[str] = None): + try: + import cohere + except ImportError: + raise ImportError( + "The cohere library is required to use CohereEmbed. Install it with `pip install cohere`" + ) + super().__init__(model, api_key) + self.client = cohere.Client(api_key) + + @backoff.on_exception( + backoff.expo, + ERRORS, + max_time=15, + ) + def get_embeddings(self, queries: List[str]) -> List[List[float]]: + embeddings = self.client.embed(texts=queries, model=self.model).embeddings + return embeddings + + + class PineconeRM(dspy.Retrieve): """ A retrieval module that uses Pinecone to return the top passages for a given query. @@ -40,11 +97,8 @@ class PineconeRM(dspy.Retrieve): Args: pinecone_index_name (str): The name of the Pinecone index to query against. pinecone_api_key (str, optional): The Pinecone API key. Defaults to None. - pinecone_env (str, optional): The Pinecone environment. Defaults to None. local_embed_model (str, optional): The local embedding model to use. A popular default is "sentence-transformers/all-mpnet-base-v2". - openai_embed_model (str, optional): The OpenAI embedding model to use. Defaults to "text-embedding-ada-002". - openai_api_key (str, optional): The API key for OpenAI. Defaults to None. - openai_org (str, optional): The organization for OpenAI. Defaults to None. + cloud_emded_provider (CloudEmbedProvider, optional): The cloud embedding provider to use. Defaults to None. k (int, optional): The number of top passages to retrieve. Defaults to 3. Returns: @@ -54,6 +108,7 @@ class PineconeRM(dspy.Retrieve): Below is a code snippet that shows how to use this as the default retriver: ```python llm = dspy.OpenAI(model="gpt-3.5-turbo") + retriever_model = PineconeRM(index_name, cloud_emded_provider=OpenAIEmbed()) retriever_model = PineconeRM(openai.api_key) dspy.settings.configure(lm=llm, rm=retriever_model) ``` @@ -68,11 +123,8 @@ def __init__( self, pinecone_index_name: str, pinecone_api_key: Optional[str] = None, - pinecone_env: Optional[str] = None, local_embed_model: Optional[str] = None, - openai_embed_model: Optional[str] = "text-embedding-ada-002", - openai_api_key: Optional[str] = None, - openai_org: Optional[str] = None, + cloud_emded_provider: Optional[CloudEmbedProvider] = None, k: int = 3, ): if local_embed_model is not None: @@ -92,69 +144,25 @@ def __init__( 'mps' if torch.backends.mps.is_available() else 'cpu' ) - elif openai_embed_model is not None: - self._openai_embed_model = openai_embed_model + + elif cloud_emded_provider is not None: self.use_local_model = False - # If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION - if openai_api_key: - openai.api_key = openai_api_key - if openai_org: - openai.organization = openai_org + self.cloud_emded_provider = cloud_emded_provider + else: raise ValueError( - "Either local_embed_model or openai_embed_model must be provided." + "Either local_embed_model or cloud_embed_provider must be provided." ) - self._pinecone_index = self._init_pinecone( - pinecone_index_name, pinecone_api_key, pinecone_env - ) - - super().__init__(k=k) - - def _init_pinecone( - self, - index_name: str, - api_key: Optional[str] = None, - environment: Optional[str] = None, - dimension: Optional[int] = None, - distance_metric: Optional[str] = None, - ) -> pinecone.Index: - """Initialize pinecone and return the loaded index. - - Args: - index_name (str): The name of the index to load. If the index is not does not exist, it will be created. - api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided. - environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter`. Defaults to env PINECONE_ENVIRONMENT. - - Raises: - ValueError: If api_key or environment is not provided and not set as an environment variable. - - Returns: - pinecone.Index: The loaded index. - """ + if pinecone_api_key is None: + self.pinecone_client = pinecone.Pinecone() + else: + self.pinecone_client = pinecone.Pinecone(api_key=pinecone_api_key) - # Pinecone init overrides default if kwargs are present, so we need to exclude if None - kwargs = {} - if api_key: - kwargs["api_key"] = api_key - if environment: - kwargs["environment"] = environment - pinecone.init(**kwargs) - - active_indexes = pinecone.list_indexes() - if index_name not in active_indexes: - if dimension is None and distance_metric is None: - raise ValueError( - "dimension and distance_metric must be provided since the index provided does not exist." - ) + self._pinecone_index = self.pinecone_client.Index(pinecone_index_name) - pinecone.create_index( - name=index_name, - dimension=dimension, - metric=distance_metric, - ) + super().__init__(k=k) - return pinecone.Index(index_name) def _mean_pooling( self, @@ -172,11 +180,7 @@ def _mean_pooling( input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) - @backoff.on_exception( - backoff.expo, - ERRORS, - max_time=15, - ) + def _get_embeddings( self, queries: List[str] @@ -189,6 +193,9 @@ def _get_embeddings( Returns: List[List[float]]: List of embeddings corresponding to each query. """ + if not self.use_local_model: + return self.cloud_emded_provider.get_embeddings(queries) + try: import torch except ImportError as exc: @@ -196,17 +203,6 @@ def _get_embeddings( "You need to install torch to use a local embedding model with PineconeRM." ) from exc - if not self.use_local_model: - if OPENAI_LEGACY: - embedding = openai.Embedding.create( - input=queries, model=self._openai_embed_model - ) - else: - embedding = openai.embeddings.create( - input=queries, model=self._openai_embed_model - ).model_dump() - return [embedding["embedding"] for embedding in embedding["data"]] - # Use local model encoded_input = self._local_tokenizer(queries, padding=True, truncation=True, return_tensors="pt").to(self.device) with torch.no_grad(): @@ -219,15 +215,17 @@ def _get_embeddings( # we need a pooling strategy to get a single vector representation of the input # so the default is to take the mean of the hidden states - def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction: + def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction: """Search with pinecone for self.k top passages for query Args: query_or_queries (Union[str, List[str]]): The query or queries to search for. + k (Optional[int]): The number of top passages to retrieve. Defaults to self.k Returns: dspy.Prediction: An object containing the retrieved passages. """ + k = k if k is not None else self.k queries = ( [query_or_queries] if isinstance(query_or_queries, str) @@ -235,35 +233,37 @@ def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction: ) queries = [q for q in queries if q] # Filter empty queries embeddings = self._get_embeddings(queries) - # For single query, just look up the top k passages if len(queries) == 1: results_dict = self._pinecone_index.query( - embeddings[0], top_k=self.k, include_metadata=True + vector=embeddings[0], top_k=k, include_metadata=True ) # Sort results by score sorted_results = sorted( results_dict["matches"], key=lambda x: x.get("scores", 0.0), reverse=True ) + passages = [result["metadata"]["text"] for result in sorted_results] - passages = [dotdict({"long_text": passage for passage in passages})] - return dspy.Prediction(passages=passages) + passages = [dotdict({"long_text": passage}) for passage in passages] + return passages # For multiple queries, query each and return the highest scoring passages # If a passage is returned multiple times, the score is accumulated. For this reason we increase top_k by 3x passage_scores = {} for embedding in embeddings: results_dict = self._pinecone_index.query( - embedding, top_k=self.k * 3, include_metadata=True + vector=embedding, top_k=k * 3, include_metadata=True ) for result in results_dict["matches"]: passage_scores[result["metadata"]["text"]] = ( passage_scores.get(result["metadata"]["text"], 0.0) + result["score"] ) - + sorted_passages = sorted( passage_scores.items(), key=lambda x: x[1], reverse=True - )[: self.k] - return dspy.Prediction(passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages]) + )[: k] + + passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages] + return passages diff --git a/examples/integrations/pinecone/pinecone_wikipedia_example.ipynb b/examples/integrations/pinecone/pinecone_wikipedia_example.ipynb new file mode 100644 index 0000000000..d522032bdb --- /dev/null +++ b/examples/integrations/pinecone/pinecone_wikipedia_example.ipynb @@ -0,0 +1,625 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2031537/1159854371.py:20: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html\n", + " import pkg_resources # Install the package if it's not installed\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import sys\n", + "import os\n", + "\n", + "try: # When on google Colab, let's clone the notebook so we download the cache.\n", + " import google.colab\n", + " repo_path = 'dspy'\n", + " !git -C $repo_path pull origin || git clone https://github.com/stanfordnlp/dspy $repo_path\n", + "except:\n", + " repo_path = '.'\n", + "\n", + "if repo_path not in sys.path:\n", + " sys.path.append(repo_path)\n", + "\n", + "# Set up the cache for this notebook\n", + "os.environ[\"DSP_NOTEBOOK_CACHEDIR\"] = os.path.join(repo_path, 'cache')\n", + "\n", + "import pkg_resources # Install the package if it's not installed\n", + "if not \"dspy-ai\" in {pkg.key for pkg in pkg_resources.working_set}:\n", + " !pip install -U pip\n", + " !pip install dspy-ai\n", + " # !pip install -e $repo_path\n", + "\n", + "import dspy\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset, Dataset\n", + "docs = load_dataset(f\"Cohere/wikipedia-22-12-simple-embeddings\", \"en\", split=\"train[:5%]\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['id', 'title', 'text', 'url', 'wiki_id', 'views', 'paragraph_id', 'langs', 'emb'],\n", + " num_rows: 24293\n", + "})" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "docs" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import os\n", + "from pinecone.grpc import PineconeGRPC as Pinecone\n", + "from pinecone import PodSpec\n", + "\n", + "api_key = os.getenv(\"PINECONE_API_KEY\")\n", + "pc = Pinecone(\n", + " api_key=os.environ.get(\"PINECONE_API_KEY\")\n", + ")\n", + "\n", + "# Pick a name for the new index\n", + "index_name = 'wikipedia-articles'\n", + "\n", + "# Check whether the index with the same name already exists - if so, delete it\n", + "if index_name in pc.list_indexes():\n", + " pc.delete_index(index_name)\n", + " \n", + "# Creates new index\n", + "if index_name not in pc.list_indexes().names():\n", + " pc.create_index(\n", + " name=index_name, \n", + " dimension=768, \n", + " metric='dotproduct',\n", + " spec=PodSpec(\n", + " environment=\"gcp-starter\",\n", + " )\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'dimension': 768,\n", + " 'index_fullness': 0.0,\n", + " 'namespaces': {'': {'vector_count': 0}},\n", + " 'total_vector_count': 0}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index = pc.Index(index_name)\n", + "index.describe_index_stats()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 100,\n", + " upserted_count: 53]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "def chunker(data, batch_size):\n", + " data_iter = iter(data)\n", + " # end = False\n", + " for i in range(0, len(data), batch_size):\n", + " chunk = []\n", + " # if i + batch_size >= len(data):\n", + " # batch_size += len(data) - i\n", + " for x in data_iter:\n", + " if len(chunk) == batch_size:\n", + " break\n", + " chunk.append(x)\n", + " \n", + " chunk_to_insert = []\n", + " for x in chunk:\n", + " item = {}\n", + " item['id'] = str(x['id'])\n", + " item['values'] = x['emb']\n", + " item['metadata'] = {}\n", + " item['metadata']['text'] = x['text']\n", + " chunk_to_insert.append(item)\n", + "\n", + " yield chunk_to_insert\n", + "\n", + "async_results = [\n", + " index.upsert(vectors=chunk, async_req=True)\n", + " for chunk in chunker(docs, batch_size=100) if len(chunk) > 0\n", + "]\n", + "\n", + "# Wait for and retrieve responses (in case of error)\n", + "results = [async_result.result() for async_result in async_results]\n", + "results" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'dimension': 768,\n", + " 'index_fullness': 0.24053,\n", + " 'namespaces': {'': {'vector_count': 24053}},\n", + " 'total_vector_count': 24053}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index.describe_index_stats()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from dspy.retrieve.pinecone_rm import PineconeRM, CohereEmbed\n", + "\n", + "cohere_embed = CohereEmbed()\n", + "\n", + "llm = dspy.OllamaLocal(model=\"openhermes2.5-mistral:7b-q5_K_M\", model_type=\"chat\")\n", + "retriever_model = PineconeRM(index_name, cloud_emded_provider=cohere_embed)\n", + "dspy.settings.configure(lm=llm, rm=retriever_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(20, 50)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from dspy.datasets import HotPotQA\n", + "\n", + "# Load the dataset.\n", + "dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0)\n", + "\n", + "# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.\n", + "trainset = [x.with_inputs('question') for x in dataset.train]\n", + "devset = [x.with_inputs('question') for x in dataset.dev]\n", + "\n", + "len(trainset), len(devset)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "class GenerateAnswer(dspy.Signature):\n", + " \"\"\"Answer questions with short factoid answers.\"\"\"\n", + "\n", + " context = dspy.InputField(desc=\"may contain relevant facts\")\n", + " question = dspy.InputField()\n", + " answer = dspy.OutputField(desc=\"often between 1 and 5 words\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "class RAG(dspy.Module):\n", + " def __init__(self, num_passages=3):\n", + " super().__init__()\n", + "\n", + " self.retrieve = dspy.Retrieve(k=num_passages)\n", + " self.generate_answer = dspy.ChainOfThought(GenerateAnswer)\n", + " \n", + " def forward(self, question):\n", + " context = self.retrieve(question).passages\n", + " prediction = self.generate_answer(context=context, question=question)\n", + " return dspy.Prediction(context=context, answer=prediction.answer)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/20 [00:00