Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d7b41cd
feat(dspy): testing initial commit
sfc-gh-alherrera Apr 17, 2024
dfae177
feat(dspy): initial cleanup of Snowflake (LM)support
sfc-gh-alherrera Apr 17, 2024
e492cfa
feat(dspy): added SnowflakeRM retriever and cleanup of Snowflake (LM)…
sfc-gh-alherrera Apr 17, 2024
f2ce44f
feat(dspy): added SnowflakeRM retriever and cleanup of Snowflake (LM)…
sfc-gh-alherrera Apr 17, 2024
6af2d47
feat(dspy): added SnowflakeRM retriever and cleanup of Snowflake (LM)…
sfc-gh-alherrera Apr 17, 2024
f42df77
feat(dspy): added SnowflakeRM retriever and cleanup of Snowflake (LM)…
sfc-gh-alherrera Apr 17, 2024
8ce1378
feat(dspy): added Snowflake Cortex (LM) and Snowflake RM support
sfc-gh-alherrera Apr 17, 2024
bbb6204
feat(dspy): added Snowflake Cortex (LM) and Snowflake RM support
sfc-gh-alherrera Apr 17, 2024
bf9247f
fix(dspy): removing unnecessary ipykernel dependency
sfc-gh-alherrera May 6, 2024
4ed0753
fix(dspy): updating import error handling language for SnowflakeRM
sfc-gh-alherrera May 6, 2024
315d71a
fix(dspy): updating README to include Snowflake as optional supported…
sfc-gh-alherrera May 6, 2024
cfa67d6
fix(dspy): updating SnowflakeRM typos
sfc-gh-alherrera May 6, 2024
75cd65b
fix(dspy): updating sort order to be descending for retriever results
sfc-gh-alherrera May 6, 2024
c663444
docs(dspy): Adding documentation for Snowflake LM
sfc-gh-alherrera May 6, 2024
2cb77df
docs(dspy): Adding documentation for Snowflake RM
sfc-gh-alherrera May 6, 2024
daef34e
docs(dspy): Adding documentation for Snowflake RM
sfc-gh-alherrera May 6, 2024
4f10658
fix(dspy): Adding self.history definition to Snowflake LM
sfc-gh-alherrera May 6, 2024
e6afb61
Merge branch 'main' into dspy-snowflake
sfc-gh-alherrera May 6, 2024
fad2f11
fix(dspy): Adding missing LLMs supported in Snowflake and tag for ses…
sfc-gh-alherrera May 9, 2024
441fd43
fix(dspy): Adding missing LLMs supported in Snowflake and tag for ses…
sfc-gh-alherrera May 9, 2024
8775674
fix(dspy): adding language for recently released snowflake retriever …
sfc-gh-alherrera May 9, 2024
a048a51
fix(dspy): updating syntax for embeddings method call which will be d…
sfc-gh-alherrera May 9, 2024
b3af69c
fix(dspy): updating syntax for Snowflake cos similarity method which …
sfc-gh-alherrera May 9, 2024
53630da
fix(dspy): updating syntax for Snowflake cos similarity method which …
sfc-gh-alherrera May 9, 2024
b0b18a7
fix(dspy): updating syntax for Snowflake cos similarity method which …
sfc-gh-alherrera May 10, 2024
8dedd00
fix(dspy): solving for null response bug in Cortex API
sfc-gh-alherrera May 10, 2024
23cd780
fix(dspy): adding LM connection parameter update to docs, resolving r…
sfc-gh-alherrera May 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ Or open our intro notebook in Google Colab: [<img align="center" src="https://co

By default, DSPy installs the latest `openai` from pip. However, if you install old version before OpenAI changed their API `openai~=0.28.1`, the library will use that just fine. Both are supported.

For the optional (alphabetically sorted) [Chromadb](https://github.com/chroma-core/chroma), [Qdrant](https://github.com/qdrant/qdrant), [Marqo](https://github.com/marqo-ai/marqo), Pinecone, [Weaviate](https://github.com/weaviate/weaviate),
For the optional (alphabetically sorted) [Chromadb](https://github.com/chroma-core/chroma), [Qdrant](https://github.com/qdrant/qdrant), [Marqo](https://github.com/marqo-ai/marqo), Pinecone, [Snowflake](https://github.com/snowflakedb/snowpark-python) [Weaviate](https://github.com/weaviate/weaviate),
or [Milvus](https://github.com/milvus-io/milvus) retrieval integration(s), include the extra(s) below:

```
pip install dspy-ai[chromadb] # or [qdrant] or [marqo] or [mongodb] or [pinecone] or [weaviate] or [milvus]
pip install dspy-ai[chromadb] # or [qdrant] or [marqo] or [mongodb] or [pinecone] or [snowflake] or [weaviate] or [milvus]
```

## 2) Documentation
Expand Down
45 changes: 45 additions & 0 deletions docs/api/language_model_clients/Snowflake.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
---
sidebar_position:
---

# dspy.Snowflake

### Usage

```python
import dspy
import os

connection_parameters = {

"account": os.getenv('SNOWFLAKE_ACCOUNT'),
"user": os.getenv('SNOWFLAKE_USER'),
"password": os.getenv('SNOWFLAKE_PASSWORD'),
"role": os.getenv('SNOWFLAKE_ROLE'),
"warehouse": os.getenv('SNOWFLAKE_WAREHOUSE'),
"database": os.getenv('SNOWFLAKE_DATABASE'),
"schema": os.getenv('SNOWFLAKE_SCHEMA')}

lm = dspy.Snowflake(model="mixtral-8x7b",credentials=connection_parameters)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it would be useful to define some required connection_parameters from Snowflake based on the documentation so users know what to expect when configuring dspy.Snowflake.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. added the same as in the RM docs

```

### Constructor

The constructor inherits from the base class `LM` and verifies the `credentials` for using Snowflake API.

```python
class Snowflake(LM):
def __init__(
self,
model,
credentials,
**kwargs):
```

**Parameters:**
- `model` (_str_): model hosted by [Snowflake Cortex](https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#availability).
- `credentials` (_dict_): connection parameters required to initialize a [snowflake snowpark session](https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session)

### Methods

Refer to [`dspy.Snowflake`](https://dspy-docs.vercel.app/api/language_model_clients/Snowflake) documentation.
79 changes: 79 additions & 0 deletions docs/api/retrieval_model_clients/SnowflakeRM.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
---
sidebar_position:
---

# retrieve.SnowflakeRM

### Constructor

Initialize an instance of the `SnowflakeRM` class, with the option to use `e5-base-v2` or `snowflake-arctic-embed-m` embeddings or any other Snowflake Cortex supported embeddings model.

```python
SnowflakeRM(
snowflake_table_name: str,
snowflake_credentials: dict,
k: int = 3,
embeddings_field: str,
embeddings_text_field:str,
embeddings_model: str = "e5-base-v2",
)
```

**Parameters:**

- `snowflake_table_name (str)`: The name of the Snowflake table containing embeddings.
- `snowflake_credentials (dict)`: The connection parameters needed to initialize a Snowflake Snowpark Session.
- `k (int, optional)`: The number of top passages to retrieve. Defaults to 3.
- `embeddings_field (str)`: The name of the column in the Snowflake table containing the embeddings.
- `embeddings_text_field (str)`: The name of the column in the Snowflake table containing the passages.
- `embeddings_model (str)`: The model to be used to convert text to embeddings

### Methods

#### `forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction`

Search the Snowflake table for the top `k` passages matching the given query or queries, using embeddings generated via the default `e5-base-v2` model or the specified `embedding_model`.

**Parameters:**

- `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for.
- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization.

**Returns:**

- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with schema `[{"id": str, "score": float, "long_text": str, "metadatas": dict }]`

### Quickstart

To support passage retrieval, it assumes that a Snowflake table has been created and populated with the passages in a column `embeddings_text_field` and the embeddings in another column `embeddings_field`

SnowflakeRM uses `e5-base-v2` embeddings model by default or any Snowflake Cortex supported embeddings model.

#### Default OpenAI Embeddings

```python
from dspy.retrieve.snowflake_rm import SnowflakeRM
import os

connection_parameters = {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah you can add this to the LM documentation to address that comment!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

"account": os.getenv('SNOWFLAKE_ACCOUNT'),
"user": os.getenv('SNOWFLAKE_USER'),
"password": os.getenv('SNOWFLAKE_PASSWORD'),
"role": os.getenv('SNOWFLAKE_ROLE'),
"warehouse": os.getenv('SNOWFLAKE_WAREHOUSE'),
"database": os.getenv('SNOWFLAKE_DATABASE'),
"schema": os.getenv('SNOWFLAKE_SCHEMA')}

retriever_model = SnowflakeRM(
snowflake_table_name="<YOUR_SNOWFLAKE_TABLE_NAME>",
snowflake_credentials=connection_parameters,
embeddings_field="<YOUR_EMBEDDINGS_COLUMN_NAME>",
embeddings_text_field= "<YOUR_PASSAGE_COLUMN_NAME>"
)

results = retriever_model("Explore the meaning of life", k=5)

for result in results:
print("Document:", result.long_text, "\n")
```
2 changes: 2 additions & 0 deletions dsp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@
from .pyserini import *
from .sbert import *
from .sentence_vectorizer import *
from .snowflake import *
from .watsonx import *

164 changes: 164 additions & 0 deletions dsp/modules/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Module for interacting with Snowflake Cortex."""
import json
from typing import Any

import backoff
from pydantic_core import PydanticCustomError

from dsp.modules.lm import LM

try:
from snowflake.snowpark import Session
from snowflake.snowpark import functions as snow_func

except ImportError:
pass


def backoff_hdlr(details) -> None:
"""Handler from https://pypi.org/project/backoff ."""
print(
f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries ",
f"calling function {details['target']} with kwargs",
f"{details['kwargs']}",
)


def giveup_hdlr(details) -> bool:
"""Wrapper function that decides when to give up on retry."""
if "rate limits" in str(details):
return False
return True


class Snowflake(LM):
"""Wrapper around Snowflake's CortexAPI.

Currently supported models include 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b',
'llama2-70b-chat','mistral-7b','gemma-7b','llama3-8b','llama3-70b','reka-core'.
"""

def __init__(self, model: str = "mixtral-8x7b", credentials=None, **kwargs):
"""Parameters

----------
model : str
Which pre-trained model from Snowflake to use?
Choices are 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b','llama2-70b-chat','mistral-7b','gemma-7b'
Full list of supported models is available here: https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#complete
credentials: dict
Snowflake credentials required to initialize the session.
Full list of requirements can be found here: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session
**kwargs: dict
Additional arguments to pass to the API provider.
"""
super().__init__(model)

self.model = model
cortex_models = [
"llama3-8b",
"llama3-70b",
"reka-core",
"snowflake-arctic",
"mistral-large",
"reka-flash",
"mixtral-8x7b",
"llama2-70b-chat",
"mistral-7b",
"gemma-7b",
]

if model in cortex_models:
self.available_args = {
"max_tokens",
"temperature",
"top_p",
}
else:
raise PydanticCustomError(
"model",
'model name is not valid, got "{model_name}"',
)

self.client = self._init_cortex(credentials=credentials)
self.provider = "Snowflake"
self.history: list[dict[str, Any]] = []
self.kwargs = {
**self.kwargs,
"temperature": 0.7,
"max_output_tokens": 1024,
"top_p": 1.0,
"top_k": 1,
**kwargs,
}

@classmethod
def _init_cortex(cls, credentials: dict) -> None:
session = Session.builder.configs(credentials).create()
session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}}

return session

def _prepare_params(
self,
parameters: Any,
) -> dict:
params_mapping = {"n": "candidate_count", "max_tokens": "max_output_tokens"}
params = {params_mapping.get(k, k): v for k, v in parameters.items()}
params = {**self.kwargs, **params}
return {k: params[k] for k in set(params.keys()) & self.available_args}

def _cortex_complete_request(self, prompt: str, **kwargs) -> dict:
complete = snow_func.builtin("snowflake.cortex.complete")
cortex_complete_args = complete(
snow_func.lit(self.model),
snow_func.lit([{"role": "user", "content": prompt}]),
snow_func.lit(kwargs),
)
raw_response = self.client.range(1).withColumn("complete_cal", cortex_complete_args).collect()

if len(raw_response) > 0:
return json.loads(raw_response[0].COMPLETE_CAL)

else:
return json.loads('{"choices": [{"messages": "None"}]}')

def basic_request(self, prompt: str, **kwargs) -> list:
raw_kwargs = kwargs
kwargs = self._prepare_params(raw_kwargs)

response = self._cortex_complete_request(prompt, **kwargs)

history = {
"prompt": prompt,
"response": {
"prompt": prompt,
"choices": [{"text": c} for c in response["choices"]],
},
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}

self.history.append(history)

return [i["text"]["messages"] for i in history["response"]["choices"]]

@backoff.on_exception(
backoff.expo,
(Exception),
max_time=1000,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
)
def _request(self, prompt: str, **kwargs):
"""Handles retrieval of completions from Snowflake Cortex whilst handling API errors."""
return self.basic_request(prompt, **kwargs)

def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
):
return self._request(prompt, **kwargs)
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Google = dsp.Google
GoogleVertexAI = dsp.GoogleVertexAI
GROQ = dsp.GroqLM
Snowflake = dsp.Snowflake
Claude = dsp.Claude

HFClientTGI = dsp.HFClientTGI
Expand Down
Loading