-
Notifications
You must be signed in to change notification settings - Fork 2.4k
DSPy Support For Snowflake LLM's #859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
arnavsinghvi11
merged 27 commits into
stanfordnlp:main
from
sfc-gh-alherrera:dspy-snowflake
May 11, 2024
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
d7b41cd
feat(dspy): testing initial commit
sfc-gh-alherrera dfae177
feat(dspy): initial cleanup of Snowflake (LM)support
sfc-gh-alherrera e492cfa
feat(dspy): added SnowflakeRM retriever and cleanup of Snowflake (LM)…
sfc-gh-alherrera f2ce44f
feat(dspy): added SnowflakeRM retriever and cleanup of Snowflake (LM)…
sfc-gh-alherrera 6af2d47
feat(dspy): added SnowflakeRM retriever and cleanup of Snowflake (LM)…
sfc-gh-alherrera f42df77
feat(dspy): added SnowflakeRM retriever and cleanup of Snowflake (LM)…
sfc-gh-alherrera 8ce1378
feat(dspy): added Snowflake Cortex (LM) and Snowflake RM support
sfc-gh-alherrera bbb6204
feat(dspy): added Snowflake Cortex (LM) and Snowflake RM support
sfc-gh-alherrera bf9247f
fix(dspy): removing unnecessary ipykernel dependency
sfc-gh-alherrera 4ed0753
fix(dspy): updating import error handling language for SnowflakeRM
sfc-gh-alherrera 315d71a
fix(dspy): updating README to include Snowflake as optional supported…
sfc-gh-alherrera cfa67d6
fix(dspy): updating SnowflakeRM typos
sfc-gh-alherrera 75cd65b
fix(dspy): updating sort order to be descending for retriever results
sfc-gh-alherrera c663444
docs(dspy): Adding documentation for Snowflake LM
sfc-gh-alherrera 2cb77df
docs(dspy): Adding documentation for Snowflake RM
sfc-gh-alherrera daef34e
docs(dspy): Adding documentation for Snowflake RM
sfc-gh-alherrera 4f10658
fix(dspy): Adding self.history definition to Snowflake LM
sfc-gh-alherrera e6afb61
Merge branch 'main' into dspy-snowflake
sfc-gh-alherrera fad2f11
fix(dspy): Adding missing LLMs supported in Snowflake and tag for ses…
sfc-gh-alherrera 441fd43
fix(dspy): Adding missing LLMs supported in Snowflake and tag for ses…
sfc-gh-alherrera 8775674
fix(dspy): adding language for recently released snowflake retriever …
sfc-gh-alherrera a048a51
fix(dspy): updating syntax for embeddings method call which will be d…
sfc-gh-alherrera b3af69c
fix(dspy): updating syntax for Snowflake cos similarity method which …
sfc-gh-alherrera 53630da
fix(dspy): updating syntax for Snowflake cos similarity method which …
sfc-gh-alherrera b0b18a7
fix(dspy): updating syntax for Snowflake cos similarity method which …
sfc-gh-alherrera 8dedd00
fix(dspy): solving for null response bug in Cortex API
sfc-gh-alherrera 23cd780
fix(dspy): adding LM connection parameter update to docs, resolving r…
sfc-gh-alherrera File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| --- | ||
| sidebar_position: | ||
| --- | ||
|
|
||
| # dspy.Snowflake | ||
|
|
||
| ### Usage | ||
|
|
||
| ```python | ||
| import dspy | ||
| import os | ||
|
|
||
| connection_parameters = { | ||
|
|
||
| "account": os.getenv('SNOWFLAKE_ACCOUNT'), | ||
| "user": os.getenv('SNOWFLAKE_USER'), | ||
| "password": os.getenv('SNOWFLAKE_PASSWORD'), | ||
| "role": os.getenv('SNOWFLAKE_ROLE'), | ||
| "warehouse": os.getenv('SNOWFLAKE_WAREHOUSE'), | ||
| "database": os.getenv('SNOWFLAKE_DATABASE'), | ||
| "schema": os.getenv('SNOWFLAKE_SCHEMA')} | ||
|
|
||
| lm = dspy.Snowflake(model="mixtral-8x7b",credentials=connection_parameters) | ||
| ``` | ||
|
|
||
| ### Constructor | ||
|
|
||
| The constructor inherits from the base class `LM` and verifies the `credentials` for using Snowflake API. | ||
|
|
||
| ```python | ||
| class Snowflake(LM): | ||
| def __init__( | ||
| self, | ||
| model, | ||
| credentials, | ||
| **kwargs): | ||
| ``` | ||
|
|
||
| **Parameters:** | ||
| - `model` (_str_): model hosted by [Snowflake Cortex](https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#availability). | ||
| - `credentials` (_dict_): connection parameters required to initialize a [snowflake snowpark session](https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session) | ||
|
|
||
| ### Methods | ||
|
|
||
| Refer to [`dspy.Snowflake`](https://dspy-docs.vercel.app/api/language_model_clients/Snowflake) documentation. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| --- | ||
| sidebar_position: | ||
| --- | ||
|
|
||
| # retrieve.SnowflakeRM | ||
|
|
||
| ### Constructor | ||
|
|
||
| Initialize an instance of the `SnowflakeRM` class, with the option to use `e5-base-v2` or `snowflake-arctic-embed-m` embeddings or any other Snowflake Cortex supported embeddings model. | ||
|
|
||
| ```python | ||
| SnowflakeRM( | ||
| snowflake_table_name: str, | ||
| snowflake_credentials: dict, | ||
| k: int = 3, | ||
| embeddings_field: str, | ||
| embeddings_text_field:str, | ||
| embeddings_model: str = "e5-base-v2", | ||
| ) | ||
| ``` | ||
|
|
||
| **Parameters:** | ||
|
|
||
| - `snowflake_table_name (str)`: The name of the Snowflake table containing embeddings. | ||
| - `snowflake_credentials (dict)`: The connection parameters needed to initialize a Snowflake Snowpark Session. | ||
| - `k (int, optional)`: The number of top passages to retrieve. Defaults to 3. | ||
| - `embeddings_field (str)`: The name of the column in the Snowflake table containing the embeddings. | ||
| - `embeddings_text_field (str)`: The name of the column in the Snowflake table containing the passages. | ||
| - `embeddings_model (str)`: The model to be used to convert text to embeddings | ||
|
|
||
| ### Methods | ||
|
|
||
| #### `forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction` | ||
|
|
||
| Search the Snowflake table for the top `k` passages matching the given query or queries, using embeddings generated via the default `e5-base-v2` model or the specified `embedding_model`. | ||
|
|
||
| **Parameters:** | ||
|
|
||
| - `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for. | ||
| - `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization. | ||
|
|
||
| **Returns:** | ||
|
|
||
| - `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with schema `[{"id": str, "score": float, "long_text": str, "metadatas": dict }]` | ||
|
|
||
| ### Quickstart | ||
|
|
||
| To support passage retrieval, it assumes that a Snowflake table has been created and populated with the passages in a column `embeddings_text_field` and the embeddings in another column `embeddings_field` | ||
|
|
||
| SnowflakeRM uses `e5-base-v2` embeddings model by default or any Snowflake Cortex supported embeddings model. | ||
|
|
||
| #### Default OpenAI Embeddings | ||
|
|
||
| ```python | ||
| from dspy.retrieve.snowflake_rm import SnowflakeRM | ||
| import os | ||
|
|
||
| connection_parameters = { | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah you can add this to the LM documentation to address that comment!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
| "account": os.getenv('SNOWFLAKE_ACCOUNT'), | ||
| "user": os.getenv('SNOWFLAKE_USER'), | ||
| "password": os.getenv('SNOWFLAKE_PASSWORD'), | ||
| "role": os.getenv('SNOWFLAKE_ROLE'), | ||
| "warehouse": os.getenv('SNOWFLAKE_WAREHOUSE'), | ||
| "database": os.getenv('SNOWFLAKE_DATABASE'), | ||
| "schema": os.getenv('SNOWFLAKE_SCHEMA')} | ||
|
|
||
| retriever_model = SnowflakeRM( | ||
| snowflake_table_name="<YOUR_SNOWFLAKE_TABLE_NAME>", | ||
| snowflake_credentials=connection_parameters, | ||
| embeddings_field="<YOUR_EMBEDDINGS_COLUMN_NAME>", | ||
| embeddings_text_field= "<YOUR_PASSAGE_COLUMN_NAME>" | ||
| ) | ||
|
|
||
| results = retriever_model("Explore the meaning of life", k=5) | ||
|
|
||
| for result in results: | ||
| print("Document:", result.long_text, "\n") | ||
| ``` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| """Module for interacting with Snowflake Cortex.""" | ||
| import json | ||
| from typing import Any | ||
|
|
||
| import backoff | ||
| from pydantic_core import PydanticCustomError | ||
|
|
||
| from dsp.modules.lm import LM | ||
|
|
||
| try: | ||
| from snowflake.snowpark import Session | ||
| from snowflake.snowpark import functions as snow_func | ||
|
|
||
| except ImportError: | ||
| pass | ||
|
|
||
|
|
||
| def backoff_hdlr(details) -> None: | ||
| """Handler from https://pypi.org/project/backoff .""" | ||
| print( | ||
| f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries ", | ||
| f"calling function {details['target']} with kwargs", | ||
| f"{details['kwargs']}", | ||
| ) | ||
|
|
||
|
|
||
| def giveup_hdlr(details) -> bool: | ||
| """Wrapper function that decides when to give up on retry.""" | ||
| if "rate limits" in str(details): | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| class Snowflake(LM): | ||
| """Wrapper around Snowflake's CortexAPI. | ||
|
|
||
| Currently supported models include 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b', | ||
| 'llama2-70b-chat','mistral-7b','gemma-7b','llama3-8b','llama3-70b','reka-core'. | ||
| """ | ||
|
|
||
| def __init__(self, model: str = "mixtral-8x7b", credentials=None, **kwargs): | ||
| """Parameters | ||
|
|
||
| ---------- | ||
| model : str | ||
| Which pre-trained model from Snowflake to use? | ||
| Choices are 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b','llama2-70b-chat','mistral-7b','gemma-7b' | ||
| Full list of supported models is available here: https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#complete | ||
| credentials: dict | ||
| Snowflake credentials required to initialize the session. | ||
| Full list of requirements can be found here: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session | ||
| **kwargs: dict | ||
| Additional arguments to pass to the API provider. | ||
| """ | ||
| super().__init__(model) | ||
|
|
||
| self.model = model | ||
| cortex_models = [ | ||
| "llama3-8b", | ||
| "llama3-70b", | ||
| "reka-core", | ||
| "snowflake-arctic", | ||
| "mistral-large", | ||
| "reka-flash", | ||
| "mixtral-8x7b", | ||
| "llama2-70b-chat", | ||
| "mistral-7b", | ||
| "gemma-7b", | ||
| ] | ||
|
|
||
| if model in cortex_models: | ||
| self.available_args = { | ||
| "max_tokens", | ||
| "temperature", | ||
| "top_p", | ||
| } | ||
| else: | ||
| raise PydanticCustomError( | ||
| "model", | ||
| 'model name is not valid, got "{model_name}"', | ||
| ) | ||
|
|
||
| self.client = self._init_cortex(credentials=credentials) | ||
| self.provider = "Snowflake" | ||
| self.history: list[dict[str, Any]] = [] | ||
| self.kwargs = { | ||
| **self.kwargs, | ||
| "temperature": 0.7, | ||
| "max_output_tokens": 1024, | ||
| "top_p": 1.0, | ||
| "top_k": 1, | ||
| **kwargs, | ||
| } | ||
arnavsinghvi11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @classmethod | ||
| def _init_cortex(cls, credentials: dict) -> None: | ||
| session = Session.builder.configs(credentials).create() | ||
| session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}} | ||
|
|
||
| return session | ||
|
|
||
| def _prepare_params( | ||
| self, | ||
| parameters: Any, | ||
| ) -> dict: | ||
| params_mapping = {"n": "candidate_count", "max_tokens": "max_output_tokens"} | ||
| params = {params_mapping.get(k, k): v for k, v in parameters.items()} | ||
| params = {**self.kwargs, **params} | ||
| return {k: params[k] for k in set(params.keys()) & self.available_args} | ||
|
|
||
| def _cortex_complete_request(self, prompt: str, **kwargs) -> dict: | ||
| complete = snow_func.builtin("snowflake.cortex.complete") | ||
| cortex_complete_args = complete( | ||
| snow_func.lit(self.model), | ||
| snow_func.lit([{"role": "user", "content": prompt}]), | ||
| snow_func.lit(kwargs), | ||
| ) | ||
| raw_response = self.client.range(1).withColumn("complete_cal", cortex_complete_args).collect() | ||
|
|
||
| if len(raw_response) > 0: | ||
| return json.loads(raw_response[0].COMPLETE_CAL) | ||
|
|
||
| else: | ||
| return json.loads('{"choices": [{"messages": "None"}]}') | ||
|
|
||
| def basic_request(self, prompt: str, **kwargs) -> list: | ||
| raw_kwargs = kwargs | ||
| kwargs = self._prepare_params(raw_kwargs) | ||
|
|
||
| response = self._cortex_complete_request(prompt, **kwargs) | ||
|
|
||
| history = { | ||
| "prompt": prompt, | ||
| "response": { | ||
| "prompt": prompt, | ||
| "choices": [{"text": c} for c in response["choices"]], | ||
| }, | ||
| "kwargs": kwargs, | ||
| "raw_kwargs": raw_kwargs, | ||
| } | ||
|
|
||
| self.history.append(history) | ||
|
|
||
| return [i["text"]["messages"] for i in history["response"]["choices"]] | ||
|
|
||
| @backoff.on_exception( | ||
| backoff.expo, | ||
| (Exception), | ||
| max_time=1000, | ||
| on_backoff=backoff_hdlr, | ||
| giveup=giveup_hdlr, | ||
| ) | ||
| def _request(self, prompt: str, **kwargs): | ||
| """Handles retrieval of completions from Snowflake Cortex whilst handling API errors.""" | ||
| return self.basic_request(prompt, **kwargs) | ||
|
|
||
| def __call__( | ||
| self, | ||
| prompt: str, | ||
| only_completed: bool = True, | ||
| return_sorted: bool = False, | ||
| **kwargs, | ||
| ): | ||
| return self._request(prompt, **kwargs) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it would be useful to define some required
connection_parametersfrom Snowflake based on the documentation so users know what to expect when configuringdspy.Snowflake.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. added the same as in the RM docs