-
Notifications
You must be signed in to change notification settings - Fork 2.4k
feature(dspy): Add MyScale in Retrieve #791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
arnavsinghvi11
merged 19 commits into
stanfordnlp:main
from
usamajamil43:feature/add-MyScale-to-retrieve
Jun 13, 2024
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
7510b12
feature(dspy): Add MyScale in Retrieve
usamajamil43 e150392
feature(dspy): Add MyScale in Retrieve
usamajamil43 47b45d2
feature(dspy): Add MyScale in Retrieve
usamajamil43 9e92c6e
Merge branch 'main' into feature/add-MyScale-to-retrieve
arnavsinghvi11 b466b3c
Update README.md
arnavsinghvi11 1e09778
Add documentation and cache.
usamajamil43 c976552
Added Cache to the embedding methods
usamajamil43 ab5a99a
Merge branch 'main' into feature/add-MyScale-to-retrieve
arnavsinghvi11 18b9e6b
The last commit
usamajamil43 6def440
Merge branch 'feature/add-MyScale-to-retrieve' of https://github.com/…
usamajamil43 bbadff7
feature(dspy): Add MyScale in Retrieve
usamajamil43 b1d53e3
feature(dspy): Add MyScale in Retrieve
usamajamil43 5346766
feature(dspy): Add MyScale in Retrieve
usamajamil43 e2f3db0
Update README.md
arnavsinghvi11 13a61a3
Add documentation and cache.
usamajamil43 7e2c1b9
Added Cache to the embedding methods
usamajamil43 c60148e
run ruff
ahsansaeed878 88a720c
Merge branch 'feature/add-MyScale-to-retrieve' of usama:usamajamil43/…
ahsansaeed878 a0463cb
Merge branch 'main' into feature/add-MyScale-to-retrieve
arnavsinghvi11 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| --- | ||
| sidebar_position: 8 | ||
| --- | ||
|
|
||
| # retrieve.MyScaleRM | ||
| ## Constructor | ||
|
|
||
| Initializes an instance of the `MyScaleRM` class, which is designed to use MyScaleDB (a ClickHouse fork optimized for vector similarity and full-text search) to retrieve documents based on query embeddings. This class supports embedding generation using either local models or OpenAI's API and manages database interactions efficiently. | ||
|
|
||
| ### Syntax | ||
| ```python | ||
| MyScaleRM( | ||
| client: clickhouse_connect.driver.client.Client, | ||
| table: str, | ||
| database: str = 'default', | ||
| metadata_columns: List[str] = ['text'], | ||
| vector_column: str = 'vector', | ||
| k: int = 3, | ||
| openai_api_key: Optional[str] = None, | ||
| openai_model: Optional[str] = None, | ||
| local_embed_model: Optional[str] = None | ||
| ) | ||
| ``` | ||
| ## Parameters for `MyScaleRM` Constructor | ||
| - `client` (_clickhouse_connect.driver.client.Client_): A client connection to the MyScaleDB database, used to execute queries and manage interactions with the database. | ||
| - `table` (_str_): Specifies the table within MyScaleDB from which data will be retrieved. This table should be equipped with a vector column for conducting similarity searches. | ||
| - `database` (_str_, optional): The name of the database where the table is located, defaulting to `"default"`. | ||
| - `metadata_columns` (_List[str], optional_): Columns to include as metadata in the output, defaulting to `["text"]`. | ||
| - `vector_column` (_str, optional_): The column that contains vector data, used for similarity searches, defaulting to `"vector"`. | ||
| - `k` (_int, optional_): The number of closest matches to return for each query, defaulting to 3. | ||
| - `openai_api_key` (_str, optional_): API key for accessing OpenAI services, necessary if using OpenAI for embedding generation. | ||
| - `openai_model` (_str, optional_): The specific OpenAI model to use for embeddings, required if an OpenAI API key is provided. | ||
| - `local_embed_model` (_str, optional_): Specifies a local model for embedding generation, chosen if local computation is preferred. | ||
|
|
||
| ## Methods | ||
| ### `forward` | ||
| Executes a retrieval operation based on a user's query and returns the top `k` relevant results using the embeddings generated by the specified method. | ||
|
|
||
| ### Syntax | ||
| ```python | ||
| def forward(self, user_query: str, k: Optional[int] = None) -> dspy.Prediction | ||
| ``` | ||
|
|
||
| ## Parameters | ||
| - `user_query` (_str_): The query or list of queries for which to retrieve matching passages. | ||
| - `k` (_Optional[int], optional_): The number of top matches to retrieve. If not provided, it defaults to the `k` value set during class initialization. | ||
|
|
||
| ## Returns | ||
| - `dspy.Prediction`: Contains the retrieved passages, formatted as a list of `dotdict` objects. Each entry includes: | ||
| - **long_text (str)**: The text content of the retrieved passage. | ||
|
|
||
| ## Description | ||
|
|
||
| The `forward` method leverages the MyScaleDB's vector search capabilities to find the top `k` passages that best match the provided query. This method is integral for utilizing the MyScaleRM class to access and retrieve data efficiently based on semantic similarity, facilitated by the chosen embedding generation technique (either via a local model or the OpenAI API). | ||
|
|
||
| ## Quickstart | ||
|
|
||
| This section provides practical examples of how to instantiate and use the `MyScaleRM` class to retrieve data from MyScaleDB efficiently using text embeddings. | ||
|
|
||
| ```python | ||
| from dspy.retrieve.myscaledb_rm import MyScaleRM | ||
|
|
||
| MyScale_model = MyScaleRM(client=client, | ||
| table="table_name", | ||
| openai_api_key="sk-***", | ||
| openai_model="embeddings_model", | ||
| vector_column="vector_column_name", | ||
| metadata_columns=["add_your_columns_here"], | ||
| k=6) | ||
|
|
||
| MyScale_model("Please suggest me some funny movies") | ||
|
|
||
| passages = results.passages | ||
|
|
||
| # Loop through each passage and print the 'long_text' | ||
| for passage in passages: | ||
| print(passage['long_text'], "\n") | ||
|
|
||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,226 @@ | ||
| import functools | ||
| import os | ||
| from typing import List, Optional | ||
|
|
||
| import openai | ||
|
|
||
| import dspy | ||
| from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on | ||
| from dsp.utils import dotdict | ||
|
|
||
| # Check for necessary libraries and suggest installation if not found. | ||
| try: | ||
| import clickhouse_connect | ||
| except ImportError: | ||
| raise ImportError( | ||
| "The 'myscale' extra is required to use MyScaleRM. Install it with `pip install dspy-ai[myscale]`", | ||
| ) | ||
|
|
||
| # Verify the compatibility of the OpenAI library version installed. | ||
| try: | ||
| major, minor, _ = map(int, openai.__version__.split('.')) | ||
| OPENAI_VERSION_COMPATIBLE = major >= 1 and minor >= 16 | ||
| except Exception: | ||
| OPENAI_VERSION_COMPATIBLE = False | ||
|
|
||
| if not OPENAI_VERSION_COMPATIBLE: | ||
| raise ImportError( | ||
| "An incompatible OpenAI library version is installed. Ensure you have version 1.16.1 or later.", | ||
| ) | ||
|
|
||
| # Attempt to handle specific OpenAI errors; fallback to general ones if necessary. | ||
| try: | ||
| import openai.error | ||
| ERRORS = (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.APIError) | ||
| except Exception: | ||
| ERRORS = (openai.RateLimitError, openai.APIError) | ||
|
|
||
|
|
||
| class MyScaleRM(dspy.Retrieve): | ||
| """ | ||
| A retrieval module that uses MyScaleDB to return the top passages for a given query. | ||
|
|
||
| MyScaleDB is a fork of ClickHouse that focuses on vector similarity search and full | ||
| text search. MyScaleRM is designed to facilitate easy retrieval of information from | ||
| MyScaleDB using embeddings. It supports embedding generation through either a local | ||
| model or the OpenAI API. This class abstracts away the complexities of connecting to | ||
| MyScaleDB, managing API keys, and processing queries to return semantically | ||
| relevant results. | ||
|
|
||
| Assumes that a table named `database.table` exists in MyScaleDB, and that the | ||
| table has column named `vector_column` that stores vector data and a vector index has | ||
| been created on this column. Other metadata are stored in `metadata_columns`. | ||
|
|
||
| Args: | ||
| client (clickhouse_connect.driver.client.Client): A client connection to the MyScaleDB. | ||
| table (str): Name of the table within the database to perform queries against. | ||
| database (str, optional): Name of the database to query within MyScaleDB. | ||
| metadata_columns(List[str], optional): A list of columns to include in the results. | ||
| vector_column (str, optional): The name of the column in the table that stores vector data. | ||
| k (int, optional): The number of closest matches to retrieve for a given query. | ||
| openai_api_key (str, optional): The API key for accessing OpenAI's services. | ||
| model (str, optional): Specifies the particular OpenAI model to use for embedding generation. | ||
| use_local_model (bool): Flag indicating whether a local model is used for embeddings. | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, | ||
| client: clickhouse_connect.driver.client.Client, | ||
| table: str, | ||
| database: str = "default", | ||
| metadata_columns: List[str] = ["text"], | ||
| vector_column: str = "vector", | ||
| k: int = 3, | ||
| openai_api_key: Optional[str] = None, | ||
| openai_model: Optional[str] = None, | ||
| local_embed_model: Optional[str] = None): | ||
| self.client = client | ||
| self.database = database | ||
| self.table = table | ||
| if not metadata_columns: | ||
| raise ValueError("metadata_columns is required") | ||
| self.metadata_columns = metadata_columns | ||
| self.vector_column = vector_column | ||
| self.k = k | ||
| self.openai_api_key = openai_api_key | ||
| self.model = openai_model | ||
| self.use_local_model = False | ||
|
|
||
| if local_embed_model: | ||
| self.setup_local_model(local_embed_model) | ||
| elif openai_api_key: | ||
| os.environ['OPENAI_API_KEY'] = self.openai_api_key | ||
|
|
||
| def setup_local_model(self, model_name: str): | ||
| """ | ||
| Configures a local model for embedding generation, including model and tokenizer loading. | ||
|
|
||
| Args: | ||
| model_name: The name or path to the pre-trained model to load. | ||
|
|
||
| Raises: | ||
| ModuleNotFoundError: If necessary libraries (torch or transformers) are not installed. | ||
| """ | ||
| try: | ||
| import torch | ||
| from transformers import AutoModel, AutoTokenizer | ||
| except ImportError as exc: | ||
| raise ModuleNotFoundError( | ||
| """You need to install PyTorch and Hugging Face's transformers library to use a local embedding model. | ||
| Install the pytorch using `pip install torch` and transformers using `pip install transformers` """, | ||
| ) from exc | ||
|
|
||
| try: | ||
| self._local_embed_model = AutoModel.from_pretrained(model_name) | ||
| self._local_tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| self.use_local_model = True | ||
| except Exception as e: | ||
| raise ValueError(f"Failed to load model or tokenizer. Error: {str(e)}") | ||
|
|
||
| if torch.cuda.is_available(): | ||
| self.device = torch.device('cuda:0') | ||
| elif torch.backends.mps.is_available(): | ||
| self.device = torch.device('mps') | ||
| else: | ||
| self.device = torch.device('cpu') | ||
|
|
||
| self._local_embed_model.to(self.device) | ||
|
|
||
| @functools.lru_cache(maxsize=None if cache_turn_on else 0) | ||
| @NotebookCacheMemory.cache | ||
| def get_embeddings(self, queries: List[str]) -> List[List[float]]: | ||
| """ | ||
| Determines the appropriate source (OpenAI or local model) for embedding generation based on class configuration, | ||
| and retrieves embeddings for the provided queries. | ||
|
|
||
| Args: | ||
| queries: A list of text queries to generate embeddings for. | ||
|
|
||
| Returns: | ||
| A list of embeddings, each corresponding to a query in the input list. | ||
|
|
||
| Raises: | ||
| ValueError: If neither an OpenAI API key nor a local model has been configured. | ||
| """ | ||
| if self.openai_api_key and self.model: | ||
| return self._get_embeddings_from_openai(queries) | ||
| elif self.use_local_model: | ||
| return self._get_embedding_from_local_model(queries) | ||
| else: | ||
| raise ValueError("No valid method for obtaining embeddings is configured.") | ||
|
|
||
| #TO DO Add this method as Util method outside MyScaleRM | ||
| @CacheMemory.cache | ||
| def _get_embeddings_from_openai(self, queries: List[str]) -> List[List[float]]: | ||
| """ | ||
| Uses the OpenAI API to generate embeddings for a list of queries. | ||
|
|
||
| Args: | ||
| queries: A list of strings for which to generate embeddings. | ||
|
|
||
| Returns: | ||
| A list of lists, where each inner list contains the embedding of a query. | ||
| """ | ||
|
|
||
| response = openai.embeddings.create( | ||
| model=self.model, | ||
| input=queries) | ||
| return response.data[0].embedding | ||
|
|
||
| #TO DO Add this method as Util method outside MyScaleRM | ||
| @CacheMemory.cache | ||
| def _get_embedding_from_local_model(self, query: str) -> List[float]: | ||
| """ | ||
| Generates embeddings for a single query using the configured local model. | ||
|
|
||
| Args: | ||
| query: The text query to generate an embedding for. | ||
|
|
||
| Returns: | ||
| A list of floats representing the query's embedding. | ||
| """ | ||
| import torch | ||
| self._local_embed_model.eval() # Ensure the model is in evaluation mode | ||
|
|
||
| inputs = self._local_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(self.device) | ||
| with torch.no_grad(): | ||
| output = self._local_embed_model(**inputs) | ||
| embedding = output.last_hidden_state.mean(dim=1).cpu().numpy().tolist()[0] | ||
|
|
||
| return embedding | ||
|
|
||
| def forward(self, user_query: str, k: Optional[int] = None) -> dspy.Prediction: | ||
| """ | ||
| Executes a retrieval operation based on a user's query and returns the top k relevant results. | ||
|
|
||
| Args: | ||
| user_query: The query text to search for. | ||
| k: Optional; The number of top matches to return. Defaults to the class's configured k value. | ||
|
|
||
| Returns: | ||
| A dspy.Prediction object containing the formatted retrieval results. | ||
|
|
||
| Raises: | ||
| ValueError: If the user_query is None. | ||
| """ | ||
| if user_query is None: | ||
| raise ValueError("Query is required") | ||
| k = k if k is not None else self.k | ||
| embeddings = self.get_embeddings([user_query]) | ||
| columns_string = ', '.join(self.metadata_columns) | ||
| result = self.client.query(f""" | ||
| SELECT {columns_string}, | ||
| distance({self.vector_column}, {embeddings}) as dist FROM {self.database}.{self.table} ORDER BY dist LIMIT {k} | ||
| """) | ||
|
|
||
| # We convert the metadata into strings to pass to dspy.Prediction | ||
| results = [] | ||
| for row in result.named_results(): | ||
| if len(self.metadata_columns) == 1: | ||
| results.append(row[self.metadata_columns[0]]) | ||
| else: | ||
| row_strings = [f"{column}: {row[column]}" for column in self.metadata_columns] # Format row data | ||
| row_string = "\n".join(row_strings) # Combine formatted data | ||
| results.append(row_string) # Append to results | ||
|
|
||
| return dspy.Prediction(passages=[dotdict({"long_text": passage}) for passage in results]) # Return results as Prediction | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can these requests be cached? could you adapt from DSPy's caching functionalities?