diff --git a/README.md b/README.md index 36700c92..5851cc5b 100644 --- a/README.md +++ b/README.md @@ -42,15 +42,15 @@ Indices can be defined through yaml specification that corresponds directly to t ```yaml index: - name: users + name: user_index storage_type: hash - prefix: "user:" - key_field: "id" + prefix: users + key_field: user fields: # define tag fields tag: - - name: users + - name: user - name: job - name: credit_store # define numeric fields @@ -65,7 +65,7 @@ fields: This would correspond to a dataset that looked something like -| users | age | job | credit_score | user_embedding | +| user | age | job | credit_score | user_embedding | |-------|-----|------------|--------------|-----------------------------------| | john | 1 | engineer | high | \x3f\x8c\xcc\x3f\x8c\xcc?@ | | mary | 2 | doctor | low | \x3f\x8c\xcc\x3f\x8c\xcc?@ | @@ -74,6 +74,8 @@ This would correspond to a dataset that looked something like With the schema, the RedisVL library can be used to create, load vectors and perform vector searches ```python +import pandas as pd + from redisvl.index import SearchIndex from redisvl.query import create_vector_query @@ -82,10 +84,10 @@ index = SearchIndex.from_yaml("./users_schema.yml")) index.connect("redis://localhost:6379") index.create() -index.load(pd.read_csv("./users.csv").to_records()) +index.load(pd.read_csv("./users.csv").to_dict("records")) query = create_vector_query( - ["users", "age", "job", "credit_score"], + ["user", "age", "job", "credit_score"], number_of_results=2, vector_field_name="user_embedding", ) @@ -93,6 +95,7 @@ query = create_vector_query( query_vector = np.array([0.1, 0.1, 0.5]).tobytes() results = index.search(query, query_params={"vector": query_vector}) + ``` ### Semantic cache @@ -100,29 +103,26 @@ results = index.search(query, query_params={"vector": query_vector}) The ``LLMCache`` Interface in RedisVL can be used as follows. ```python -# init open ai client -import openai -openai.api_key = "sk-xxx" - from redisvl.llmcache.semantic import SemanticCache -cache = SemanticCache(redis_host="localhost", redis_port=6379, redis_password=None) - -def ask_gpt3(question): - response = openai.Completion.create( - engine="text-davinci-003", - prompt=question, - max_tokens=100 - ) - return response.choices[0].text.strip() - -def answer_question(question: str): - results = cache.check(question) - if results: - return results[0] - else: - answer = ask_gpt3(question) - cache.store(question, answer) - return answer +cache = SemanticCache( + redis_url="redis://localhost:6379", + threshold=0.9, # semantic similarity threshold +) + +# check if the cache has a result for a given query +cache.check("What is the capital of France?") +[ ] + +# store a result for a given query +cache.store("What is the capital of France?", "Paris") + +# Cache will now have the query +cache.check("What is the capital of France?") +["Paris"] + +# Cache will return the result if the query is similar enough +cache.get("What really is the capital of France?") +["Paris"] ``` diff --git a/conftest.py b/conftest.py index 452c6db2..2f817991 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,6 @@ import os import pytest +import asyncio from redisvl.utils.connection import ( get_async_redis_connection, @@ -23,3 +24,13 @@ def client(): @pytest.fixture def openai_key(): return os.getenv("OPENAI_KEY") + + +@pytest.fixture(scope="session") +def event_loop(): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + yield loop + loop.close() \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 9d90df67..39b5e207 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -102,6 +102,7 @@ add_module_names = False nbsphinx_execute = 'never' +jupyter_execute_notebooks = "off" # -- Options for autosummary/autodoc output ------------------------------------ autosummary_generate = True @@ -129,4 +130,4 @@ "android-chrome-192x192.png", # apple icons {"rel": "apple-touch-icon", "href": "apple-touch-icon.png"}, -] \ No newline at end of file +] diff --git a/docs/user_guide/getting_started_01.ipynb b/docs/user_guide/getting_started_01.ipynb index b9e66f01..ab23a8ed 100644 --- a/docs/user_guide/getting_started_01.ipynb +++ b/docs/user_guide/getting_started_01.ipynb @@ -16,7 +16,7 @@ "4. Performing queries\n", "\n", "Before running this notebook, be sure to\n", - "1. Gave installed ``rvl`` and have that environment active for this notebook.\n", + "1. Have installed ``redisvl`` and have that environment active for this notebook.\n", "2. Have a running Redis instance with RediSearch > 2.4 running." ] }, diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index 63053e12..0a55008a 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -26,8 +26,7 @@ embedding_creation ``` ```{toctree} -:maxdepth: 2 :caption: LLMCache -llm_cache +llmcache_03 ``` diff --git a/docs/user_guide/llm_cache.rst b/docs/user_guide/llm_cache.rst deleted file mode 100644 index 65117102..00000000 --- a/docs/user_guide/llm_cache.rst +++ /dev/null @@ -1,10 +0,0 @@ - - -========= -LLM Cache -========= - -LLM Caching is the process by which interactions with the LLM are cached in -a database capable of performing vector search. This allows for the LLM to -find similar interactions to those that have been previously seen. - diff --git a/docs/user_guide/llmcache_03.ipynb b/docs/user_guide/llmcache_03.ipynb new file mode 100644 index 00000000..1a5d22c1 --- /dev/null +++ b/docs/user_guide/llmcache_03.ipynb @@ -0,0 +1,313 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Semantic Caching\n", + "\n", + "RedisVL provides the ``LLMCache`` interface to turn Redis, with it's vector search capability, into a semantic cache to store query results, thereby reducing the number of requests and tokens sent to the Large Language Models (LLM) service. This decreases expenses and enhances performance by reducing the time taken to generate responses.\n", + "\n", + "This notebook will go over how to use ``LLMCache`` for your applications" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we will import OpenAI to user their API for responding to prompts." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import openai\n", + "openai.api_key = \"sk-\"\n", + "\n", + "def ask_openai(question):\n", + " response = openai.Completion.create(\n", + " engine=\"text-davinci-003\",\n", + " prompt=question,\n", + " max_tokens=200\n", + " )\n", + " return response.choices[0].text.strip()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The capital of France is Paris.\n" + ] + } + ], + "source": [ + "# test it\n", + "print(ask_openai(\"What is the capital of France?\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initializing and using ``LLMCache``\n", + "\n", + "``LLMCache`` will automatically create an index within Redis upon initialization for the semantic cache. The same ``SearchIndex`` class used in the previous tutorials is used here to perform index creation and manipulation." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/sam.partee/.virtualenvs/rvl/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from redisvl.llmcache.semantic import SemanticCache\n", + "cache = SemanticCache(\n", + " redis_url=\"redis://localhost:6379\",\n", + " threshold=0.9, # semantic similarity threshold\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# check the cache\n", + "cache.check(\"What is the capital of France?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# store the question and answer\n", + "cache.store(\"What is the capital of France?\", \"Paris\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['Paris']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# check the cache again\n", + "cache.check(\"What is the capital of France?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# check for a semantically similar result\n", + "cache.check(\"What really is the capital of France?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['Paris']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# decrease the semantic similarity threshold\n", + "cache.set_threshold(0.7)\n", + "cache.check(\"What really is the capital of France?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# adversarial example (not semantically similar enough)\n", + "cache.check(\"What is the capital of Spain?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Performance\n", + "\n", + "Next, we will measure the speedup obtained by using ``LLMCache``. We will use the ``time`` module to measure the time taken to generate responses with and without ``LLMCache``." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def answer_question(question: str):\n", + " results = cache.check(question)\n", + " if results:\n", + " return results[0]\n", + " else:\n", + " answer = ask_openai(question)\n", + " cache.store(question, answer)\n", + " return answer" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time taken without cache 0.7418899536132812\n" + ] + } + ], + "source": [ + "import time\n", + "start = time.time()\n", + "answer = answer_question(\"What is the capital of France?\")\n", + "end = time.time()\n", + "print(f\"Time taken without cache {time.time() - start}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time Taken with cache: 0.07415914535522461\n", + "Percentage of time saved: 90.0%\n" + ] + } + ], + "source": [ + "cached_start = time.time()\n", + "cached_answer = answer_question(\"What is the capital of France?\")\n", + "cached_end = time.time()\n", + "print(f\"Time Taken with cache: {cached_end - cached_start}\")\n", + "print(f\"Percentage of time saved: {round(((end - start) - (cached_end - cached_start)) / (end - start) * 100, 2)}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "# remove the index and all cached items\n", + "cache.index.delete()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rvl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/user_guide/vector_search.rst b/docs/user_guide/vector_search.rst deleted file mode 100644 index 3a8a2b9b..00000000 --- a/docs/user_guide/vector_search.rst +++ /dev/null @@ -1,7 +0,0 @@ - -============= -Vector Search -============= - -This example shows how to search for vectors in a vector index with RedisVL - diff --git a/pyproject.toml b/pyproject.toml index 9fef0171..2872554c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ exclude = ''' [tool.pytest.ini_options] log_cli = true +asyncio_mode = "auto" [tool.coverage.run] source = ["redisvl"] diff --git a/redisvl/cli/index.py b/redisvl/cli/index.py index 8ae1be59..bf020a12 100644 --- a/redisvl/cli/index.py +++ b/redisvl/cli/index.py @@ -3,10 +3,10 @@ from argparse import Namespace from pprint import pprint +from redisvl.cli.log import get_logger from redisvl.cli.utils import create_redis_url from redisvl.index import SearchIndex from redisvl.utils.connection import get_redis_connection -from redisvl.utils.log import get_logger from redisvl.utils.utils import convert_bytes logger = get_logger(__name__) diff --git a/redisvl/utils/log.py b/redisvl/cli/log.py similarity index 100% rename from redisvl/utils/log.py rename to redisvl/cli/log.py diff --git a/redisvl/cli/main.py b/redisvl/cli/main.py index 3fff8fb9..5374357a 100644 --- a/redisvl/cli/main.py +++ b/redisvl/cli/main.py @@ -2,7 +2,7 @@ import sys from redisvl.cli.index import Index -from redisvl.utils.log import get_logger +from redisvl.cli.log import get_logger logger = get_logger(__name__) diff --git a/redisvl/index.py b/redisvl/index.py index 5bacc5f1..f0e2afcd 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -35,8 +35,9 @@ def __init__( def set_client(self, client: redis.Redis): self._redis_conn = client + @property @check_connected("_redis_conn") - def get_client(self) -> redis.Redis: + def client(self) -> redis.Redis: return self._redis_conn # type: ignore @check_connected("_redis_conn") @@ -190,6 +191,9 @@ def create(self, overwrite: Optional[bool] = False): """ check_redis_modules_exist(self._redis_conn) + if not self._fields: + raise ValueError("No fields defined for index") + if self._index_exists() and overwrite: self.delete() @@ -228,8 +232,14 @@ def load(self, data: Iterable[Dict[str, Any]], **kwargs): raises: redis.exceptions.ResponseError: If the index does not exist """ + if not data: + return + if not isinstance(data, Iterable): + if not isinstance(data[0], dict): + raise TypeError("data must be an iterable of dictionaries") for record in data: + # TODO don't use colon if no prefix key = f"{self._prefix}:{str(record[self._key_field])}" self._redis_conn.hset(key, mapping=record) # type: ignore @@ -320,7 +330,7 @@ async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10): async def load(d: dict): async with semaphore: - key = self._prefix + str(d[self._key_field]) + key = self._prefix + ":" + str(d[self._key_field]) await self._redis_conn.hset(key, mapping=d) # type: ignore # gather with concurrency diff --git a/redisvl/llmcache/base.py b/redisvl/llmcache/base.py index a097abba..aa6af20b 100644 --- a/redisvl/llmcache/base.py +++ b/redisvl/llmcache/base.py @@ -39,6 +39,9 @@ def wrapper(*args, **kwargs): return response # Otherwise execute the llm callable here response = llm_callable(*args, **kwargs) + args = list(args) + args.append(response) + self.store(*args, **kwargs) return response return wrapper diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index 70279bb5..90d93485 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -7,11 +7,8 @@ from redisvl.providers import HuggingfaceProvider from redisvl.providers.base import BaseProvider from redisvl.query import create_vector_query -from redisvl.utils.log import get_logger from redisvl.utils.utils import array_to_buffer -_logger = get_logger(__name__) - class SemanticCache(BaseLLMCache): """Cache for Large Language Models.""" @@ -30,12 +27,12 @@ def __init__( index_name: str = "cache", prefix: str = "llmcache", threshold: float = 0.9, + ttl: Optional[int] = None, provider: Optional[BaseProvider] = None, redis_url: Optional[str] = "redis://localhost:6379", connection_args: Optional[dict] = None, - ttl: Optional[int] = None, ): - self.ttl = ttl + self._ttl = ttl self._provider = provider or self._default_provider self._threshold = threshold @@ -44,9 +41,18 @@ def __init__( index_name, prefix=prefix, fields=self._default_fields ) connection_args = connection_args or {} - self._index.connect(redis_url=redis_url, **connection_args) + self._index.connect(url=redis_url, **connection_args) self._index.create() + @property + def ttl(self) -> Optional[int]: + """Returns the TTL for the cache.""" + return self._ttl + + def set_ttl(self, ttl: int): + """Sets the TTL for the cache.""" + self._ttl = int(ttl) + @property def index(self) -> SearchIndex: """Returns the index for the cache.""" @@ -85,6 +91,7 @@ def check( cache_hits = [] for doc in results.docs: + self._refresh_ttl(doc.id) sim = similarity(doc.vector_score) if sim > self.threshold: cache_hits.append(doc.response) @@ -113,9 +120,10 @@ def store( def _refresh_ttl(self, key: str): """Refreshes the TTL for the specified key.""" - client = self._index.get_client() + client = self._index.client if client: - client.expire(key, self.ttl) + if self.ttl: + client.expire(key, self.ttl) else: raise RuntimeError("LLMCache is not connected to a Redis instance.") diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 3f8f0820..783cb2c9 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -53,3 +53,16 @@ def test_set_threshold(cache): cache.set_threshold(0.9) assert cache.threshold == 0.9 cache.index.delete(drop=True) + + +def test_wrapper(cache): + @cache.cache_response + def test_function(prompt): + return "This is a test response." + + # Check that the wrapper works + test_function("This is a test prompt.") + check_result = cache.check("This is a test prompt.") + assert len(check_result) >= 1 + assert "This is a test response." in check_result + cache.index.delete(drop=True) diff --git a/tests/test_index.py b/tests/test_index.py new file mode 100644 index 00000000..8a117495 --- /dev/null +++ b/tests/test_index.py @@ -0,0 +1,130 @@ +import pytest +import redis +from redis.commands.search.field import TagField + +from redisvl.index import AsyncSearchIndex, SearchIndex +from redisvl.utils.utils import convert_bytes + +fields = [TagField("test")] + + +def test_search_index_client(client): + si = SearchIndex("my_index", fields=fields) + si.set_client(client) + + assert si.client is not None + + +def test_search_index_create(client): + si = SearchIndex("my_index", fields=fields) + si.set_client(client) + si.create(overwrite=True) + + assert "my_index" in convert_bytes(si.client.execute_command("FT._LIST")) + + s1_2 = SearchIndex.from_existing(client, "my_index") + assert s1_2.info()["index_name"] == si.info()["index_name"] + + +def test_search_index_delete(client): + si = SearchIndex("my_index", fields=fields) + si.set_client(client) + si.create(overwrite=True) + si.delete() + + assert "my_index" not in convert_bytes(si.client.execute_command("FT._LIST")) + + +def test_search_index_load(client): + si = SearchIndex("my_index", fields=fields) + si.set_client(client) + si.create(overwrite=True) + data = [{"id": "1", "value": "test"}] + si.load(data) + + assert convert_bytes(client.hget(":1", "value")) == "test" + + +@pytest.mark.asyncio +async def test_async_search_index_creation(async_client): + asi = AsyncSearchIndex("my_index", fields=fields) + asi.set_client(async_client) + + assert asi.client == async_client + + +@pytest.mark.asyncio +async def test_async_search_index_create(async_client): + asi = AsyncSearchIndex("my_index", fields=fields) + asi.set_client(async_client) + await asi.create(overwrite=True) + + indices = await asi.client.execute_command("FT._LIST") + assert "my_index" in convert_bytes(indices) + + +@pytest.mark.asyncio +async def test_async_search_index_delete(async_client): + asi = AsyncSearchIndex("my_index", fields=fields) + asi.set_client(async_client) + await asi.create(overwrite=True) + await asi.delete() + + indices = await asi.client.execute_command("FT._LIST") + assert "my_index" not in convert_bytes(indices) + + +@pytest.mark.asyncio +async def test_async_search_index_load(async_client): + asi = AsyncSearchIndex("my_index", fields=fields) + asi.set_client(async_client) + await asi.create(overwrite=True) + data = [{"id": "1", "value": "test"}] + await asi.load(data) + result = await async_client.hget(":1", "value") + assert convert_bytes(result) == "test" + await asi.delete() + + +# --- Index Errors ---- + + +def test_search_index_delete_nonexistent(client): + si = SearchIndex("my_index") + si.set_client(client) + with pytest.raises(redis.exceptions.ResponseError): + si.delete() + + +@pytest.mark.asyncio +async def test_async_search_index_delete_nonexistent(async_client): + asi = AsyncSearchIndex("my_index") + asi.set_client(async_client) + with pytest.raises(redis.exceptions.ResponseError): + await asi.delete() + + +# --- Data Errors ---- + + +def test_no_key_field(client): + si = SearchIndex("my_index", fields=fields, key_field="key") + si.set_client(client) + si.create(overwrite=True) + bad_data = [{"wrong_key": "1", "value": "test"}] + + # TODO make a better error + with pytest.raises(KeyError): + si.load(bad_data) + + +@pytest.mark.asyncio +async def test_async_search_index_load_bad_data(async_client): + asi = AsyncSearchIndex("my_index", fields=fields) + asi.set_client(async_client) + await asi.create(overwrite=True) + + # dictionary not list of dictionaries + bad_data = {"wrong_key": "1", "value": "test"} + with pytest.raises(TypeError): + await asi.load(bad_data)