diff --git a/README.md b/README.md
index 0c3f0669e2..a69bdbf4fa 100644
--- a/README.md
+++ b/README.md
@@ -72,11 +72,11 @@ Or open our intro notebook in Google Colab: [
dspy.Prediction`
+
+Search the Snowflake table for the top `k` passages matching the given query or queries, using embeddings generated via the default `e5-base-v2` model or the specified `embedding_model`.
+
+**Parameters:**
+
+- `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for.
+- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization.
+
+**Returns:**
+
+- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with schema `[{"id": str, "score": float, "long_text": str, "metadatas": dict }]`
+
+### Quickstart
+
+To support passage retrieval, it assumes that a Snowflake table has been created and populated with the passages in a column `embeddings_text_field` and the embeddings in another column `embeddings_field`
+
+SnowflakeRM uses `e5-base-v2` embeddings model by default or any Snowflake Cortex supported embeddings model.
+
+#### Default OpenAI Embeddings
+
+```python
+from dspy.retrieve.snowflake_rm import SnowflakeRM
+import os
+
+connection_parameters = {
+
+ "account": os.getenv('SNOWFLAKE_ACCOUNT'),
+ "user": os.getenv('SNOWFLAKE_USER'),
+ "password": os.getenv('SNOWFLAKE_PASSWORD'),
+ "role": os.getenv('SNOWFLAKE_ROLE'),
+ "warehouse": os.getenv('SNOWFLAKE_WAREHOUSE'),
+ "database": os.getenv('SNOWFLAKE_DATABASE'),
+ "schema": os.getenv('SNOWFLAKE_SCHEMA')}
+
+retriever_model = SnowflakeRM(
+ snowflake_table_name="",
+ snowflake_credentials=connection_parameters,
+ embeddings_field="",
+ embeddings_text_field= ""
+ )
+
+results = retriever_model("Explore the meaning of life", k=5)
+
+for result in results:
+ print("Document:", result.long_text, "\n")
+```
diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py
index 829fd655b1..d480490a78 100644
--- a/dsp/modules/__init__.py
+++ b/dsp/modules/__init__.py
@@ -22,4 +22,6 @@
from .pyserini import *
from .sbert import *
from .sentence_vectorizer import *
+from .snowflake import *
from .watsonx import *
+
diff --git a/dsp/modules/snowflake.py b/dsp/modules/snowflake.py
new file mode 100644
index 0000000000..09bffd787a
--- /dev/null
+++ b/dsp/modules/snowflake.py
@@ -0,0 +1,164 @@
+"""Module for interacting with Snowflake Cortex."""
+import json
+from typing import Any
+
+import backoff
+from pydantic_core import PydanticCustomError
+
+from dsp.modules.lm import LM
+
+try:
+ from snowflake.snowpark import Session
+ from snowflake.snowpark import functions as snow_func
+
+except ImportError:
+ pass
+
+
+def backoff_hdlr(details) -> None:
+ """Handler from https://pypi.org/project/backoff ."""
+ print(
+ f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries ",
+ f"calling function {details['target']} with kwargs",
+ f"{details['kwargs']}",
+ )
+
+
+def giveup_hdlr(details) -> bool:
+ """Wrapper function that decides when to give up on retry."""
+ if "rate limits" in str(details):
+ return False
+ return True
+
+
+class Snowflake(LM):
+ """Wrapper around Snowflake's CortexAPI.
+
+ Currently supported models include 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b',
+ 'llama2-70b-chat','mistral-7b','gemma-7b','llama3-8b','llama3-70b','reka-core'.
+ """
+
+ def __init__(self, model: str = "mixtral-8x7b", credentials=None, **kwargs):
+ """Parameters
+
+ ----------
+ model : str
+ Which pre-trained model from Snowflake to use?
+ Choices are 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b','llama2-70b-chat','mistral-7b','gemma-7b'
+ Full list of supported models is available here: https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#complete
+ credentials: dict
+ Snowflake credentials required to initialize the session.
+ Full list of requirements can be found here: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session
+ **kwargs: dict
+ Additional arguments to pass to the API provider.
+ """
+ super().__init__(model)
+
+ self.model = model
+ cortex_models = [
+ "llama3-8b",
+ "llama3-70b",
+ "reka-core",
+ "snowflake-arctic",
+ "mistral-large",
+ "reka-flash",
+ "mixtral-8x7b",
+ "llama2-70b-chat",
+ "mistral-7b",
+ "gemma-7b",
+ ]
+
+ if model in cortex_models:
+ self.available_args = {
+ "max_tokens",
+ "temperature",
+ "top_p",
+ }
+ else:
+ raise PydanticCustomError(
+ "model",
+ 'model name is not valid, got "{model_name}"',
+ )
+
+ self.client = self._init_cortex(credentials=credentials)
+ self.provider = "Snowflake"
+ self.history: list[dict[str, Any]] = []
+ self.kwargs = {
+ **self.kwargs,
+ "temperature": 0.7,
+ "max_output_tokens": 1024,
+ "top_p": 1.0,
+ "top_k": 1,
+ **kwargs,
+ }
+
+ @classmethod
+ def _init_cortex(cls, credentials: dict) -> None:
+ session = Session.builder.configs(credentials).create()
+ session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}}
+
+ return session
+
+ def _prepare_params(
+ self,
+ parameters: Any,
+ ) -> dict:
+ params_mapping = {"n": "candidate_count", "max_tokens": "max_output_tokens"}
+ params = {params_mapping.get(k, k): v for k, v in parameters.items()}
+ params = {**self.kwargs, **params}
+ return {k: params[k] for k in set(params.keys()) & self.available_args}
+
+ def _cortex_complete_request(self, prompt: str, **kwargs) -> dict:
+ complete = snow_func.builtin("snowflake.cortex.complete")
+ cortex_complete_args = complete(
+ snow_func.lit(self.model),
+ snow_func.lit([{"role": "user", "content": prompt}]),
+ snow_func.lit(kwargs),
+ )
+ raw_response = self.client.range(1).withColumn("complete_cal", cortex_complete_args).collect()
+
+ if len(raw_response) > 0:
+ return json.loads(raw_response[0].COMPLETE_CAL)
+
+ else:
+ return json.loads('{"choices": [{"messages": "None"}]}')
+
+ def basic_request(self, prompt: str, **kwargs) -> list:
+ raw_kwargs = kwargs
+ kwargs = self._prepare_params(raw_kwargs)
+
+ response = self._cortex_complete_request(prompt, **kwargs)
+
+ history = {
+ "prompt": prompt,
+ "response": {
+ "prompt": prompt,
+ "choices": [{"text": c} for c in response["choices"]],
+ },
+ "kwargs": kwargs,
+ "raw_kwargs": raw_kwargs,
+ }
+
+ self.history.append(history)
+
+ return [i["text"]["messages"] for i in history["response"]["choices"]]
+
+ @backoff.on_exception(
+ backoff.expo,
+ (Exception),
+ max_time=1000,
+ on_backoff=backoff_hdlr,
+ giveup=giveup_hdlr,
+ )
+ def _request(self, prompt: str, **kwargs):
+ """Handles retrieval of completions from Snowflake Cortex whilst handling API errors."""
+ return self.basic_request(prompt, **kwargs)
+
+ def __call__(
+ self,
+ prompt: str,
+ only_completed: bool = True,
+ return_sorted: bool = False,
+ **kwargs,
+ ):
+ return self._request(prompt, **kwargs)
diff --git a/dspy/__init__.py b/dspy/__init__.py
index 957e4755db..da659a2f45 100644
--- a/dspy/__init__.py
+++ b/dspy/__init__.py
@@ -25,6 +25,7 @@
Google = dsp.Google
GoogleVertexAI = dsp.GoogleVertexAI
GROQ = dsp.GroqLM
+Snowflake = dsp.Snowflake
Claude = dsp.Claude
HFClientTGI = dsp.HFClientTGI
diff --git a/dspy/retrieve/snowflake_rm.py b/dspy/retrieve/snowflake_rm.py
new file mode 100644
index 0000000000..40aac2b59f
--- /dev/null
+++ b/dspy/retrieve/snowflake_rm.py
@@ -0,0 +1,116 @@
+from typing import Optional, Union
+
+import dspy
+from dsp.utils import dotdict
+
+try:
+ from snowflake.snowpark import Session
+ from snowflake.snowpark import functions as snow_fn
+ from snowflake.snowpark.functions import col, function, lit
+ from snowflake.snowpark.types import VectorType
+
+except ImportError:
+ raise ImportError(
+ "The snowflake-snowpark-python library is required to use SnowflakeRM. Install it with dspy-ai[snowflake]",
+ )
+
+
+class SnowflakeRM(dspy.Retrieve):
+ """A retrieval module that uses Weaviate to return the top passages for a given query.
+
+ Assumes that a Snowflake table has been created and populated with the following payload:
+ - content: The text of the passage
+
+ Args:
+ snowflake_credentials: connection parameters for initializing Snowflake client.
+ snowflake_table_name (str): The name of the Snowflake table containing document embeddings.
+ embeddings_field (str): The field in the Snowflake table with the content embeddings
+ embeddings_text_field (str): The field in the Snowflake table with the content.
+ k (int, optional): The default number of top passages to retrieve. Defaults to 3.
+ """
+
+ def __init__(
+ self,
+ snowflake_table_name: str,
+ snowflake_credentials: dict,
+ k: int = 3,
+ embeddings_field: str = "chunk_vec",
+ embeddings_text_field: str = "chunk",
+ embeddings_model: str = "e5-base-v2",
+ ):
+ self.snowflake_table_name = snowflake_table_name
+ self.embeddings_field = embeddings_field
+ self.embeddings_text_field = embeddings_text_field
+ self.embeddings_model = embeddings_model
+ self.client = self._init_cortex(credentials=snowflake_credentials)
+
+ super().__init__(k=k)
+
+ def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None) -> dspy.Prediction:
+ """Search Snowflake document embeddings table for self.k top passages for query.
+
+ Args:
+ query_or_queries (Union[str, List[str]]): The query or queries to search for.
+ k (Optional[int]): The number of top passages to retrieve. Defaults to self.k.
+
+ Returns:
+ dspy.Prediction: An object containing the retrieved passages.
+ """
+ k = k if k is not None else self.k
+ queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
+ queries = [q for q in queries if q]
+ passages = []
+
+ for query in queries:
+ query_embeddings = self._get_embeddings(query)
+ top_k_chunks = self._top_k_similar_chunks(query_embeddings, k)
+
+ passages.extend(dotdict({"long_text": passage[0]}) for passage in top_k_chunks)
+
+ return passages
+
+ def _top_k_similar_chunks(self, query_embeddings, k):
+ """Search Snowflake table for self.k top passages for query.
+
+ Args:
+ query_embeddings(List[float]]): the embeddings for the query of interest
+ doc_table
+ k (Optional[int]): The number of top passages to retrieve. Defaults to self.k.
+
+ Returns:
+ dspy.Prediction: An object containing the retrieved passages.
+ """
+ doc_table_value = self.embeddings_field
+ doc_table_key = self.embeddings_text_field
+
+ doc_embeddings = self.client.table(self.snowflake_table_name)
+ cosine_similarity = function("vector_cosine_similarity")
+
+ top_k = (
+ doc_embeddings.select(
+ doc_table_value,
+ doc_table_key,
+ cosine_similarity(
+ doc_embeddings.col(doc_table_value),
+ lit(query_embeddings).cast(VectorType(float, len(query_embeddings))),
+ ).as_("dist"),
+ )
+ .sort("dist", ascending=False)
+ .limit(k)
+ )
+
+ return top_k.select(doc_table_key).to_pandas().values
+
+ @classmethod
+ def _init_cortex(cls, credentials: dict) -> None:
+ session = Session.builder.configs(credentials).create()
+ session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}}
+
+ return session
+
+ def _get_embeddings(self, query: str) -> list[float]:
+ # create embeddings for the query
+ embed = snow_fn.builtin("snowflake.cortex.embed_text_768")
+ cortex_embed_args = embed(snow_fn.lit(self.embeddings_model), snow_fn.lit(query))
+
+ return self.client.range(1).withColumn("complete_cal", cortex_embed_args).collect()[0].COMPLETE_CAL
diff --git a/poetry.lock b/poetry.lock
index e4e0345594..7d9f6b8516 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]]
name = "aiohttp"
@@ -282,6 +282,17 @@ typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""}
[package.extras]
tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"]
+[[package]]
+name = "asn1crypto"
+version = "1.5.1"
+description = "Fast ASN.1 parser and serializer with definitions for private keys, public keys, certificates, CRL, OCSP, CMS, PKCS#3, PKCS#7, PKCS#8, PKCS#12, PKCS#5, X.509 and TSP"
+optional = true
+python-versions = "*"
+files = [
+ {file = "asn1crypto-1.5.1-py2.py3-none-any.whl", hash = "sha256:db4e40728b728508912cbb3d44f19ce188f218e9eba635821bb4b68564f8fd67"},
+ {file = "asn1crypto-1.5.1.tar.gz", hash = "sha256:13ae38502be632115abf8a24cbe5f4da52e3b5231990aff31123c805306ccb9c"},
+]
+
[[package]]
name = "asttokens"
version = "2.4.1"
@@ -902,6 +913,17 @@ files = [
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
+[[package]]
+name = "cloudpickle"
+version = "2.2.1"
+description = "Extended pickling support for Python objects"
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "cloudpickle-2.2.1-py3-none-any.whl", hash = "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f"},
+ {file = "cloudpickle-2.2.1.tar.gz", hash = "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"},
+]
+
[[package]]
name = "colorama"
version = "0.4.6"
@@ -4531,6 +4553,23 @@ files = [
[package.extras]
windows-terminal = ["colorama (>=0.4.6)"]
+[[package]]
+name = "pyjwt"
+version = "2.8.0"
+description = "JSON Web Token implementation in Python"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"},
+ {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"},
+]
+
+[package.extras]
+crypto = ["cryptography (>=3.4.0)"]
+dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
+docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
+tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
+
[[package]]
name = "pymdown-extensions"
version = "10.8.1"
@@ -4577,6 +4616,24 @@ ujson = ">=2.0.0"
model = ["milvus-model (>=0.1.0)"]
test = ["black", "grpcio-testing", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>=0.3.3)"]
+[[package]]
+name = "pyopenssl"
+version = "24.1.0"
+description = "Python wrapper module around the OpenSSL library"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "pyOpenSSL-24.1.0-py3-none-any.whl", hash = "sha256:17ed5be5936449c5418d1cd269a1a9e9081bc54c17aed272b45856a3d3dc86ad"},
+ {file = "pyOpenSSL-24.1.0.tar.gz", hash = "sha256:cabed4bfaa5df9f1a16c0ef64a0cb65318b5cd077a7eda7d6970131ca2f41a6f"},
+]
+
+[package.dependencies]
+cryptography = ">=41.0.5,<43"
+
+[package.extras]
+docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"]
+test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
+
[[package]]
name = "pyparsing"
version = "3.1.2"
@@ -5503,6 +5560,105 @@ files = [
{file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"},
]
+[[package]]
+name = "snowflake-connector-python"
+version = "3.10.0"
+description = "Snowflake Connector for Python"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "snowflake_connector_python-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e2afca4bca70016519d1a7317c498f1d9c56140bf3e40ea40bddcc95fe827ca"},
+ {file = "snowflake_connector_python-3.10.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:d19bde29f89b226eb22af4c83134ecb5c229da1d5e960a01b8f495df78dcdc36"},
+ {file = "snowflake_connector_python-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bfe013ed97b4dd2e191fd6770a14030d29dd0108817d6ce76b9773250dd2d560"},
+ {file = "snowflake_connector_python-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0917c9f9382d830907e1a18ee1208537b203618700a9c671c2a20167b30f574"},
+ {file = "snowflake_connector_python-3.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:7e828bc99240433e6552ac4cc4e37f223ae5c51c7880458ddb281668503c7491"},
+ {file = "snowflake_connector_python-3.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a0d3d06d758455c50b998eabc1fd972a1f67faa5c85ef250fd5986f5a41aab0b"},
+ {file = "snowflake_connector_python-3.10.0-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:4602cb19b204bb03e03d65c6d5328467c9efc0fec53ca56768c3747c8dc8a70f"},
+ {file = "snowflake_connector_python-3.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb1a04b496bbd3e1e2e926df82b2369887b2eea958f535fb934c240bfbabf6c5"},
+ {file = "snowflake_connector_python-3.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c889f9f60f915d657e0a0ad2e6cc52cdcafd9bcbfa95a095aadfd8bcae62b819"},
+ {file = "snowflake_connector_python-3.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:8e441484216ed416a6ed338133e23bd991ac4ba2e46531f4d330f61568c49314"},
+ {file = "snowflake_connector_python-3.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bb4aced19053c67513cecc92311fa9d3b507b2277698c8e987d404f6f3a49fb2"},
+ {file = "snowflake_connector_python-3.10.0-cp312-cp312-macosx_11_0_x86_64.whl", hash = "sha256:858315a2feff86213b079c6293ad8d850a778044c664686802ead8bb1337e1bc"},
+ {file = "snowflake_connector_python-3.10.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:adf16e1ca9f46d3bdf68e955ffa42075ebdb251e3b13b59003d04e4fea7d579a"},
+ {file = "snowflake_connector_python-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4c5c2a08b39086a5348502652ad4fdf24871d7ab30fd59f6b7b57249158468c"},
+ {file = "snowflake_connector_python-3.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:05011286f42c52eb3e5a6db59ee3eaf79f3039f3a19d7ffac6f4ee143779c637"},
+ {file = "snowflake_connector_python-3.10.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:569301289ada5b0d72d0bd8432b7ca180220335faa6d9a0f7185f60891db6f2c"},
+ {file = "snowflake_connector_python-3.10.0-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:4e5641c70a12da9804b74f350b8cbbdffdc7aca5069b096755abd2a1fdcf5d1b"},
+ {file = "snowflake_connector_python-3.10.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12ff767a1b8c48431549ac28884f8bd9647e63a23f470b05f6ab8d143c4b1475"},
+ {file = "snowflake_connector_python-3.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e52bbc1e2e7bda956525b4229d7f87579f8cabd7d5506b12aa754c4bcdc8c8d7"},
+ {file = "snowflake_connector_python-3.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:280a8dcca0249e864419564e38764c08f8841900d9872fec2f2855fda494b29f"},
+ {file = "snowflake_connector_python-3.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:67bf570230b0cf818e6766c17245c7355a1f5ea27778e54ab8d09e5bb3536ad9"},
+ {file = "snowflake_connector_python-3.10.0-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:aa1e26f9c571d2c4206da5c978c1b345ffd798d3db1f9ae91985e6243c6bf94b"},
+ {file = "snowflake_connector_python-3.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73e9baa531d5156a03bfe5af462cf6193ec2a01cbb575edf7a2dd3b2a35254c7"},
+ {file = "snowflake_connector_python-3.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e03361c4749e4d65bf0d223fdea1c2d7a33af53b74e873929a6085d150aff17e"},
+ {file = "snowflake_connector_python-3.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:e8cddd4357e70ab55d7aeeed144cbbeb1ff658b563d7d8d307afc06178a367ec"},
+ {file = "snowflake_connector_python-3.10.0.tar.gz", hash = "sha256:7c7438e958753bd1174b73581d77c92b0b47a86c38d8ea0ba1ea23c442eb8e75"},
+]
+
+[package.dependencies]
+asn1crypto = ">0.24.0,<2.0.0"
+certifi = ">=2017.4.17"
+cffi = ">=1.9,<2.0.0"
+charset-normalizer = ">=2,<4"
+cryptography = ">=3.1.0,<43.0.0"
+filelock = ">=3.5,<4"
+idna = ">=2.5,<4"
+packaging = "*"
+platformdirs = ">=2.6.0,<5.0.0"
+pyjwt = "<3.0.0"
+pyOpenSSL = ">=16.2.0,<25.0.0"
+pytz = "*"
+requests = "<3.0.0"
+sortedcontainers = ">=2.4.0"
+tomlkit = "*"
+typing-extensions = ">=4.3,<5"
+urllib3 = {version = ">=1.21.1,<2.0.0", markers = "python_version < \"3.10\""}
+
+[package.extras]
+development = ["Cython", "coverage", "more-itertools", "numpy (<1.27.0)", "pendulum (!=2.1.1)", "pexpect", "pytest (<7.5.0)", "pytest-cov", "pytest-rerunfailures", "pytest-timeout", "pytest-xdist", "pytzdata"]
+pandas = ["pandas (>=1.0.0,<3.0.0)", "pyarrow"]
+secure-local-storage = ["keyring (>=23.1.0,<25.0.0)"]
+
+[[package]]
+name = "snowflake-snowpark-python"
+version = "1.16.0"
+description = "Snowflake Snowpark for Python"
+optional = true
+python-versions = "<3.12,>=3.8"
+files = [
+ {file = "snowflake_snowpark_python-1.16.0-py3-none-any.whl", hash = "sha256:3b3713235644bfa463f41a72e368e0007667c4efb91d770e9a5681164e495aee"},
+ {file = "snowflake_snowpark_python-1.16.0.tar.gz", hash = "sha256:b6c25fa37878f250ee8dca40c83bf556bc6d983be85818fd0767fcee893f9112"},
+]
+
+[package.dependencies]
+cloudpickle = [
+ {version = ">=1.6.0,<2.1.0 || >2.1.0,<2.2.0 || >2.2.0,<=2.2.1", markers = "python_version < \"3.11\""},
+ {version = "2.2.1", markers = "python_version ~= \"3.11\""},
+]
+pyyaml = "*"
+setuptools = ">=40.6.0"
+snowflake-connector-python = ">=3.10.0,<4.0.0"
+typing-extensions = ">=4.1.0,<5.0.0"
+wheel = "*"
+
+[package.extras]
+development = ["cachetools", "coverage", "pre-commit", "pytest (<8.0.0)", "pytest-cov", "pytest-timeout", "sphinx (==5.0.2)"]
+localtest = ["pandas", "pyarrow", "requests"]
+opentelemetry = ["opentelemetry-api (>=1.0.0,<2.0.0)", "opentelemetry-sdk (>=1.0.0,<2.0.0)"]
+pandas = ["snowflake-connector-python[pandas] (>=3.10.0,<4.0.0)"]
+secure-local-storage = ["snowflake-connector-python[secure-local-storage] (>=3.10.0,<4.0.0)"]
+
+[[package]]
+name = "sortedcontainers"
+version = "2.4.0"
+description = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set"
+optional = true
+python-versions = "*"
+files = [
+ {file = "sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0"},
+ {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"},
+]
+
[[package]]
name = "soupsieve"
version = "2.5"
@@ -6092,6 +6248,17 @@ files = [
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
+[[package]]
+name = "tomlkit"
+version = "0.12.5"
+description = "Style preserving TOML library"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "tomlkit-0.12.5-py3-none-any.whl", hash = "sha256:af914f5a9c59ed9d0762c7b64d3b5d5df007448eb9cd2edc8a46b1eafead172f"},
+ {file = "tomlkit-0.12.5.tar.gz", hash = "sha256:eef34fba39834d4d6b73c9ba7f3e4d1c417a4e56f89a7e96e090dd0d24b8fb3c"},
+]
+
[[package]]
name = "torch"
version = "2.3.0"
@@ -6421,23 +6588,6 @@ brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotl
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
-[[package]]
-name = "urllib3"
-version = "2.2.1"
-description = "HTTP library with thread-safe connection pooling, file post, and more."
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"},
- {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"},
-]
-
-[package.extras]
-brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
-h2 = ["h2 (>=4,<5)"]
-socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
-zstd = ["zstandard (>=0.18.0)"]
-
[[package]]
name = "uvicorn"
version = "0.29.0"
@@ -6810,6 +6960,20 @@ files = [
{file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"},
]
+[[package]]
+name = "wheel"
+version = "0.43.0"
+description = "A built-package format for Python"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "wheel-0.43.0-py3-none-any.whl", hash = "sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81"},
+ {file = "wheel-0.43.0.tar.gz", hash = "sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85"},
+]
+
+[package.extras]
+test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
+
[[package]]
name = "win32-setctime"
version = "1.1.0"
@@ -7153,4 +7317,4 @@ weaviate = ["weaviate-client"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.13"
-content-hash = "dfb44251298e064041c90ee3a63a1cf5baaf2c6ce5c4bbaa4a036247c74852a8"
+content-hash = "d5fc4db9e32e22b358c93b3e026dbecbfd1f32fc4c4106112ba95e37c0aa259d"
diff --git a/pyproject.toml b/pyproject.toml
index 962df57c7f..7557fdce05 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -120,9 +120,11 @@ rich = "^13.7.1"
psycopg2 = { version = "^2.9.9", optional = true }
pgvector = { version = "^0.2.5", optional = true }
structlog = "^24.1.0"
+snowflake-snowpark-python = { version = "*",optional=true, python = ">=3.9,<3.12" }
jinja2 = "^3.1.3"
+
[tool.poetry.group.dev.dependencies]
pytest = "^6.2.5"
transformers = "^4.38.2"
diff --git a/setup.py b/setup.py
index 9918a8cd88..7d1e973653 100644
--- a/setup.py
+++ b/setup.py
@@ -1,12 +1,12 @@
from setuptools import find_packages, setup
-# Read the content of the README file
-with open('README.md', encoding='utf-8') as f:
- long_description = f.read()
+# Read the content of the README file
+with open("README.md", encoding="utf-8") as f:
+ long_description = f.read()
-# Read the content of the requirements.txt file
-with open('requirements.txt', encoding='utf-8') as f:
- requirements = f.read().splitlines()
+# Read the content of the requirements.txt file
+with open("requirements.txt", encoding="utf-8") as f:
+ requirements = f.read().splitlines()
setup(
name="dspy-ai",
@@ -21,16 +21,18 @@
packages=find_packages(include=['dsp.*', 'dspy.*', 'dsp', 'dspy']),
python_requires='>=3.9',
install_requires=requirements,
+
extras_require={
"chromadb": ["chromadb~=0.4.14"],
"qdrant": ["qdrant-client", "fastembed"],
"marqo": ["marqo~=3.1.0"],
- "mongodb": ["pymongo~=3.12.0"],
- "pinecone": ["pinecone-client~=2.2.4"],
- "weaviate": ["weaviate-client~=3.26.1"],
+ "mongodb": ["pymongo~=3.12.0"],
+ "pinecone": ["pinecone-client~=2.2.4"],
+ "weaviate": ["weaviate-client~=3.26.1"],
"faiss-cpu": ["sentence_transformers", "faiss-cpu"],
"milvus": ["pymilvus~=2.3.7"],
"google-vertex-ai": ["google-cloud-aiplatform==1.43.0"],
+ "snowflake": ["snowflake-snowpark-python"],
"fastembed": ["fastembed"],
},
classifiers=[