diff --git a/Makefile b/Makefile index 98787ac0..6266d0d0 100644 --- a/Makefile +++ b/Makefile @@ -77,17 +77,17 @@ servedocs: # help: test - Run all tests .PHONY: test test: - @python -m pytest + @python -m pytest --log-level=CRITICAL # help: test-verbose - Run all tests verbosely .PHONY: test-verbose test-verbose: - @python -m pytest -vv -s + @python -m pytest -vv -s --log-level=CRITICAL # help: test-cov - Run all tests with coverage .PHONY: test-cov test-cov: - @python -m pytest -vv --cov=./redisvl + @python -m pytest -vv --cov=./redisvl --log-level=CRITICAL # help: cov - generate html coverage report .PHONY: cov diff --git a/conftest.py b/conftest.py index 36d4e79e..622ae345 100644 --- a/conftest.py +++ b/conftest.py @@ -2,25 +2,20 @@ import pytest import asyncio -from redisvl.utils.connection import RedisConnection +from redisvl.redis.connection import RedisConnectionFactory -REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379") - -aredis = RedisConnection.get_async_redis_connection(REDIS_URL) -redis = RedisConnection.get_redis_connection(REDIS_URL) @pytest.fixture() def redis_url(): - return REDIS_URL + return os.getenv("REDIS_URL", "redis://localhost:6379") @pytest.fixture -def async_client(): - return aredis +def async_client(redis_url): + return RedisConnectionFactory.get_async_redis_connection(redis_url) @pytest.fixture -def client(): - return redis - +def client(redis_url): + return RedisConnectionFactory.get_redis_connection(redis_url) @pytest.fixture def openai_key(): @@ -38,6 +33,68 @@ def gcp_location(): def gcp_project_id(): return os.getenv("GCP_PROJECT_ID") + +@pytest.fixture +def sample_data(): + return [ + { + "user": "john", + "age": 18, + "job": "engineer", + "credit_score": "high", + "location": "-122.4194,37.7749", + "user_embedding": [0.1, 0.1, 0.5] + }, + { + "user": "mary", + "age": 14, + "job": "doctor", + "credit_score": "low", + "location": "-122.4194,37.7749", + "user_embedding": [0.1, 0.1, 0.5] + }, + { + "user": "nancy", + "age": 94, + "job": "doctor", + "credit_score": "high", + "location": "-122.4194,37.7749", + "user_embedding": [0.7, 0.1, 0.5] + }, + { + "user": "tyler", + "age": 100, + "job": "engineer", + "credit_score": "high", + "location": "-110.0839,37.3861", + "user_embedding": [0.1, 0.4, 0.5] + }, + { + "user": "tim", + "age": 12, + "job": "dermatologist", + "credit_score": "high", + "location": "-110.0839,37.3861", + "user_embedding": [0.4, 0.4, 0.5] + }, + { + "user": "taimur", + "age": 15, + "job": "CEO", + "credit_score": "low", + "location": "-110.0839,37.3861", + "user_embedding": [0.6, 0.1, 0.5] + }, + { + "user": "joe", + "age": 35, + "job": "dentist", + "credit_score": "medium", + "location": "-110.0839,37.3861", + "user_embedding": [0.9, 0.9, 0.1] + }, +] + @pytest.fixture(scope="session") def event_loop(): try: @@ -48,7 +105,7 @@ def event_loop(): loop.close() @pytest.fixture -def clear_db(): +def clear_db(redis): redis.flushall() yield redis.flushall() \ No newline at end of file diff --git a/docs/_static/gallery.yaml b/docs/_static/gallery.yaml index 1206a088..855dbdbd 100644 --- a/docs/_static/gallery.yaml +++ b/docs/_static/gallery.yaml @@ -1,6 +1,6 @@ - title: Arxiv Paper Search - website: https://docsearch.redisventures.com + website: https://docsearch.redisvl.com img-bottom: ../_static/gallery-images/arxiv-search.png - title: Real-Time Embeddings with Redis and Bytewax website: https://github.com/awmatheson/real-time-embeddings diff --git a/docs/_static/js/sidebar.js b/docs/_static/js/sidebar.js index 33c004ab..5f04b2dc 100644 --- a/docs/_static/js/sidebar.js +++ b/docs/_static/js/sidebar.js @@ -2,7 +2,7 @@ const toc = [ { header: "Overview", toc: [ { title: "RedisVL", path: "/index.html" }, { title: "Install", path: "/overview/installation.html" }, - { title: "CLI", path: "/user_guide/cli.html" }, + { title: "CLI", path: "/overview/cli.html" }, ]}, { header: "User Guides", toc: [ { title: "Getting Started", path: "/user_guide/getting_started_01.html" }, @@ -10,11 +10,10 @@ const toc = [ { title: "JSON vs Hash Storage", path: "/user_guide/hash_vs_json_05.html" }, { title: "Vectorizers", path: "/user_guide/vectorizers_04.html" }, { title: "Semantic Caching", path: "/user_guide/llmcache_03.html" }, - ]}, { header: "API", toc: [ - { title: "Schema", path: "/api/indexschema.html"}, - { title: "Index", path: "/api/searchindex.html" }, + { title: "Schema", path: "/api/schema.html"}, + { title: "Search Index", path: "/api/searchindex.html" }, { title: "Query", path: "/api/query.html" }, { title: "Filter", path: "/api/filter.html" }, { title: "Vectorizers", path: "/api/vectorizer.html" }, diff --git a/docs/api/index.md b/docs/api/index.md index 29d6844b..613e5542 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -11,7 +11,7 @@ myst: :caption: RedisVL :maxdepth: 2 -indexschema +schema searchindex query filter diff --git a/docs/api/indexschema.rst b/docs/api/schema.rst similarity index 100% rename from docs/api/indexschema.rst rename to docs/api/schema.rst diff --git a/docs/api/searchindex.rst b/docs/api/searchindex.rst index b01235da..11a5f930 100644 --- a/docs/api/searchindex.rst +++ b/docs/api/searchindex.rst @@ -1,6 +1,17 @@ -*********** -SearchIndex -*********** +******************** +Search Index Classes +******************** + +.. list-table:: + :widths: 25 75 + :header-rows: 1 + + * - Class + - Description + * - `SearchIndex <#searchindex_api>`_ + - Primary class to write, read, and search across data structures in Redis. + * - `AsyncSearchIndex <#asyncsearchindex_api>`_ + - Async version of the SearchIndex to write, read, and search across data structures in Redis. SearchIndex =========== @@ -9,36 +20,19 @@ SearchIndex .. currentmodule:: redisvl.index -.. autosummary:: - - SearchIndex.from_yaml - SearchIndex.from_dict - SearchIndex.client - SearchIndex.name - SearchIndex.prefix - SearchIndex.key_separator - SearchIndex.storage_type - SearchIndex.connect - SearchIndex.disconnect - SearchIndex.set_client - SearchIndex.create - SearchIndex.acreate - SearchIndex.exists - SearchIndex.aexists - SearchIndex.load - SearchIndex.aload - SearchIndex.search - SearchIndex.asearch - SearchIndex.query - SearchIndex.aquery - SearchIndex.query_batch - SearchIndex.aquery_batch - SearchIndex.delete - SearchIndex.adelete - SearchIndex.info - SearchIndex.ainfo - .. autoclass:: SearchIndex :show-inheritance: :inherited-members: :members: + +AsyncSearchIndex +================ + +.. _asyncsearchindex_api: + +.. currentmodule:: redisvl.index + +.. autoclass:: AsyncSearchIndex + :show-inheritance: + :inherited-members: + :members: diff --git a/docs/examples/openai_qna.ipynb b/docs/examples/openai_qna.ipynb index da2ac347..9607947c 100644 --- a/docs/examples/openai_qna.ipynb +++ b/docs/examples/openai_qna.ipynb @@ -13,7 +13,7 @@ "2. Create embeddings for each article\n", "3. Create a RedisVL index and store the embeddings with metadata\n", "4. Construct a simple QnA system using the index and GPT-3\n", - "5. Improve the QnA system with LLMCache\n", + "5. Improve the QnA system with LLM caching\n", "\n", "\n", "The image below shows the architecture of the system we will create in this notebook.\n", @@ -27,10 +27,10 @@ "source": [ "## Setup\n", "\n", - "In order to run this example, you will need to have a Redis instance with RediSearch running locally. You can do this by running the following command in your terminal:\n", + "In order to run this example, you will need to have a Redis Stack running locally (or spin up for free on [Redis Cloud](https://redis.com/try-free)). You can do this by running the following command in your terminal:\n", "\n", "```bash\n", - "docker run --name redis-vecdb -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest\n", + "docker run --name redis -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest\n", "```\n", "\n", "This will also provide the RedisInsight GUI at http://localhost:8001\n", @@ -40,41 +40,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: pandas in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (2.1.4)\n", - "Requirement already satisfied: wget in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (3.2)\n", - "Requirement already satisfied: tenacity in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (8.2.2)\n", - "Requirement already satisfied: tiktoken in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (0.5.2)\n", - "Requirement already satisfied: openai==0.28.1 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (0.28.1)\n", - "Requirement already satisfied: requests>=2.20 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from openai==0.28.1) (2.31.0)\n", - "Requirement already satisfied: tqdm in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from openai==0.28.1) (4.66.1)\n", - "Requirement already satisfied: aiohttp in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from openai==0.28.1) (3.9.1)\n", - "Requirement already satisfied: numpy<2,>=1.22.4 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from pandas) (1.26.2)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from pandas) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from pandas) (2023.3.post1)\n", - "Requirement already satisfied: tzdata>=2022.1 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from pandas) (2023.3)\n", - "Requirement already satisfied: regex>=2022.1.18 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from tiktoken) (2023.10.3)\n", - "Requirement already satisfied: six>=1.5 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from requests>=2.20->openai==0.28.1) (3.3.2)\n", - "Requirement already satisfied: idna<4,>=2.5 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from requests>=2.20->openai==0.28.1) (3.6)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from requests>=2.20->openai==0.28.1) (2.1.0)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from requests>=2.20->openai==0.28.1) (2023.11.17)\n", - "Requirement already satisfied: attrs>=17.3.0 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from aiohttp->openai==0.28.1) (23.1.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from aiohttp->openai==0.28.1) (6.0.4)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from aiohttp->openai==0.28.1) (1.9.4)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from aiohttp->openai==0.28.1) (1.4.1)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from aiohttp->openai==0.28.1) (1.3.1)\n", - "Requirement already satisfied: async-timeout<5.0,>=4.0 in /Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages (from aiohttp->openai==0.28.1) (4.0.3)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], + "outputs": [], "source": [ "# first we need to install a few things\n", "\n", @@ -83,20 +51,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'wikipedia_articles_2000.csv'" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import wget\n", "import pandas as pd\n", @@ -108,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -194,7 +151,7 @@ "4 German Empire The German Empire (\"Deutsches Reich\" or \"Deuts... " ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -224,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -280,7 +237,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -295,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -394,7 +351,7 @@ "4 0 " ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -415,7 +372,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -509,7 +466,7 @@ " Alanis Morissette\n", " Title: Alanis Morissette;\\nTwin people from Ca...\n", " 1\n", - " b'\\xa4\\xc15\\xbcuw\\xe7\\xbc\\xea\\xe4\\x16\\xbb\\x87\\...\n", + " b'Ii4\\xbc\\x8e>\\xe0\\xbc\\x18]\\x07\\xbb%\\xa0\\x92\\x...\n", " \n", " \n", " 2689\n", @@ -518,7 +475,7 @@ " Brontosaurus\n", " Title: Brontosaurus;\\nBrontosaurus is a genus...\n", " 0\n", - " b'3\\xf0\\xda\\xbcY\\xc0\\xb4:\\x1cN\\x81\\xbc\\xe9\\xcc...\n", + " b'\\xad\\xa5\\xdb\\xbc\\xa5\\xa5\\xba:\\xb4\"\\x81\\xbc\\x...\n", " \n", " \n", " 2690\n", @@ -599,8 +556,8 @@ "3 b'\\xa4\\xba\\xf5\\xbcS\\xf3\\x02\\xbc\\xa1\\x15O\\xbc\\x... \n", "4 b'0(\\xfa\\xbb\\x81\\xd2\\xd9;\\xaf\\x92\\x9a;\\xd3FL\\x... \n", "... ... \n", - "2688 b'\\xa4\\xc15\\xbcuw\\xe7\\xbc\\xea\\xe4\\x16\\xbb\\x87\\... \n", - "2689 b'3\\xf0\\xda\\xbcY\\xc0\\xb4:\\x1cN\\x81\\xbc\\xe9\\xcc... \n", + "2688 b'Ii4\\xbc\\x8e>\\xe0\\xbc\\x18]\\x07\\xbb%\\xa0\\x92\\x... \n", + "2689 b'\\xad\\xa5\\xdb\\xbc\\xa5\\xa5\\xba:\\xb4\"\\x81\\xbc\\x... \n", "2690 b'\\x97\\x82\\xb9\\xbbL\\x90d\\xbc\\xb7G\\x9c\\xba\\x94g... \n", "2691 b'\\xe4\\xa3\\x1c:\\x83g\\x90<\\x99=s;*[E\\xbb\\x10 \"\\... \n", "2692 b'T,-\\xbbS\\xe5\\x87;\\x1c\\x0f\\x9d:\\xc4\\xd4\\xcd:\\... \n", @@ -608,7 +565,7 @@ "[2693 rows x 6 columns]" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -644,7 +601,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -681,33 +638,35 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import redis.asyncio as redis\n", "\n", - "from redisvl.index import SearchIndex\n", + "from redisvl.index import AsyncSearchIndex\n", + "from redisvl.schema import IndexSchema\n", + "\n", "\n", "client = redis.Redis.from_url(\"redis://localhost:6379\")\n", + "schema = IndexSchema.from_yaml(\"wiki_schema.yaml\")\n", "\n", - "index = SearchIndex.from_yaml(\"wiki_schema.yaml\")\n", - "index.set_client(client)\n", + "index = AsyncSearchIndex(schema, client)\n", "\n", - "await index.acreate()" + "await index.create()" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m21:51:38\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m21:51:38\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. wiki\n" + "\u001b[32m16:00:26\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m16:00:26\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. wikipedia\n" ] } ], @@ -724,11 +683,11 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ - "await index.aload(chunked_data.to_dict(orient=\"records\"))" + "keys = await index.load(chunked_data.to_dict(orient=\"records\"))" ] }, { @@ -746,7 +705,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -757,7 +716,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -779,9 +738,9 @@ " '''\n", " return retrieval_prompt\n", "\n", - "async def retrieve_context(index: SearchIndex, query: str):\n", + "async def retrieve_context(index: AsyncSearchIndex, query: str):\n", " # Embed the query\n", - " query_embedding = oaip.embed(query)\n", + " query_embedding = await oaip.aembed(query)\n", "\n", " # Get the top result from the index\n", " vector_query = VectorQuery(\n", @@ -791,22 +750,22 @@ " num_results=1\n", " )\n", "\n", - " results = await index.aquery(vector_query)\n", + " results = await index.query(vector_query)\n", " content = \"\"\n", " if len(results) > 1:\n", " content = results[0][\"content\"]\n", " return content\n", "\n", - "async def answer_question(index: SearchIndex, query: str):\n", + "async def answer_question(index: AsyncSearchIndex, query: str):\n", " # Retrieve the context\n", " content = await retrieve_context(index, query)\n", "\n", " prompt = make_prompt(query, content)\n", " retrieval = await openai.ChatCompletion.acreate(\n", " model=CHAT_MODEL,\n", - " messages=[{'role':\"user\",\n", - " 'content': prompt}],\n", - " max_tokens=500)\n", + " messages=[{'role':\"user\", 'content': prompt}],\n", + " max_tokens=50\n", + " )\n", "\n", " # Response provided by GPT-3.5\n", " return retrieval['choices'][0]['message']['content']" @@ -814,20 +773,18 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['A Brontosaurus, also known as Apatosaurus, is a dinosaur that lived during the',\n", - " 'Late Jurassic Period, around 150 million years ago. It was a herbivorous',\n", - " 'dinosaur with a long neck, small head, and a large, thick body. Brontosaurus is',\n", - " 'one of the largest land animals known to have existed, reaching lengths of over',\n", - " '70 feet and weighing up to 30 tons.']" + "['A Brontosaurus, also known as Apatosaurus, is a type of large, long-necked',\n", + " 'dinosaur that lived during the Late Jurassic Period, about 150 million years',\n", + " 'ago. They were herbivores and belonged to the saurop']" ] }, - "execution_count": 15, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -841,7 +798,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -850,7 +807,7 @@ "\"I don't know.\"" ] }, - "execution_count": 16, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -863,26 +820,19 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['Alanis Morissette is a Canadian-American singer-songwriter. She was',\n", - " 'born on June 1, 1974, in Ottawa, Canada. Morissette began her career',\n", - " 'as a child actress and later gained fame as a pop-rock artist in the',\n", - " '1990s. Her breakthrough album, \"Jagged Little Pill,\" was released in',\n", - " '1995 and became a global hit, earning numerous awards and accolades.',\n", - " 'Morissette is known for her introspective and honest lyrics,',\n", - " 'addressing themes of love, heartbreak, and self-discovery. She has',\n", - " 'released several successful albums and has continued to evolve her',\n", - " 'musical style over the years. In addition to her music career,',\n", - " 'Morissette is also an actress and has made appearances in film and',\n", - " 'television.']" + "['Alanis Morissette is a Canadian-American singer-songwriter and',\n", + " 'actress. She gained international fame with her third studio album,',\n", + " '\"Jagged Little Pill,\" released in 1995. The album went on to become a',\n", + " 'massive success, selling over']" ] }, - "execution_count": 17, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -898,12 +848,14 @@ "source": [ "## Improve the QnA System with LLM caching\n", "\n", - "The QnA system we built above is pretty good, but it can be improved. We can use the `SemanticCache` to improve the system. The ``SemanticCache`` will store the results of previous queries and return them if the query is similar enough to a previous query. This will reduce the number of queries we need to send to the OpenAI API and increase the overall QPS of the system assuming we expect similar queries to be asked." + "The QnA system we built above is pretty good, but we can use the `SemanticCache` to improve the throughput and stability. The ``SemanticCache`` will store the results of previous queries and return them if the query is similar enough to a previous query. This will reduce the number of round trip queries we need to send to the OpenAI API.\n", + "\n", + "> Note this technique will work assuming we expect a similar profile of queries to be asked." ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -914,15 +866,14 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ - "async def answer_question(index: SearchIndex, query: str):\n", + "async def answer_question(index: AsyncSearchIndex, query: str):\n", "\n", " # check the cache\n", - " result = cache.check(prompt=query)\n", - " if result:\n", + " if result := cache.check(prompt=query):\n", " return result[0]['response']\n", "\n", " # Retrieve the context\n", @@ -931,9 +882,9 @@ " prompt = make_prompt(query, content)\n", " retrieval = await openai.ChatCompletion.acreate(\n", " model=CHAT_MODEL,\n", - " messages=[{'role':\"user\",\n", - " 'content': prompt}],\n", - " max_tokens=500)\n", + " messages=[{'role':\"user\", 'content': prompt}],\n", + " max_tokens=500\n", + " )\n", "\n", " # Response provided by GPT-3.5\n", " answer = retrieval['choices'][0]['message']['content']\n", @@ -945,40 +896,43 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Time taken: 10.225465774536133\n", + "Time taken: 6.253775119781494\n", "\n" ] }, { "data": { "text/plain": [ - "['Alanis Morissette is a Canadian-American singer, songwriter, and actress. She',\n", - " 'rose to fame in the 1990s with her breakthrough album \"Jagged Little Pill,\"',\n", - " 'which became one of the best-selling albums of all time. Born on June 1, 1974,',\n", - " 'in Ottawa, Ontario, Morissette began her career as a teen pop star in Canada',\n", - " 'before transitioning to alternative rock. Throughout her career, Morissette has',\n", - " 'released several successful albums and has won numerous awards, including',\n", - " 'multiple Grammy Awards. Her music often explores themes of female empowerment,',\n", - " 'personal introspection, and social commentary. Some of her notable songs include',\n", - " '\"Ironic,\" \"You Oughta Know,\" and \"Hand in My Pocket.\" In addition to her music',\n", - " 'career, Morissette has also acted in various films and television shows. She is',\n", - " 'known for her roles in movies such as \"Dogma\" and \"Jay and Silent Bob Strike',\n", - " 'Back.\" Morissette has been transparent about her personal struggles, including',\n", - " 'her experiences with eating disorders, depression, and postpartum depression.',\n", - " 'She has used her platform to advocate for mental health awareness and has been',\n", - " 'involved in various charitable causes. Overall, Alanis Morissette has had a',\n", - " 'successful and influential career in the music industry while also making an',\n", - " 'impact beyond music.']" + "['Alanis Morissette is a Canadian singer, songwriter, and actress. She was born on',\n", + " 'June 1, 1974, in Ottawa, Ontario, Canada. Morissette began her career in the',\n", + " 'music industry as a child, releasing her first album \"Alanis\" in 1991. However,',\n", + " 'it was her third studio album, \"Jagged Little Pill,\" released in 1995, that',\n", + " 'brought her international fame and critical acclaim. The album sold over 33',\n", + " 'million copies worldwide and produced hit singles such as \"You Oughta Know,\"',\n", + " '\"Ironic,\" and \"Hand in My Pocket.\" Throughout her career, Morissette has',\n", + " 'continued to release successful albums and has received numerous awards,',\n", + " 'including Grammy Awards, Juno Awards, and Billboard Music Awards. Her music',\n", + " 'often explores themes of love, relationships, self-discovery, and spirituality.',\n", + " 'Some of her other notable albums include \"Supposed Former Infatuation Junkie,\"',\n", + " '\"Under Rug Swept,\" and \"Flavors of Entanglement.\" In addition to her music',\n", + " 'career, Alanis Morissette has also ventured into acting. She has appeared in',\n", + " 'films such as \"Dogma\" and \"Radio Free Albemuth,\" as well as on television shows',\n", + " 'like \"Weeds\" and \"Sex and the City.\" Offstage, Morissette has been open about',\n", + " 'her struggles with mental health and has become an advocate for mental wellness.',\n", + " 'She has also expressed her views on feminism and spirituality in her music and',\n", + " 'interviews. Overall, Alanis Morissette has had a successful and influential',\n", + " 'career in the music industry, with her powerful and emotional songs resonating',\n", + " 'with audiences around the world.']" ] }, - "execution_count": 20, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -995,7 +949,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1043,7 +997,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1091,6 +1045,16 @@ "print(f\"Time taken with the cache: {time.time() - start}\\n\")\n", "textwrap.wrap(answer, width=80)" ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "# Cleanup\n", + "await index.delete()" + ] } ], "metadata": { diff --git a/docs/index.md b/docs/index.md index 36970c4c..cb15750a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -15,17 +15,19 @@ RedisVL provides a powerful, dedicated Python client library for using Redis as :grid-columns: 1 2 2 3 - header: "{fab}`bootstrap;pst-color-primary` Index Management" - content: "Manipulate Redis search indices in Python or from CLI." -- header: "{fas}`bolt;pst-color-primary` Vector Similarity Search" - content: "Perform powerful vector similarity search with filtering support." + content: "Design search schema and indices with ease from YAML, with Python, or from the CLI." +- header: "{fas}`bolt;pst-color-primary` Advanced Vector Search" + content: "Perform powerful vector search queries with complex filtering support." - header: "{fas}`circle-half-stroke;pst-color-primary` Embedding Creation" content: "Use OpenAI or any of the other supported vectorizers to create embeddings." + link: "user_guide/vectorizers_04" - header: "{fas}`palette;pst-color-primary` CLI" - content: "Interact with RedisVL using a Command line interface (CLI) for ease of use." + content: "Interact with RedisVL using a Command Line Interface (CLI) for ease of use." - header: "{fab}`python;pst-color-primary` Semantic Caching" - content: "Use RedisVL to cache LLM results, increasing QPS and decreasing cost." + content: "Extend RedisVL to cache LLM results, increasing QPS and decreasing system cost." + link: "user_guide/llmcache_03" - header: "{fas}`lightbulb;pst-color-primary` Example Gallery" - content: "Explore our gallery of examples to get started." + content: "Explore the gallery of examples to get started." link: "examples/index" ``` @@ -40,16 +42,19 @@ pip install redisvl Then make sure to have [Redis](https://redis.io) accessible with Search & Query features enabled on [Redis Cloud](https://redis.com/try-free) or locally in docker with [Redis Stack](https://redis.io/docs/getting-started/install-stack/docker/): ```bash -docker run -d --name redis-stack -p 6379:6379 -p 8001:8001 redis/redis-stack:latest +docker run -d --name redis -p 6379:6379 -p 8001:8001 redis/redis-stack:latest ``` This will also spin up the [Redis Insight GUI](https://redis.com/redis-enterprise/redis-insight/) at `http://localhost:8001`. +> Read more about `redisvl` installation [here](https://redisvl.com/overview/installation.html) + ## Table of Contents ```{toctree} :maxdepth: 2 +Overview User Guides Example Gallery API diff --git a/docs/user_guide/cli.ipynb b/docs/overview/cli.ipynb similarity index 99% rename from docs/user_guide/cli.ipynb rename to docs/overview/cli.ipynb index 3a8245ec..299e6e1e 100644 --- a/docs/user_guide/cli.ipynb +++ b/docs/overview/cli.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Redis Vector Library CLI (``rvl``)\n", + "# ``rvl`` ~ The RedisVL CLI\n", "\n", "RedisVL is a Python library with a dedicated CLI to help load and create vector search indices within Redis.\n", "\n", diff --git a/docs/overview/index.md b/docs/overview/index.md new file mode 100644 index 00000000..a4d290cd --- /dev/null +++ b/docs/overview/index.md @@ -0,0 +1,17 @@ +--- +myst: + html_meta: + "description lang=en": | + User Guides for RedisVL +--- + +# Overview + + +```{toctree} +:maxdepth: 2 + +installation +cli +``` + diff --git a/docs/overview/installation.md b/docs/overview/installation.md index 9ecd41fc..5416c23d 100644 --- a/docs/overview/installation.md +++ b/docs/overview/installation.md @@ -9,6 +9,8 @@ myst: There are a few ways to install RedisVL. The easiest way is to use pip. +## Install RedisVL with Pip + Install `redisvl` into your Python (>=3.8) environment using `pip`: ```bash diff --git a/docs/user_guide/getting_started_01.ipynb b/docs/user_guide/getting_started_01.ipynb index 37464af8..9d08a236 100644 --- a/docs/user_guide/getting_started_01.ipynb +++ b/docs/user_guide/getting_started_01.ipynb @@ -215,7 +215,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 4, @@ -249,7 +249,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -304,8 +304,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m20:11:48\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m20:11:48\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_simple\n" + "\u001b[32m16:13:33\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m16:13:33\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_simple\n" ] } ], @@ -365,7 +365,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "['user_simple_docs:71cb417a3675404889a8a22255f482d0', 'user_simple_docs:3eda3f6f640144a086149ad36d2e8419', 'user_simple_docs:aa9195acd07f41a485477eb3cb333bb8']\n" + "['user_simple_docs:297be8ec3c6444a4b73c10e77daadb4a', 'user_simple_docs:ac0cc4c7ee4d4cd18e9002dbaf1b5cbc', 'user_simple_docs:6c746e3f02d94d9087e0d207cfed5701']\n" ] } ], @@ -379,7 +379,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - ">By default, `load` will create a unique Redis \"key\" as a combination of the index key `prefix` and a UUID. You can also customize the key by providing direct keys or pointing to a specified key_field on load." + ">By default, `load` will create a unique Redis \"key\" as a combination of the index key `prefix` and a UUID. You can also customize the key by providing direct keys or pointing to a specified `key_field` on load." ] }, { @@ -400,7 +400,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Fetching data for user 71cb417a3675404889a8a22255f482d0\n" + "Fetching data for user 297be8ec3c6444a4b73c10e77daadb4a\n" ] }, { @@ -442,7 +442,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "['user_simple_docs:2747d2e5355d4b0fbb994a9b37518bcc']\n" + "['user_simple_docs:714e5ec6d4a946c082fe006d311e8d49']\n" ] } ], @@ -550,14 +550,12 @@ } ], "source": [ - "index = SearchIndex.from_dict(\n", - " schema,\n", - " redis_url=\"redis://localhost:6379\",\n", - " use_async=True\n", - ")\n", + "from redisvl.index import AsyncSearchIndex\n", + "\n", + "index = AsyncSearchIndex.from_dict(schema, redis_url=\"redis://localhost:6379\")\n", "\n", "# execute the vector query async\n", - "results = await index.aquery(query)\n", + "results = await index.query(query)\n", "result_print(results)" ] }, @@ -569,16 +567,6 @@ "In some scenarios, it makes sense to update the index schema. With Redis and `redisvl`, this is easy because Redis can keep the underlying data in place while you change or make updates to the index configuration." ] }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "# First we will clean up the existing index yet keep docs in place\n", - "await index.adelete(drop=False)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -590,7 +578,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -609,7 +597,7 @@ " 'datatype': 'float32'}}]}" ] }, - "execution_count": 16, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -621,7 +609,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -647,23 +635,31 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16:13:34 redisvl.index INFO Index already exists, overwriting.\n" + ] + } + ], "source": [ - "# Run the index update\n", - "await index.acreate()" + "# Run the index update but keep underlying data in place\n", + "await index.create(overwrite=True, drop=False)" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceuseragejobcredit_score
0john1engineerhigh
0mary2doctorlow
0.0566299557686tyler9engineerhigh
" + "
vector_distanceuseragejobcredit_score
0mary2doctorlow
0john1engineerhigh
0.0566299557686tyler9engineerhigh
" ], "text/plain": [ "" @@ -675,7 +671,7 @@ ], "source": [ "# Execute the vector query async\n", - "results = await index.aquery(query)\n", + "results = await index.query(query)\n", "result_print(results)" ] }, @@ -689,7 +685,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -717,7 +713,7 @@ "│ offsets_per_term_avg │ 0 │\n", "│ records_per_doc_avg │ 5 │\n", "│ sortable_values_size_mb │ 0 │\n", - "│ total_indexing_time │ 0.624 │\n", + "│ total_indexing_time │ 0.138 │\n", "│ total_inverted_index_blocks │ 11 │\n", "│ vector_index_sz_mb │ 0.0201416 │\n", "╰─────────────────────────────┴─────────────╯\n" @@ -737,12 +733,12 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# clean up the index\n", - "await index.adelete()" + "await index.delete()" ] } ], diff --git a/docs/user_guide/hash_vs_json_05.ipynb b/docs/user_guide/hash_vs_json_05.ipynb index c486d4a9..b8eb3d26 100644 --- a/docs/user_guide/hash_vs_json_05.ipynb +++ b/docs/user_guide/hash_vs_json_05.ipynb @@ -34,7 +34,7 @@ "# import necessary modules\n", "import pickle\n", "\n", - "from redisvl.utils.utils import buffer_to_array\n", + "from redisvl.redis.utils import buffer_to_array\n", "from jupyterutils import result_print, table_print\n", "from redisvl.index import SearchIndex\n", "\n", diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb index d4c48867..74133c42 100644 --- a/docs/user_guide/hybrid_queries_02.ipynb +++ b/docs/user_guide/hybrid_queries_02.ipynb @@ -101,8 +101,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m20:12:24\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m20:12:24\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" + "\u001b[32m16:13:58\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m16:13:58\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" ] } ], @@ -1111,10 +1111,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'id': 'user_queries_docs:9f8ae1d270e642d89e41b5f512e35cc7', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:45ab38080206444f994d59ee11d13a9c', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:5b4b0b33e88447108eabd3b0f54a1fb2', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:7bf2ecb23e314a3f98245f2c07418f64', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" + "{'id': 'user_queries_docs:6ae49de28548476ea4896f6cdb35f617', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:14af692b215a402580991c1ea464df9d', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:435f00b6002c4ce4b08303bebacfda76', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:22bcd56418c744009052c3391c9b6a78', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" ] } ], @@ -1163,6 +1163,16 @@ "# Using the str() method, you can see what Redis Query this will emit.\n", "str(v)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cleanup\n", + "index.delete()" + ] } ], "metadata": { diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index 458093e8..26b85e9c 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -9,7 +9,7 @@ myst: ```{toctree} -:maxdepth: 5 +:maxdepth: 2 getting_started_01 hybrid_queries_02 diff --git a/docs/user_guide/llmcache_03.ipynb b/docs/user_guide/llmcache_03.ipynb index 288c7b0b..a33100d9 100644 --- a/docs/user_guide/llmcache_03.ipynb +++ b/docs/user_guide/llmcache_03.ipynb @@ -6,7 +6,7 @@ "source": [ "# Semantic Caching for LLMs\n", "\n", - "RedisVL provides an ``SemanticCache`` interface to turn Redis into a semantic cache to store responses to previously asked questions. This reduces the number of requests and tokens sent to the Large Language Models (LLM) service, decreasing costs and enhancing application throughput (by reducing the time taken to generate responses).\n", + "RedisVL provides an ``SemanticCache`` interface utilize Redis' built-in caching capabilities AND vector search in order to store responses from previously-answered questions. This reduces the number of requests and tokens sent to the Large Language Models (LLM) service, decreasing costs and enhancing application throughput (by reducing the time taken to generate responses).\n", "\n", "This notebook will go over how to use Redis as a Semantic Cache for your applications" ] @@ -27,6 +27,9 @@ "import os\n", "import openai\n", "import getpass\n", + "import time\n", + "\n", + "import numpy as np\n", "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"False\"\n", "\n", @@ -72,18 +75,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/tyler.hutcherson/RedisVentures/redisvl/.venv/lib/python3.9/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", - " return self.fget.__get__(instance, owner)()\n" - ] - } - ], + "outputs": [], "source": [ "from redisvl.llmcache import SemanticCache\n", "\n", @@ -198,6 +192,13 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we will check the cache again with the same question and with a semantically similar question:" + ] + }, { "cell_type": "code", "execution_count": 8, @@ -241,6 +242,18 @@ "llmcache.check(prompt=question)[0]['response']" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Customize the Distance Threshhold\n", + "\n", + "For most use cases, the right semantic similarity threshhold is not a fixed quantity. Depending on the choice of embedding model,\n", + "the properties of the input query, and even business use case -- the threshhold might need to change. \n", + "\n", + "Fortunately, you can seamlessly adjust the threshhold at any point like below:" + ] + }, { "cell_type": "code", "execution_count": 10, @@ -302,9 +315,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Simple performance testing\n", + "## Utilize TTL\n", "\n", - "Next, we will measure the speedup obtained by using ``SemanticCache``. We will use the ``time`` module to measure the time taken to generate responses with and without ``SemanticCache``." + "Redis uses TTL policies (optional) to expire individual keys at points in time in the future.\n", + "This allows you to focus on your data flow and business logic without bothering with complex cleanup tasks.\n", + "\n", + "A TTL policy set on the `SemanticCache` allows you to temporarily hold onto cache entries. Below, we will set the TTL policy to 5 seconds." ] }, { @@ -313,9 +329,65 @@ "metadata": {}, "outputs": [], "source": [ - "import time\n", + "llmcache.set_ttl(5) # 5 seconds" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "llmcache.store(\"This is a TTL test\", \"This is a TTL test response\")\n", "\n", + "time.sleep(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[]\n" + ] + } + ], + "source": [ + "# confirm that the cache has cleared by now on it's own\n", + "result = llmcache.check(\"This is a TTL test\")\n", "\n", + "print(result)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# Reset the TTL to null (long lived data)\n", + "llmcache.set_ttl()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simple Performance Testing\n", + "\n", + "Next, we will measure the speedup obtained by using ``SemanticCache``. We will use the ``time`` module to measure the time taken to generate responses with and without ``SemanticCache``." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ "def answer_question(question: str) -> str:\n", " \"\"\"Helper function to answer a simple question using OpenAI with a wrapper\n", " check for the answer in the semantic cache first.\n", @@ -336,71 +408,72 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Without caching, a call to openAI to answer this simple question took 0.7188658714294434 seconds.\n" + "Without caching, a call to openAI to answer this simple question took 0.984370231628418 seconds.\n" ] - } - ], - "source": [ - "start = time.time()\n", - "# asking a question -- openai response time\n", - "answer = answer_question(\"What is the capital of France?\")\n", - "end = time.time()\n", - "\n", - "print(f\"Without caching, a call to openAI to answer this simple question took {end-start} seconds.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ + }, { "data": { "text/plain": [ - "'llmcache:115049a298532be2f181edb03f766770c0db84c22aff39003fec340deaec7545'" + "'llmcache:67e0f6e28fe2a61c0022fd42bf734bb8ffe49d3e375fd69d692574295a20fc1a'" ] }, - "execution_count": 15, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "llmcache.store(prompt=\"What is the capital of France?\", response=\"Paris\")" + "start = time.time()\n", + "# asking a question -- openai response time\n", + "question = \"What was the name of the first US President?\"\n", + "answer = answer_question(question)\n", + "end = time.time()\n", + "\n", + "print(f\"Without caching, a call to openAI to answer this simple question took {end-start} seconds.\")\n", + "\n", + "# add the entry to our LLM cache\n", + "llmcache.store(prompt=question, response=\"George Washington\")" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Time Taken with cache enabled: 0.4411182403564453\n", - "Percentage of time saved: 38.64%\n" + "Avg time taken with LLM cache enabled: 0.5094501972198486\n", + "Percentage of time saved: 48.25%\n" ] } ], "source": [ - "cached_start = time.time()\n", - "cached_answer = answer_question(\"What is the capital of France?\")\n", - "cached_end = time.time()\n", - "print(f\"Time Taken with cache enabled: {cached_end - cached_start}\")\n", - "print(f\"Percentage of time saved: {round(((end - start) - (cached_end - cached_start)) / (end - start) * 100, 2)}%\")" + "# Calculate the avg latency for caching over LLM usage\n", + "times = []\n", + "\n", + "for _ in range(10):\n", + " cached_start = time.time()\n", + " cached_answer = answer_question(question)\n", + " cached_end = time.time()\n", + " times.append(cached_end-cached_start)\n", + "\n", + "avg_time_with_cache = np.mean(times)\n", + "print(f\"Avg time taken with LLM cache enabled: {avg_time_with_cache}\")\n", + "print(f\"Percentage of time saved: {round(((end - start) - avg_time_with_cache) / (end - start) * 100, 2)}%\")" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -412,25 +485,25 @@ "╭─────────────────────────────┬─────────────╮\n", "│ Stat Key │ Value │\n", "├─────────────────────────────┼─────────────┤\n", - "│ num_docs │ 1 │\n", - "│ num_terms │ 7 │\n", - "│ max_doc_id │ 2 │\n", - "│ num_records │ 16 │\n", + "│ num_docs │ 0 │\n", + "│ num_terms │ 19 │\n", + "│ max_doc_id │ 5 │\n", + "│ num_records │ 36 │\n", "│ percent_indexed │ 1 │\n", "│ hash_indexing_failures │ 0 │\n", - "│ number_of_uses │ 9 │\n", - "│ bytes_per_record_avg │ 5.25 │\n", - "│ doc_table_size_mb │ 0.000134468 │\n", - "│ inverted_sz_mb │ 8.01086e-05 │\n", - "│ key_table_size_mb │ 2.76566e-05 │\n", + "│ number_of_uses │ 40 │\n", + "│ bytes_per_record_avg │ 5.27778 │\n", + "│ doc_table_size_mb │ 0 │\n", + "│ inverted_sz_mb │ 0.000181198 │\n", + "│ key_table_size_mb │ 0 │\n", "│ offset_bits_per_record_avg │ 8 │\n", - "│ offset_vectors_sz_mb │ 1.33514e-05 │\n", - "│ offsets_per_term_avg │ 0.875 │\n", - "│ records_per_doc_avg │ 16 │\n", + "│ offset_vectors_sz_mb │ 3.33786e-05 │\n", + "│ offsets_per_term_avg │ 0.972222 │\n", + "│ records_per_doc_avg │ inf │\n", "│ sortable_values_size_mb │ 0 │\n", - "│ total_indexing_time │ 0.996 │\n", - "│ total_inverted_index_blocks │ 25 │\n", - "│ vector_index_sz_mb │ 3.0161 │\n", + "│ total_indexing_time │ 3.074 │\n", + "│ total_inverted_index_blocks │ 19 │\n", + "│ vector_index_sz_mb │ 0.000389099 │\n", "╰─────────────────────────────┴─────────────╯\n" ] } @@ -442,7 +515,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/redisvl/cli/index.py b/redisvl/cli/index.py index 9dd633af..c5b350b3 100644 --- a/redisvl/cli/index.py +++ b/redisvl/cli/index.py @@ -6,10 +6,10 @@ from redisvl.cli.utils import add_index_parsing_options, create_redis_url from redisvl.index import SearchIndex +from redisvl.redis.connection import RedisConnectionFactory +from redisvl.redis.utils import convert_bytes, make_dict from redisvl.schema.schema import IndexSchema -from redisvl.utils.connection import RedisConnection from redisvl.utils.log import get_logger -from redisvl.utils.utils import convert_bytes, make_dict logger = get_logger("[RedisVL]") @@ -81,7 +81,7 @@ def listall(self, args: Namespace): rvl index listall """ redis_url = create_redis_url(args) - conn = RedisConnection.get_redis_connection(redis_url) + conn = RedisConnectionFactory.get_redis_connection(redis_url) indices = convert_bytes(conn.execute_command("FT._LIST")) logger.info("Indices:") for i, index in enumerate(indices): @@ -109,7 +109,7 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex: # connect to redis try: redis_url = create_redis_url(args) - conn = RedisConnection.get_redis_connection(url=redis_url) + conn = RedisConnectionFactory.get_redis_connection(url=redis_url) except ValueError: logger.error( "Must set REDIS_URL environment variable or provide host and port" diff --git a/redisvl/index.py b/redisvl/index.py index 54bb1449..5f83e463 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -23,14 +23,13 @@ from redis.commands.search.indexDefinition import IndexDefinition from redisvl.query.query import BaseQuery, CountQuery, FilterQuery +from redisvl.redis.connection import RedisConnectionFactory +from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType from redisvl.storage import HashStorage, JsonStorage -from redisvl.utils.connection import RedisConnection -from redisvl.utils.utils import ( - check_async_redis_modules_exist, - check_redis_modules_exist, - convert_bytes, -) +from redisvl.utils.log import get_logger + +logger = get_logger(__name__) def process_results( @@ -85,26 +84,26 @@ def _process(doc: "Document") -> Dict[str, Any]: return [_process(doc) for doc in results.docs] -def check_modules_present(client_variable_name: str): +def check_modules_present(): def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): - _redis_conn = getattr(self, client_variable_name) - check_redis_modules_exist(_redis_conn.client) - return func(self, *args, **kwargs) + result = func(self, *args, **kwargs) + RedisConnectionFactory.validate_redis_modules(self._redis_client) + return result return wrapper return decorator -def check_async_modules_present(client_variable_name: str): +def check_async_modules_present(): def decorator(func): @wraps(func) - async def wrapper(self, *args, **kwargs): - _redis_conn = getattr(self, client_variable_name) - await check_async_redis_modules_exist(_redis_conn.client) - return await func(self, *args, **kwargs) + def wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + RedisConnectionFactory.validate_async_redis_modules(self._redis_client) + return result return wrapper @@ -130,7 +129,7 @@ def check_async_index_exists(): def decorator(func): @wraps(func) async def wrapper(self, *args, **kwargs): - if not await self.aexists(): + if not await self.exists(): raise ValueError( f"Index has not been created. Must be created before calling {func.__name__}" ) @@ -141,34 +140,8 @@ async def wrapper(self, *args, **kwargs): return decorator -class SearchIndex: - """A class for interacting with Redis as a vector database. - - This class is a wrapper around the redis-py client that provides - purpose-built methods for interacting with Redis as a vector database. - - .. code-block:: python - - from redisvl.index import SearchIndex - - # initialize the index object with schema from file - index = SearchIndex.from_yaml("schema.yaml", redis_url="redis://localhost:6379") - - # create the index - index.create(overwrite=True) - - # data is an iterable of dictionaries - index.load(data) - - # delete index and data - index.delete(drop=True) - - # Do the same with an an async connection - index = SearchIndex.from_yaml("schema.yaml", redis_url="redis://localhost:6379", use_async=True) - await index.acreate(overwrite=True) - await index.aload(data) - - """ +class BaseSearchIndex: + """Base search engine class""" _STORAGE_MAP = { StorageType.HASH: HashStorage, @@ -197,18 +170,19 @@ def __init__( args. """ # final validation on schema object - if not schema or not isinstance(schema, IndexSchema): - raise ValueError("Must provide a valid schema object") + if not isinstance(schema, IndexSchema): + raise ValueError("Must provide a valid IndexSchema object") + + self.schema = schema # set up redis connection - self._redis_conn = RedisConnection() + self._redis_client: Optional[Union[redis.Redis, aredis.Redis]] = None if redis_client is not None: self.set_client(redis_client) elif redis_url is not None: - self.connect(redis_url, **kwargs, **connection_args) - - self.schema = schema + self.connect(redis_url, **connection_args) + # set up index storage layer self._storage = self._STORAGE_MAP[self.schema.index.storage_type]( prefix=self.schema.index.prefix, key_separator=self.schema.index.key_separator, @@ -239,25 +213,14 @@ def storage_type(self) -> StorageType: @property def client(self) -> Optional[Union[redis.Redis, aredis.Redis]]: """The underlying redis-py client object.""" - return self._redis_conn.client + return self._redis_client @classmethod - def from_existing(cls): - raise DeprecationWarning( - "This method is deprecated since 0.0.5. Use the from_yaml or\ - from_dict constructors with an IndexSchema instead." - ) - - @classmethod - def from_yaml( - cls, schema_path: str, connection_args: Dict[str, Any] = {}, **kwargs - ): + def from_yaml(cls, schema_path: str, **kwargs): """Create a SearchIndex from a YAML schema file. Args: schema_path (str): Path to the YAML schema file. - connection_args (Dict[str, Any], optional): Redis client connection - args. Returns: SearchIndex: A RedisVL SearchIndex object. @@ -265,16 +228,14 @@ def from_yaml( .. code-block:: python from redisvl.index import SearchIndex - index = SearchIndex.from_yaml("schema.yaml", redis_url="redis://localhost:6379") - index.create(overwrite=True) + index = SearchIndex.from_yaml("schemas/schema.yaml") + index.connect(redis_url="redis://localhost:6379") """ schema = IndexSchema.from_yaml(schema_path) - return cls(schema=schema, connection_args=connection_args, **kwargs) + return cls(schema=schema, **kwargs) @classmethod - def from_dict( - cls, schema_dict: Dict[str, Any], connection_args: Dict[str, Any] = {}, **kwargs - ): + def from_dict(cls, schema_dict: Dict[str, Any], **kwargs): """Create a SearchIndex from a dictionary. Args: @@ -297,23 +258,76 @@ def from_dict( "fields": [ {"name": "doc-id", "type": "tag"} ] - }, redis_url="redis://localhost:6379") - index.create(overwrite=True) + }) + index.connect(redis_url="redis://localhost:6379") """ schema = IndexSchema.from_dict(schema_dict) - return cls(schema=schema, connection_args=connection_args, **kwargs) + return cls(schema=schema, **kwargs) - def connect( - self, redis_url: Optional[str] = None, use_async: bool = False, **kwargs - ): - """Connect to a Redis instance. + def connect(self, redis_url: Optional[str] = None, **kwargs): + """Connect to Redis at a given URL.""" + raise NotImplementedError + + def set_client(self, client: Union[redis.Redis, aredis.Redis]): + """Manually set the Redis client to use with the search index.""" + raise NotImplementedError + + def disconnect(self): + """Reset the Redis connection.""" + self._redis_client = None + return self + + def key(self, id: str) -> str: + """Create a redis key as a combination of an index key prefix (optional) + and specified id. + + The id is typically either a unique identifier, or + derived from some domain-specific metadata combination (like a document + id or chunk id). + + Args: + id (str): The specified unique identifier for a particular + document indexed in Redis. + + Returns: + str: The full Redis key including key prefix and value as a string. + """ + return self._storage._key( + id=id, + prefix=self.schema.index.prefix, + key_separator=self.schema.index.key_separator, + ) + + +class SearchIndex(BaseSearchIndex): + """A class for interacting with Redis as a vector database. + + This class is a wrapper around the redis-py client that provides + purpose-built methods for interacting with Redis as a vector database. + + .. code-block:: python + + from redisvl.index import SearchIndex - This method establishes a connection to a Redis server. If `redis_url` - is provided, it will be used as the connection endpoint. Otherwise, the - method attempts to use the `REDIS_URL` environment variable as the - connection URL. The `use_async` parameter determines whether the - connection should be asynchronous. + # initialize the index object with schema from file + index = SearchIndex.from_yaml("schemas/schema.yaml") + index.connect(redis_url="redis://localhost:6379") + + # create the index + index.create(overwrite=True) + + # data is an iterable of dictionaries + index.load(data) + + # delete index and data + index.delete(drop=True) + + """ + + def connect(self, redis_url: Optional[str] = None, **kwargs): + """Connect to a Redis instance using the provided `redis_url`, falling + back to the `REDIS_URL` environment variable (if available). Note: Additional keyword arguments (`**kwargs`) can be used to provide extra options specific to the Redis connection. @@ -322,8 +336,6 @@ def connect( redis_url (Optional[str], optional): The URL of the Redis server to connect to. If not provided, the method defaults to using the `REDIS_URL` environment variable. - use_async (bool): If `True`, establishes a connection with an async - Redis client. Defaults to `False`. Raises: redis.exceptions.ConnectionError: If the connection to the Redis @@ -333,21 +345,14 @@ def connect( .. code-block:: python - # standard sync Redis connection index.connect(redis_url="redis://localhost:6379") - # async Redis connection - index.connect(redis_url="redis://localhost:6379", use_async=True) """ - self._redis_conn.connect(redis_url, use_async, **kwargs) - return self + client = RedisConnectionFactory.connect(redis_url, use_async=False, **kwargs) + return self.set_client(client) - def disconnect(self): - """Reset the Redis connection.""" - self._redis_conn = RedisConnection() - return self - - def set_client(self, client: Union[redis.Redis, aredis.Redis]): + @check_modules_present() + def set_client(self, client: redis.Redis): """Manually set the Redis client to use with the search index. This method configures the search index to use a specific Redis or @@ -355,7 +360,7 @@ def set_client(self, client: Union[redis.Redis, aredis.Redis]): custom-configured client is preferred instead of creating a new one. Args: - client (Union[redis.Redis, aredis.Redis]): A Redis or Async Redis + client (redis.Redis): A Redis or Async Redis client instance to be used for the connection. Raises: @@ -363,47 +368,44 @@ def set_client(self, client: Union[redis.Redis, aredis.Redis]): .. code-block:: python - r = redis.Redis.from_url("redis://localhost:6379") - index.set_client(r) - - # async Redis client - import redis.asyncio as aredis + import redis + from redisvl.index import SearchIndex - r = aredis.Redis.from_url("redis://localhost:6379") - index.set_client(r) + client = redis.Redis.from_url("redis://localhost:6379") + index = SearchIndex.from_yaml("schemas/schema.yaml") + index.set_client(client) """ - self._redis_conn.set_client(client) - return self - - def key(self, id: str) -> str: - """Create a redis key as a combination of an index key prefix (optional) - and specified id. The id is typically either a unique identifier, or - derived from some domain-specific metadata combination (like a document - id or chunk id). + if not isinstance(client, redis.Redis): + raise TypeError("Invalid Redis client instance") - Args: - id (str): The specified unique identifier for a particular - document indexed in Redis. + self._redis_client = client - Returns: - str: The full Redis key including key prefix and value as a string. - """ - return self._storage._key( - id, self.schema.index.prefix, self.schema.index.key_separator - ) + return self - @check_modules_present("_redis_conn") - def create(self, overwrite: bool = False) -> None: - """Create an index in Redis from this SearchIndex object. + def create(self, overwrite: bool = False, drop: bool = False) -> None: + """Create an index in Redis with the given schema and properties. Args: overwrite (bool, optional): Whether to overwrite the index if it already exists. Defaults to False. + drop (bool, optional): Whether to drop all keys associated with the + index in the case of overwriting. Defaults to False. Raises: RuntimeError: If the index already exists and 'overwrite' is False. ValueError: If no fields are defined for the index. + + .. code-block:: python + + # create an index in Redis; only if one does not exist with given name + index.create() + + # overwrite an index in Redis without dropping associated data + index.create(overwrite=True) + + # overwrite an index in Redis; drop associated data (clean slate) + index.create(overwrite=True, drop=True) """ # Check that fields are defined. redis_fields = self.schema.redis_fields @@ -414,35 +416,41 @@ def create(self, overwrite: bool = False) -> None: if self.exists(): if not overwrite: - print("Index already exists, not overwriting.") + logger.info("Index already exists, not overwriting.") return None - print("Index already exists, overwriting.") - self.delete() - - # Create the index with the specified fields and settings. - self._redis_conn.client.ft(self.name).create_index( # type: ignore - fields=redis_fields, - definition=IndexDefinition( - prefix=[self.schema.index.prefix], index_type=self._storage.type - ), - ) + logger.info("Index already exists, overwriting.") + self.delete(drop=drop) + + try: + self._redis_client.ft(self.name).create_index( # type: ignore + fields=redis_fields, + definition=IndexDefinition( + prefix=[self.schema.index.prefix], index_type=self._storage.type + ), + ) + except: + logger.exception("Error while trying to create the index") + raise - @check_modules_present("_redis_conn") @check_index_exists() def delete(self, drop: bool = True): - """Delete the search index. + """Delete the search index while optionally dropping all keys associated + with the index. Args: - drop (bool, optional): Delete the documents in the index. - Defaults to True. + drop (bool, optional): Delete the key / documents pairs in the + index. Defaults to True. raises: redis.exceptions.ResponseError: If the index does not exist. """ - # Delete the search index - self._redis_conn.client.ft(self.schema.index.name).dropindex(delete_documents=drop) # type: ignore + try: + self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore + delete_documents=drop + ) + except: + logger.exception("Error while deleting index") - @check_modules_present("_redis_conn") def load( self, data: Iterable[Any], @@ -480,15 +488,19 @@ def load( keys = index.load([{"test": "foo"}, {"test": "bar"}]) """ - return self._storage.write( - self._redis_conn.client, # type: ignore - objects=data, - key_field=key_field, - keys=keys, - ttl=ttl, - preprocess=preprocess, - batch_size=batch_size, - ) + try: + return self._storage.write( + self._redis_client, # type: ignore + objects=data, + key_field=key_field, + keys=keys, + ttl=ttl, + preprocess=preprocess, + batch_size=batch_size, + ) + except: + logger.exception("Error while loading data to Redis") + raise def fetch(self, id: str) -> Dict[str, Any]: """Fetch an object from Redis by id. @@ -504,35 +516,34 @@ def fetch(self, id: str) -> Dict[str, Any]: Returns: Dict[str, Any]: The fetched object. """ - return convert_bytes(self._redis_conn.client.hgetall(self.key(id))) # type: ignore + return convert_bytes(self._redis_client.hgetall(self.key(id))) # type: ignore - @check_modules_present("_redis_conn") @check_index_exists() - def search(self, *args, **kwargs) -> Union["Result", Any]: - """Perform a search on this index. + def search(self, *args, **kwargs) -> "Result": + """Perform a search against the index. Wrapper around redis.search.Search that adds the index name to the search query and passes along the rest of the arguments to the redis-py ft.search() method. Returns: - Union["Result", Any]: Search results. + Result: Raw Redis search results. """ - results = self._redis_conn.client.ft(self.schema.index.name).search( # type: ignore - *args, **kwargs - ) - return results + try: + return self._redis_client.ft(self.schema.index.name).search( # type: ignore + *args, **kwargs + ) + except: + logger.exception("Error while searching") + raise def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Execute a query and process results.""" results = self.search(query.query, query_params=query.params) - # post process the results return process_results( results, query=query, storage_type=self.schema.index.storage_type ) - @check_modules_present("_redis_conn") - @check_index_exists() def query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Execute a query on the index. @@ -552,10 +563,8 @@ def query(self, query: BaseQuery) -> List[Dict[str, Any]]: """ return self._query(query) - @check_modules_present("_redis_conn") - @check_index_exists() def query_batch(self, query: BaseQuery, batch_size: int = 30) -> Generator: - """Execute a query on the index with batching. + """Execute a query on the index while batching results. This method takes a BaseQuery object directly, handles optional paging support, and post-processing of the search results. @@ -594,16 +603,14 @@ def query_batch(self, query: BaseQuery, batch_size: int = 30) -> Generator: # increment the pagination tracker first += batch_size - @check_modules_present("_redis_conn") def listall(self) -> List[str]: """List all search indices in Redis database. Returns: List[str]: The list of indices in the database. """ - return convert_bytes(self._redis_conn.client.execute_command("FT._LIST")) # type: ignore + return convert_bytes(self._redis_client.execute_command("FT._LIST")) # type: ignore - @check_modules_present("_redis_conn") def exists(self) -> bool: """Check if the index exists in Redis. @@ -612,7 +619,6 @@ def exists(self) -> bool: """ return self.schema.index.name in self.listall() - @check_modules_present("_redis_conn") @check_index_exists() def info(self) -> Dict[str, Any]: """Get information about the index. @@ -620,20 +626,125 @@ def info(self) -> Dict[str, Any]: Returns: dict: A dictionary containing the information about the index. """ - return convert_bytes( - self._redis_conn.client.ft(self.schema.index.name).info() # type: ignore - ) + try: + return convert_bytes( + self._redis_client.ft(self.schema.index.name).info() # type: ignore + ) + except: + logger.exception( + f"Error while fetching {self.schema.index.name} index info" + ) + raise + + +class AsyncSearchIndex(BaseSearchIndex): + """A class for interacting with Redis as a vector database in async mode. + + This class is a wrapper around the redis-py async client that provides + purpose-built methods for interacting with Redis as a vector database. + + .. code-block:: python + + from redisvl.index import AsyncSearchIndex + + # initialize the index object with schema from file + index = AsyncSearchIndex.from_yaml("schemas/schema.yaml") + index.connect(redis_url="redis://localhost:6379") + + # create the index + await index.create(overwrite=True) + + # data is an iterable of dictionaries + await index.load(data) + + # delete index and data + await index.delete(drop=True) + + """ + + def connect(self, redis_url: Optional[str] = None, **kwargs): + """Connect to a Redis instance using the provided `redis_url`, falling + back to the `REDIS_URL` environment variable (if available). + + Note: Additional keyword arguments (`**kwargs`) can be used to provide + extra options specific to the Redis connection. + + Args: + redis_url (Optional[str], optional): The URL of the Redis server to + connect to. If not provided, the method defaults to using the + `REDIS_URL` environment variable. + + Raises: + redis.exceptions.ConnectionError: If the connection to the Redis + server fails. + ValueError: If the Redis URL is not provided nor accessible + through the `REDIS_URL` environment variable. + + .. code-block:: python + + index.connect(redis_url="redis://localhost:6379") + + """ + client = RedisConnectionFactory.connect(redis_url, use_async=True, **kwargs) + return self.set_client(client) + + @check_async_modules_present() + def set_client(self, client: aredis.Redis): + """Manually set the Redis client to use with the search index. + + This method configures the search index to use a specific + Async Redis client. It is useful for cases where an external, + custom-configured client is preferred instead of creating a new one. + + Args: + client (aredis.Redis): An Async Redis + client instance to be used for the connection. + + Raises: + TypeError: If the provided client is not valid. - @check_async_modules_present("_redis_conn") - async def acreate(self, overwrite: bool = False) -> None: - """Asynchronously create an index in Redis from this SearchIndex object. + .. code-block:: python + + import redis.asyncio as aredis + from redisvl.index import AsyncSearchIndex + + # async Redis client and index + client = aredis.Redis.from_url("redis://localhost:6379") + index = AsyncSearchIndex.from_yaml("schemas/schema.yaml") + index.set_client(client) + + """ + if not isinstance(client, aredis.Redis): + raise TypeError("Invalid Redis client instance") + + self._redis_client = client + + return self + + async def create(self, overwrite: bool = False, drop: bool = False) -> None: + """Asynchronously create an index in Redis with the given schema + and properties. Args: overwrite (bool, optional): Whether to overwrite the index if it already exists. Defaults to False. + drop (bool, optional): Whether to drop all keys associated with the + index in the case of overwriting. Defaults to False. Raises: RuntimeError: If the index already exists and 'overwrite' is False. + ValueError: If no fields are defined for the index. + + .. code-block:: python + + # create an index in Redis; only if one does not exist with given name + await index.create() + + # overwrite an index in Redis without dropping associated data + await index.create(overwrite=True) + + # overwrite an index in Redis; drop associated data (clean slate) + await index.create(overwrite=True, drop=True) """ redis_fields = self.schema.redis_fields if not redis_fields: @@ -641,24 +752,26 @@ async def acreate(self, overwrite: bool = False) -> None: if not isinstance(overwrite, bool): raise TypeError("overwrite must be of type bool") - if await self.aexists(): + if await self.exists(): if not overwrite: - print("Index already exists, not overwriting.") + logger.info("Index already exists, not overwriting.") return None - print("Index already exists, overwriting.") - await self.adelete() - - # Create Index with proper IndexType - await self._redis_conn.client.ft(self.schema.index.name).create_index( # type: ignore - fields=redis_fields, - definition=IndexDefinition( - prefix=[self.schema.index.prefix], index_type=self._storage.type - ), - ) + logger.info("Index already exists, overwriting.") + await self.delete(drop) + + try: + await self._redis_client.ft(self.schema.index.name).create_index( # type: ignore + fields=redis_fields, + definition=IndexDefinition( + prefix=[self.schema.index.prefix], index_type=self._storage.type + ), + ) + except: + logger.exception("Error while trying to create the index") + raise - @check_async_modules_present("_redis_conn") @check_async_index_exists() - async def adelete(self, drop: bool = True): + async def delete(self, drop: bool = True): """Delete the search index. Args: @@ -668,11 +781,15 @@ async def adelete(self, drop: bool = True): Raises: redis.exceptions.ResponseError: If the index does not exist. """ - # Delete the search index - await self._redis_conn.client.ft(self.schema.index.name).dropindex(delete_documents=drop) # type: ignore - - @check_async_modules_present("_redis_conn") - async def aload( + try: + await self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore + delete_documents=drop + ) + except: + logger.exception("Error while deleting index") + raise + + async def load( self, data: Iterable[Any], key_field: Optional[str] = None, @@ -710,17 +827,21 @@ async def aload( keys = await index.aload([{"test": "foo"}, {"test": "bar"}]) """ - return await self._storage.awrite( - self._redis_conn.client, # type: ignore - objects=data, - key_field=key_field, - keys=keys, - ttl=ttl, - preprocess=preprocess, - concurrency=concurrency, - ) - - async def afetch(self, id: str) -> Dict[str, Any]: + try: + return await self._storage.awrite( + self._redis_client, # type: ignore + objects=data, + key_field=key_field, + keys=keys, + ttl=ttl, + preprocess=preprocess, + concurrency=concurrency, + ) + except: + logger.exception("Error while loading data to Redis") + raise + + async def fetch(self, id: str) -> Dict[str, Any]: """Asynchronously etch an object from Redis by id. The id is typically either a unique identifier, or derived from some domain-specific metadata combination (like a document id or chunk id). @@ -732,11 +853,10 @@ async def afetch(self, id: str) -> Dict[str, Any]: Returns: Dict[str, Any]: The fetched object. """ - return convert_bytes(await self._redis_conn.client.hgetall(self.key(id))) # type: ignore + return convert_bytes(await self._redis_client.hgetall(self.key(id))) # type: ignore - @check_async_modules_present("_redis_conn") @check_async_index_exists() - async def asearch(self, *args, **kwargs) -> Union["Result", Any]: + async def search(self, *args, **kwargs) -> "Result": """Perform a search on this index. Wrapper around redis.search.Search that adds the index name @@ -744,24 +864,24 @@ async def asearch(self, *args, **kwargs) -> Union["Result", Any]: to the redis-py ft.search() method. Returns: - Union["Result", Any]: Search results. + Result: Raw Redis search results. """ - results = await self._redis_conn.client.ft(self.schema.index.name).search( # type: ignore - *args, **kwargs - ) - return results - - async def _aquery(self, query: BaseQuery) -> List[Dict[str, Any]]: + try: + return await self._redis_client.ft(self.schema.index.name).search( # type: ignore + *args, **kwargs + ) + except: + logger.exception("Error while searching") + raise + + async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Asynchronously execute a query and process results.""" - results = await self.asearch(query.query, query_params=query.params) - # post process the results + results = await self.search(query.query, query_params=query.params) return process_results( results, query=query, storage_type=self.schema.index.storage_type ) - @check_async_modules_present("_redis_conn") - @check_async_index_exists() - async def aquery(self, query: BaseQuery) -> List[Dict[str, Any]]: + async def query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Asynchronously execute a query on the index. This method takes a BaseQuery object directly, runs the search, and @@ -777,11 +897,9 @@ async def aquery(self, query: BaseQuery) -> List[Dict[str, Any]]: results = await aindex.query(query) """ - return await self._aquery(query) + return await self._query(query) - @check_async_modules_present("_redis_conn") - @check_async_index_exists() - async def aquery_batch( + async def query_batch( self, query: BaseQuery, batch_size: int = 30 ) -> AsyncGenerator: """Execute a query on the index with batching. @@ -802,7 +920,7 @@ async def aquery_batch( .. code-block:: python - async for batch in index.aquery_batch(query, batch_size=10): + async for batch in index.query_batch(query, batch_size=10): # process batched results pass """ @@ -815,41 +933,44 @@ async def aquery_batch( first = 0 while True: query.set_paging(first, batch_size) - batch_results = await self._aquery(query) + batch_results = await self._query(query) if not batch_results: break yield batch_results # increment the pagination tracker first += batch_size - @check_async_modules_present("_redis_conn") - async def alistall(self) -> List[str]: + async def listall(self) -> List[str]: """List all search indices in Redis database. Returns: List[str]: The list of indices in the database. """ return convert_bytes( - await self._redis_conn.client.execute_command("FT._LIST") # type: ignore + await self._redis_client.execute_command("FT._LIST") # type: ignore ) - @check_async_modules_present("_redis_conn") - async def aexists(self) -> bool: + async def exists(self) -> bool: """Check if the index exists in Redis. Returns: bool: True if the index exists, False otherwise. """ - return self.schema.index.name in await self.alistall() + return self.schema.index.name in await self.listall() - @check_async_modules_present("_redis_conn") @check_async_index_exists() - async def ainfo(self) -> Dict[str, Any]: + async def info(self) -> Dict[str, Any]: """Get information about the index. Returns: dict: A dictionary containing the information about the index. """ - return convert_bytes( - await self._redis_conn.client.ft(self.schema.index.name).info() # type: ignore - ) + try: + return convert_bytes( + await self._redis_client.ft(self.schema.index.name).info() # type: ignore + ) + except: + logger.exception( + f"Error while fetching {self.schema.index.name} index info" + ) + raise diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index f3da56b2..f2c716e3 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -1,4 +1,3 @@ -import warnings from typing import Any, Dict, List, Optional from redis import Redis @@ -6,8 +5,8 @@ from redisvl.index import SearchIndex from redisvl.llmcache.base import BaseLLMCache from redisvl.query import RangeQuery +from redisvl.redis.utils import array_to_buffer from redisvl.schema.schema import IndexSchema -from redisvl.utils.utils import array_to_buffer from redisvl.vectorize.base import BaseVectorizer from redisvl.vectorize.text import HFTextVectorizer @@ -64,29 +63,7 @@ def __init__( """ super().__init__(ttl) - # Check for index_name in kwargs - if "index_name" in kwargs: - name = kwargs.pop("index_name") - warnings.warn( - message="index_name kwarg is deprecated in favor of name.", - category=DeprecationWarning, - stacklevel=2, - ) - - # Check for threshold in kwargs - if "threshold" in kwargs: - distance_threshold = 1 - kwargs.pop("threshold") - warnings.warn( - message="threshold kwarg is deprecated in favor of distance_threshold. " - + "Setting distance_threshold to 1 - threshold.", - category=DeprecationWarning, - stacklevel=2, - ) - - if not isinstance(name, str): - raise ValueError("A valid index name must be provided.") - - # use the index name as the key prefix by default + # Use the index name as the key prefix by default if prefix is None: prefix = name @@ -110,8 +87,10 @@ def __init__( ] ) - # build search index and connect + # build search index self._index = SearchIndex(schema=schema) + + # handle redis connection if redis_client: self._index.set_client(redis_client) else: diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 2d3af642..82111d04 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -4,7 +4,7 @@ from redis.commands.search.query import Query from redisvl.query.filter import FilterExpression -from redisvl.utils.utils import array_to_buffer +from redisvl.redis.utils import array_to_buffer class BaseQuery: diff --git a/redisvl/redis/__init__.py b/redisvl/redis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py new file mode 100644 index 00000000..c88efaf0 --- /dev/null +++ b/redisvl/redis/connection.py @@ -0,0 +1,154 @@ +import os +from typing import Optional + +from redis import ConnectionPool, Redis +from redis.asyncio import Redis as AsyncRedis + +from redisvl.redis.constants import REDIS_REQUIRED_MODULES +from redisvl.redis.utils import convert_bytes + + +def get_address_from_env() -> str: + """Get a redis connection from environment variables. + + Returns: + str: Redis URL + """ + if "REDIS_URL" not in os.environ: + raise ValueError("REDIS_URL env var not set") + return os.environ["REDIS_URL"] + + +class RedisConnectionFactory: + """Builds connections to a Redis database, supporting both synchronous and + asynchronous clients. + + This class allows for establishing and handling Redis connections using + either standard Redis or async Redis clients, based on the provided + configuration. + """ + + @classmethod + def connect( + cls, redis_url: Optional[str] = None, use_async: bool = False, **kwargs + ) -> None: + """Create a connection to the Redis database based on a URL and some + connection kwargs. + + This method sets up either a synchronous or asynchronous Redis client + based on the provided parameters. + + Args: + redis_url (Optional[str]): The URL of the Redis server to connect + to. If not provided, the environment variable REDIS_URL is used. + use_async (bool): If True, an asynchronous client is created. + Defaults to False. + **kwargs: Additional keyword arguments to be passed to the Redis + client constructor. + + Raises: + ValueError: If redis_url is not provided and REDIS_URL environment + variable is not set. + """ + redis_url = redis_url or get_address_from_env() + connection_func = ( + cls.get_async_redis_connection if use_async else cls.get_redis_connection + ) + return connection_func(redis_url, **kwargs) # type: ignore + + @staticmethod + def get_redis_connection(url: Optional[str] = None, **kwargs) -> Redis: + """Creates and returns a synchronous Redis client. + + Args: + url (Optional[str]): The URL of the Redis server. If not provided, + the environment variable REDIS_URL is used. + **kwargs: Additional keyword arguments to be passed to the Redis + client constructor. + + Returns: + Redis: A synchronous Redis client instance. + + Raises: + ValueError: If url is not provided and REDIS_URL environment + variable is not set. + """ + if url: + return Redis.from_url(url, **kwargs) + # fallback to env var REDIS_URL + return Redis.from_url(get_address_from_env(), **kwargs) + + @staticmethod + def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedis: + """Creates and returns an asynchronous Redis client. + + Args: + url (Optional[str]): The URL of the Redis server. If not provided, + the environment variable REDIS_URL is used. + **kwargs: Additional keyword arguments to be passed to the async + Redis client constructor. + + Returns: + AsyncRedis: An asynchronous Redis client instance. + + Raises: + ValueError: If url is not provided and REDIS_URL environment + variable is not set. + """ + if url: + return AsyncRedis.from_url(url, **kwargs) + # fallback to env var REDIS_URL + return AsyncRedis.from_url(get_address_from_env(), **kwargs) + + @staticmethod + def validate_redis_modules(client: Redis) -> None: + """Validates if the required Redis modules are installed. + + Args: + client (Redis): Synchronous Redis client. + + Raises: + ValueError: If required Redis modules are not installed. + """ + RedisConnectionFactory._validate_redis_modules( + convert_bytes(client.module_list()) + ) + + @staticmethod + def validate_async_redis_modules(client: AsyncRedis) -> None: + """ + Validates if the required Redis modules are installed. + + Args: + client (AsyncRedis): Asynchronous Redis client. + + Raises: + ValueError: If required Redis modules are not installed. + """ + temp_client = Redis( + connection_pool=ConnectionPool(**client.connection_pool.connection_kwargs) + ) + RedisConnectionFactory.validate_redis_modules(temp_client) + + @staticmethod + def _validate_redis_modules(installed_modules) -> None: + """ + Validates if required Redis modules are installed. + + Args: + installed_modules: List of installed modules. + + Raises: + ValueError: If required Redis modules are not installed. + """ + installed_modules = {module["name"]: module for module in installed_modules} + for required_module in REDIS_REQUIRED_MODULES: + if required_module["name"] in installed_modules: + installed_version = installed_modules[required_module["name"]]["ver"] + if int(installed_version) >= int(required_module["ver"]): # type: ignore + return + + raise ValueError( + f"Required Redis database module {required_module['name']} with version >= {required_module['ver']} not installed. " + "Refer to Redis Stack documentation: https://redis.io/docs/stack/" + ) diff --git a/redisvl/redis/constants.py b/redisvl/redis/constants.py new file mode 100644 index 00000000..43ec7394 --- /dev/null +++ b/redisvl/redis/constants.py @@ -0,0 +1,8 @@ +# required modules +REDIS_REQUIRED_MODULES = [ + {"name": "search", "ver": 20600}, + {"name": "searchlight", "ver": 20600}, +] + +# default tag separator +REDIS_TAG_SEPARATOR = "," diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py new file mode 100644 index 00000000..e8108bbb --- /dev/null +++ b/redisvl/redis/utils.py @@ -0,0 +1,39 @@ +from typing import Any, Dict, List + +import numpy as np + + +def make_dict(values: List[Any]) -> Dict[Any, Any]: + """Convert a list of objects into a dictionary""" + i = 0 + di = {} + while i < len(values) - 1: + di[values[i]] = values[i + 1] + i += 2 + return di + + +def convert_bytes(data: Any) -> Any: + """Convert bytes data back to string""" + if isinstance(data, bytes): + try: + return data.decode("utf-8") + except: + return data + if isinstance(data, dict): + return dict(map(convert_bytes, data.items())) + if isinstance(data, list): + return list(map(convert_bytes, data)) + if isinstance(data, tuple): + return map(convert_bytes, data) + return data + + +def array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes: + """Convert a list of floats into a numpy byte string.""" + return np.array(array).astype(dtype).tobytes() + + +def buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]: + """Convert bytes into into a list of floats.""" + return np.frombuffer(buffer, dtype=dtype).tolist() diff --git a/redisvl/schema/fields.py b/redisvl/schema/fields.py index b01c3d56..7dd85bea 100644 --- a/redisvl/schema/fields.py +++ b/redisvl/schema/fields.py @@ -1,3 +1,11 @@ +""" +RedisVL Fields, FieldAttributes, and Enums + +Reference Redis search source documentation as needed: https://redis.io/commands/ft.create/ +Reference Redis vector search documentation as needed: https://redis.io/docs/interact/search-and-query/advanced-concepts/vectors/ +""" + +from enum import Enum from typing import Any, Dict, Optional, Tuple, Type, Union from pydantic.v1 import BaseModel, Field, validator @@ -7,47 +15,95 @@ from redis.commands.search.field import TagField as RedisTagField from redis.commands.search.field import TextField as RedisTextField from redis.commands.search.field import VectorField as RedisVectorField -from typing_extensions import Literal + +### Attribute Enums ### + + +class VectorDistanceMetric(str, Enum): + COSINE = "COSINE" + L2 = "L2" + IP = "IP" + + +class VectorDataType(str, Enum): + FLOAT32 = "FLOAT32" + FLOAT64 = "FLOAT64" + + +class VectorIndexAlgorithm(str, Enum): + FLAT = "FLAT" + HNSW = "HNSW" + + +### Field Attributes ### class BaseFieldAttributes(BaseModel): - sortable: Optional[bool] = False + """Base field attributes shared by other lexical fields""" + + sortable: bool = Field(default=False) + """Enable faster result sorting on the field at runtime""" class TextFieldAttributes(BaseFieldAttributes): - weight: Optional[float] = 1 - no_stem: Optional[bool] = False + """Full text field attributes""" + + weight: float = Field(default=1) + """Declares the importance of this field when calculating results""" + no_stem: bool = Field(default=False) + """Disable stemming on the text field during indexing""" + withsuffixtrie: bool = Field(default=False) + """Keep a suffix trie with all terms which match the suffix to optimize certain queries""" phonetic_matcher: Optional[str] = None - withsuffixtrie: Optional[bool] = False + """Used to perform phonetic matching during search""" class TagFieldAttributes(BaseFieldAttributes): - separator: Optional[str] = "," - case_sensitive: Optional[bool] = False + """Tag field attributes""" + + separator: str = Field(default=",") + """Indicates how the text in the original attribute is split into individual tags""" + case_sensitive: bool = Field(default=False) + """Treat text as case sensitive or not. By default, tag characters are converted to lowercase""" + withsuffixtrie: bool = Field(default=False) + """Keep a suffix trie with all terms which match the suffix to optimize certain queries""" class NumericFieldAttributes(BaseFieldAttributes): + """Numeric field attributes""" + pass class GeoFieldAttributes(BaseFieldAttributes): + """Numeric field attributes""" + pass class BaseVectorFieldAttributes(BaseModel): - dims: int = Field(...) - algorithm: object = Field(...) - datatype: str = Field(default="FLOAT32") - distance_metric: str = Field(default="COSINE") + """Base vector field attributes shared by both FLAT and HNSW fields""" + + dims: int + """Dimensionality of the vector embeddings field""" + algorithm: VectorIndexAlgorithm + """The indexing algorithm for the field: HNSW or FLAT""" + datatype: VectorDataType = Field(default=VectorDataType.FLOAT32) + """The float datatype for the vector embeddings""" + distance_metric: VectorDistanceMetric = Field(default=VectorDistanceMetric.COSINE) + """The distance metric used to measure query relevance""" initial_cap: Optional[int] = None + """Initial vector capacity in the index affecting memory allocation size of the index""" @validator("algorithm", "datatype", "distance_metric", pre=True) @classmethod def uppercase_strings(cls, v): + """Validate that provided values are cast to uppercase""" return v.upper() @property def field_data(self) -> Dict[str, Any]: + """Select attributes required by the Redis API""" field_data = { "TYPE": self.datatype, "DIM": self.dims, @@ -58,17 +114,32 @@ def field_data(self) -> Dict[str, Any]: return field_data +class FlatVectorFieldAttributes(BaseVectorFieldAttributes): + """FLAT vector field attributes""" + + algorithm: VectorIndexAlgorithm = Field( + default=VectorIndexAlgorithm.FLAT, const=True + ) + """The indexing algorithm for the vector field""" + block_size: Optional[int] = None + """Block size to hold amount of vectors in a contiguous array. This is useful when the index is dynamic with respect to addition and deletion""" + + class HNSWVectorFieldAttributes(BaseVectorFieldAttributes): - algorithm: Literal["HNSW"] = "HNSW" + """HNSW vector field attributes""" + + algorithm: VectorIndexAlgorithm = Field( + default=VectorIndexAlgorithm.HNSW, const=True + ) + """The indexing algorithm for the vector field""" m: int = Field(default=16) + """Number of max outgoing edges for each graph node in each layer""" ef_construction: int = Field(default=200) + """Number of max allowed potential outgoing edges candidates for each node in the graph during build time""" ef_runtime: int = Field(default=10) + """Number of maximum top candidates to hold during KNN search""" epsilon: float = Field(default=0.01) - - -class FlatVectorFieldAttributes(BaseVectorFieldAttributes): - algorithm: Literal["FLAT"] = "FLAT" - block_size: Optional[int] = None + """Relative factor that sets the boundaries in which a range query may search for candidates""" ### Field Classes ### @@ -114,7 +185,7 @@ def as_redis_field(self) -> RedisField: class TagField(BaseField): - """Tag field for simple boolean filtering""" + """Tag field for simple boolean-style filtering""" type: str = Field(default="tag", const=True) attrs: TagFieldAttributes = Field(default_factory=TagFieldAttributes) diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index e69b36f2..60d72bae 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -1,5 +1,4 @@ import re -import warnings from enum import Enum from pathlib import Path from typing import Any, Dict, List @@ -9,7 +8,9 @@ from redis.commands.search.field import Field as RedisField from redisvl.schema.fields import BaseField, FieldFactory +from redisvl.utils.log import get_logger +logger = get_logger(__name__) SCHEMA_VERSION = "0.1.0" @@ -147,8 +148,8 @@ def _make_field(storage_type, **field_inputs) -> BaseField: field.path = field.path if field.path else f"$.{field.name}" else: if field.path is not None: - warnings.warn( - message=f"Path attribute for field '{field.name}' will be ignored for HASH storage type." + logger.warning( + f"Path attribute for field '{field.name}' will be ignored for HASH storage type." ) field.path = None return field @@ -351,7 +352,7 @@ def remove_field(self, field_name: str): field_name (str): The name of the field to be removed. """ if field_name not in self.fields: - warnings.warn(message=f"Field '{field_name}' does not exist in the schema") + logger.warning(f"Field '{field_name}' does not exist in the schema") return del self.fields[field_name] @@ -400,7 +401,7 @@ def generate_fields( if strict: raise else: - warnings.warn( + logger.warn( message=f"Error inferring field type for {field_name}: {e}" ) return fields diff --git a/redisvl/storage.py b/redisvl/storage.py index c2b8d9bd..e0e9cec6 100644 --- a/redisvl/storage.py +++ b/redisvl/storage.py @@ -7,7 +7,7 @@ from redis.asyncio import Redis as AsyncRedis from redis.commands.search.indexDefinition import IndexType -from redisvl.utils.utils import convert_bytes +from redisvl.redis.utils import convert_bytes class BaseStorage(BaseModel): @@ -18,19 +18,24 @@ class BaseStorage(BaseModel): validation, and basic read/write operations (both sync and async). """ - type: IndexType # Type of index used in storage - prefix: str # Prefix for Redis keys - key_separator: str # Separator between prefix and key value - default_batch_size: int = 200 # Default size for batch operations - default_write_concurrency: int = 20 # Default concurrency for async ops + type: IndexType + """Type of index used in storage""" + prefix: str + """Prefix for Redis keys""" + key_separator: str + """Separator between prefix and key value""" + default_batch_size: int = 200 + """Default size for batch operations""" + default_write_concurrency: int = 20 + """Default concurrency for async ops""" @staticmethod - def _key(key_value: str, prefix: str, key_separator: str) -> str: + def _key(id: str, prefix: str, key_separator: str) -> str: """Create a Redis key using a combination of a prefix, separator, and - the key value. + the identifider. Args: - key_value (str): The unique identifier for the Redis entry. + id (str): The unique identifier for the Redis entry. prefix (str): A prefix to append before the key value. key_separator (str): A separator to insert between prefix and key value. @@ -39,9 +44,9 @@ def _key(key_value: str, prefix: str, key_separator: str) -> str: str: The fully formed Redis key. """ if not prefix: - return key_value + return id else: - return f"{prefix}{key_separator}{key_value}" + return f"{prefix}{key_separator}{id}" def _create_key(self, obj: Dict[str, Any], key_field: Optional[str] = None) -> str: """Construct a Redis key for a given object, optionally using a @@ -384,6 +389,7 @@ class HashStorage(BaseStorage): """ type: IndexType = IndexType.HASH + """Hash data type for the index""" def _validate(self, obj: Dict[str, Any]): """Validate that the given object is a dictionary, suitable for storage @@ -456,6 +462,7 @@ class JsonStorage(BaseStorage): """ type: IndexType = IndexType.JSON + """JSON data type for the index""" def _validate(self, obj: Dict[str, Any]): """Validate that the given object is a dictionary, suitable for JSON diff --git a/redisvl/utils/connection.py b/redisvl/utils/connection.py deleted file mode 100644 index aa8a1f6f..00000000 --- a/redisvl/utils/connection.py +++ /dev/null @@ -1,132 +0,0 @@ -import os -from typing import Optional, Union - -from redis import Redis -from redis.asyncio import Redis as ARedis - -# TODO: handle connection errors. - - -def get_address_from_env(): - """Get a redis connection from environment variables. - - Returns: - str: Redis URL - """ - addr = os.getenv("REDIS_URL", None) - if not addr: - raise ValueError("REDIS_URL env var not set") - return addr - - -class RedisConnection: - """Manages connections to a Redis database, supporting both synchronous and - asynchronous clients. - - This class allows for establishing and handling Redis connections using - either standard Redis or async Redis clients, based on the provided - configuration. - """ - - def __init__(self): - self._redis_url = None - self._kwargs = None - self.client: Optional[Union[Redis, ARedis]] = None - - def connect( - self, redis_url: Optional[str] = None, use_async: bool = False, **kwargs - ) -> None: - """Establishes a connection to the Redis database. - - This method sets up either a synchronous or asynchronous Redis client - based on the provided parameters. - - Args: - redis_url (Optional[str]): The URL of the Redis server to connect - to. If not provided, the environment variable REDIS_URL is used. - use_async (bool): If True, an asynchronous client is created. - Defaults to False. - **kwargs: Additional keyword arguments to be passed to the Redis - client constructor. - - Raises: - ValueError: If redis_url is not provided and REDIS_URL environment - variable is not set. - """ - self._redis_url = redis_url - self._kwargs = kwargs - if use_async: - self.client = self.get_async_redis_connection( - self._redis_url, **self._kwargs - ) - else: - self.client = self.get_redis_connection(self._redis_url, **self._kwargs) - - def set_client(self, client: Union[Redis, ARedis]) -> None: - """Sets the Redis client instance for the connection. - - This method allows setting a pre-configured Redis client, either - synchronous or asynchronous. - - Args: - client (Union[Redis, ARedis]): The Redis client instance to be set. - - Raises: - TypeError: If the provided client is not a valid Redis client - instance. - """ - if not (isinstance(client, Redis) or isinstance(client, ARedis)): - raise TypeError("Must provide a valid Redis client instance") - self.client = client - - @staticmethod - def get_redis_connection(url: Optional[str] = None, **kwargs) -> Redis: - """Creates and returns a synchronous Redis client. - - Args: - url (Optional[str]): The URL of the Redis server. If not provided, - the environment variable REDIS_URL is used. - **kwargs: Additional keyword arguments to be passed to the Redis - client constructor. - - Returns: - Redis: A synchronous Redis client instance. - - Raises: - ValueError: If url is not provided and REDIS_URL environment - variable is not set. - """ - if url: - client = Redis.from_url(url, **kwargs) - else: - try: - client = Redis.from_url(get_address_from_env()) - except ValueError: - raise ValueError("No Redis URL provided and REDIS_URL env var not set") - return client - - @staticmethod - def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> ARedis: - """Creates and returns an asynchronous Redis client. - - Args: - url (Optional[str]): The URL of the Redis server. If not provided, - the environment variable REDIS_URL is used. - **kwargs: Additional keyword arguments to be passed to the async - Redis client constructor. - - Returns: - ARedis: An asynchronous Redis client instance. - - Raises: - ValueError: If url is not provided and REDIS_URL environment - variable is not set. - """ - if url: - client = ARedis.from_url(url, **kwargs) - else: - try: - client = ARedis.from_url(get_address_from_env()) - except ValueError: - raise ValueError("No Redis URL provided and REDIS_URL env var not set") - return client diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py deleted file mode 100644 index 1bcc203a..00000000 --- a/redisvl/utils/utils.py +++ /dev/null @@ -1,86 +0,0 @@ -from typing import Any, List - -import numpy as np - -# required modules -REDIS_REQUIRED_MODULES = [ - {"name": "search", "ver": 20400}, - {"name": "searchlight", "ver": 20400}, -] - - -def make_dict(values: List[Any]): - # TODO make this a real function - i = 0 - di = {} - while i < len(values) - 1: - di[values[i]] = values[i + 1] - i += 2 - return di - - -def convert_bytes(data: Any) -> Any: - if isinstance(data, bytes): - try: - return data.decode("utf-8") - except: - return data - if isinstance(data, dict): - return dict(map(convert_bytes, data.items())) - if isinstance(data, list): - return list(map(convert_bytes, data)) - if isinstance(data, tuple): - return map(convert_bytes, data) - return data - - -def check_redis_modules_exist(client) -> None: - """Check if the correct Redis modules are installed.""" - installed_modules = client.module_list() - installed_modules = { - module[b"name"].decode("utf-8"): module for module in installed_modules - } - for module in REDIS_REQUIRED_MODULES: - if module["name"] in installed_modules and int( - installed_modules[module["name"]][b"ver"] - ) >= int( - module["ver"] - ): # type: ignore[call-overload] - return - # otherwise raise error - error_message = ( - "You must add the RediSearch (>= 2.4) module from Redis Stack. " - "Please refer to Redis Stack docs: https://redis.io/docs/stack/" - ) - raise ValueError(error_message) - - -async def check_async_redis_modules_exist(client) -> None: - """Check if the correct Redis modules are installed.""" - installed_modules = await client.module_list() - installed_modules = { - module[b"name"].decode("utf-8"): module for module in installed_modules - } - for module in REDIS_REQUIRED_MODULES: - if module["name"] in installed_modules and int( - installed_modules[module["name"]][b"ver"] - ) >= int( - module["ver"] - ): # type: ignore[call-overload] - return - # otherwise raise error - error_message = ( - "You must add the RediSearch (>= 2.4) module from Redis Stack. " - "Please refer to Redis Stack docs: https://redis.io/docs/stack/" - ) - raise ValueError(error_message) - - -def array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes: - """Convert a list of floats into a numpy byte string.""" - return np.array(array).astype(dtype).tobytes() - - -def buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]: - """Convert bytes into into a list of floats.""" - return np.frombuffer(buffer, dtype=dtype).tolist() diff --git a/redisvl/vectorize/base.py b/redisvl/vectorize/base.py index f4359639..46ba955d 100644 --- a/redisvl/vectorize/base.py +++ b/redisvl/vectorize/base.py @@ -2,7 +2,7 @@ from pydantic.v1 import BaseModel, validator -from redisvl.utils.utils import array_to_buffer +from redisvl.redis.utils import array_to_buffer class BaseVectorizer(BaseModel): diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py new file mode 100644 index 00000000..e9c69c14 --- /dev/null +++ b/tests/integration/test_connection.py @@ -0,0 +1,54 @@ +import os + +import pytest +from redis import Redis +from redis.asyncio import Redis as AsyncRedis +from redis.exceptions import ConnectionError + +from redisvl.redis.connection import RedisConnectionFactory, get_address_from_env + + +def test_get_address_from_env(redis_url): + assert get_address_from_env() == redis_url + + +def test_sync_redis_connection(redis_url): + client = RedisConnectionFactory.connect(redis_url) + assert client is not None + assert isinstance(client, Redis) + # Perform a simple operation + assert client.ping() + + +@pytest.mark.asyncio +async def test_async_redis_connection(redis_url): + client = RedisConnectionFactory.connect(redis_url, use_async=True) + assert client is not None + assert isinstance(client, AsyncRedis) + # Perform a simple operation + assert await client.ping() + + +def test_missing_env_var(): + redis_url = os.getenv("REDIS_URL") + if redis_url: + del os.environ["REDIS_URL"] + with pytest.raises(ValueError): + RedisConnectionFactory.connect() + os.environ["REDIS_URL"] = redis_url + + +def test_invalid_url_format(): + with pytest.raises(ValueError): + RedisConnectionFactory.connect(redis_url="invalid_url_format") + + +def test_unknown_redis(): + bad_client = RedisConnectionFactory.connect(redis_url="redis://fake:1234") + with pytest.raises(ConnectionError): + bad_client.ping() + + +def test_required_modules(client): + RedisConnectionFactory.validate_redis_modules(client) + RedisConnectionFactory.validate_async_redis_modules(client) diff --git a/tests/integration/test_simple.py b/tests/integration/test_flow.py similarity index 67% rename from tests/integration/test_simple.py rename to tests/integration/test_flow.py index 39919f56..76359872 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_flow.py @@ -1,38 +1,9 @@ -from pprint import pprint - import pytest from redisvl.index import SearchIndex from redisvl.query import VectorQuery +from redisvl.redis.utils import array_to_buffer from redisvl.schema import StorageType -from redisvl.utils.utils import array_to_buffer - -data = [ - { - "id": 1, - "user": "john", - "age": 1, - "job": "engineer", - "credit_score": "high", - "user_embedding": [0.1, 0.1, 0.5], - }, - { - "id": 2, - "user": "mary", - "age": 2, - "job": "doctor", - "credit_score": "low", - "user_embedding": [0.1, 0.1, 0.5], - }, - { - "id": 3, - "user": "joe", - "age": 3, - "job": "dentist", - "credit_score": "medium", - "user_embedding": [0.9, 0.9, 0.1], - }, -] fields_spec = [ {"name": "credit_score", "type": "tag"}, @@ -51,7 +22,6 @@ }, ] - hash_schema = { "index": { "name": "user_index_hash", @@ -72,28 +42,27 @@ @pytest.mark.parametrize("schema", [hash_schema, json_schema]) -def test_simple(client, schema): +def test_simple(client, schema, sample_data): index = SearchIndex.from_dict(schema) # assign client (only for testing) index.set_client(client) # create the index - index.create(overwrite=True) + index.create(overwrite=True, drop=True) # Prepare and load the data based on storage type def hash_preprocess(item: dict) -> dict: return {**item, "user_embedding": array_to_buffer(item["user_embedding"])} if index.storage_type == StorageType.HASH: - index.load(data, preprocess=hash_preprocess) + index.load(sample_data, preprocess=hash_preprocess) else: - # Load the prepared data into the index - print("DATA", data, flush=True) - index.load(data) + index.load(sample_data) + return_fields = ["user", "age", "job", "credit_score"] query = VectorQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", - return_fields=["user", "age", "job", "credit_score"], + return_fields=return_fields, num_results=3, ) @@ -105,7 +74,6 @@ def hash_preprocess(item: dict) -> dict: # users = list(results.docs) # print(len(users)) users = [doc for doc in results.docs] - pprint(users) assert users[0].user in ["john", "mary"] assert users[1].user in ["john", "mary"] @@ -116,9 +84,8 @@ def hash_preprocess(item: dict) -> dict: assert float(users[1].vector_distance) == 0.0 assert float(users[2].vector_distance) > 0 - print() - for doc in results.docs: - print("Score:", doc.vector_distance) - pprint(doc) + for doc1, doc2 in zip(results.docs, results_2): + for field in return_fields: + assert getattr(doc1, field) == doc2[field] index.delete() diff --git a/tests/integration/test_simple_async.py b/tests/integration/test_flow_async.py similarity index 52% rename from tests/integration/test_simple_async.py rename to tests/integration/test_flow_async.py index d35b9252..f9debae6 100644 --- a/tests/integration/test_simple_async.py +++ b/tests/integration/test_flow_async.py @@ -1,40 +1,11 @@ import time -from pprint import pprint -import numpy as np import pytest -from redisvl.index import SearchIndex +from redisvl.index import AsyncSearchIndex from redisvl.query import VectorQuery - -data = [ - { - "id": 1, - "user": "john", - "age": 1, - "job": "engineer", - "credit_score": "high", - "user_embedding": np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(), - }, - { - "id": 2, - "user": "mary", - "age": 2, - "job": "doctor", - "credit_score": "low", - "user_embedding": np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(), - }, - { - "id": 3, - "user": "joe", - "age": 3, - "job": "dentist", - "credit_score": "medium", - "user_embedding": np.array([0.9, 0.9, 0.1], dtype=np.float32).tobytes(), - }, -] - -query_vector = np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes() +from redisvl.redis.utils import array_to_buffer +from redisvl.schema import StorageType fields_spec = [ {"name": "credit_score", "type": "tag"}, @@ -55,37 +26,54 @@ hash_schema = { "index": { - "name": "user_index", - "prefix": "users", + "name": "user_index_hash", + "prefix": "users_hash", "storage_type": "hash", }, "fields": fields_spec, } +json_schema = { + "index": { + "name": "user_index_json", + "prefix": "users_json", + "storage_type": "json", + }, + "fields": fields_spec, +} + @pytest.mark.asyncio -async def test_simple(async_client): - index = SearchIndex.from_dict(hash_schema) +@pytest.mark.parametrize("schema", [hash_schema, json_schema]) +async def test_simple(async_client, schema, sample_data): + index = AsyncSearchIndex.from_dict(schema) # assign client (only for testing) index.set_client(async_client) # create the index - await index.acreate(overwrite=True) + await index.create(overwrite=True, drop=True) + + # Prepare and load the data based on storage type + async def hash_preprocess(item: dict) -> dict: + return {**item, "user_embedding": array_to_buffer(item["user_embedding"])} - # load data into the index in Redis - await index.aload(data) + if index.storage_type == StorageType.HASH: + await index.load(sample_data, preprocess=hash_preprocess) + else: + await index.load(sample_data) # wait for async index to create time.sleep(1) + return_fields = ["user", "age", "job", "credit_score"] query = VectorQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", - return_fields=["user", "age", "job", "credit_score"], + return_fields=return_fields, num_results=3, ) - results = await index.asearch(query.query, query_params=query.params) - results_2 = await index.aquery(query) + results = await index.search(query.query, query_params=query.params) + results_2 = await index.query(query) assert len(results.docs) == len(results_2) # make sure correct users returned @@ -102,9 +90,8 @@ async def test_simple(async_client): assert float(users[1].vector_distance) == 0.0 assert float(users[2].vector_distance) > 0 - print() - for doc in results.docs: - print("Score:", doc.vector_distance) - pprint(doc) + for doc1, doc2 in zip(results.docs, results_2): + for field in return_fields: + assert getattr(doc1, field) == doc2[field] - await index.adelete() + await index.delete() diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index c0d902ca..fb5b3f74 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -1,73 +1,43 @@ -import numpy as np import pytest from redis.commands.search.result import Result from redisvl.index import SearchIndex from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text +from redisvl.redis.utils import array_to_buffer -data = [ - { - "user": "john", - "age": 18, - "job": "engineer", - "credit_score": "high", - "location": "-122.4194,37.7749", - "user_embedding": np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(), - }, - { - "user": "derrick", - "age": 14, - "job": "doctor", - "credit_score": "low", - "location": "-122.4194,37.7749", - "user_embedding": np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(), - }, - { - "user": "nancy", - "age": 94, - "job": "doctor", - "credit_score": "high", - "location": "-122.4194,37.7749", - "user_embedding": np.array([0.7, 0.1, 0.5], dtype=np.float32).tobytes(), - }, - { - "user": "tyler", - "age": 100, - "job": "engineer", - "credit_score": "high", - "location": "-110.0839,37.3861", - "user_embedding": np.array([0.1, 0.4, 0.5], dtype=np.float32).tobytes(), - }, - { - "user": "tim", - "age": 12, - "job": "dermatologist", - "credit_score": "high", - "location": "-110.0839,37.3861", - "user_embedding": np.array([0.4, 0.4, 0.5], dtype=np.float32).tobytes(), - }, - { - "user": "taimur", - "age": 15, - "job": "CEO", - "credit_score": "low", - "location": "-110.0839,37.3861", - "user_embedding": np.array([0.6, 0.1, 0.5], dtype=np.float32).tobytes(), - }, - { - "user": "joe", - "age": 35, - "job": "dentist", - "credit_score": "medium", - "location": "-110.0839,37.3861", - "user_embedding": np.array([0.9, 0.9, 0.1], dtype=np.float32).tobytes(), - }, -] - - -@pytest.fixture(scope="module") -def index(): +# TODO expand to multiple schema types and sync + async + + +@pytest.fixture +def vector_query(): + return VectorQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score", "age", "job", "location"], + ) + + +@pytest.fixture +def filter_query(): + return FilterQuery( + return_fields=["user", "credit_score", "age", "job", "location"], + filter_expression=Tag("credit_score") == "high", + ) + + +@pytest.fixture +def range_query(): + return RangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score", "age", "job", "location"], + distance_threshold=0.2, + ) + + +@pytest.fixture +def index(sample_data): # construct a search index from the schema index = SearchIndex.from_dict( { @@ -101,7 +71,11 @@ def index(): # create the index (no data yet) index.create(overwrite=True) - index.load(data) + # Prepare and load the data + def hash_preprocess(item: dict) -> dict: + return {**item, "user_embedding": array_to_buffer(item["user_embedding"])} + + index.load(sample_data, preprocess=hash_preprocess) # run the test yield index @@ -123,7 +97,16 @@ def test_search_and_query(index): assert len(results.docs) == 7 for doc in results.docs: # ensure all return fields present - assert doc.user in ["john", "derrick", "nancy", "tyler", "tim", "taimur", "joe"] + assert doc.user in [ + "john", + "derrick", + "nancy", + "tyler", + "tim", + "taimur", + "joe", + "mary", + ] assert int(doc.age) in [18, 14, 94, 100, 12, 15, 35] assert doc.job in ["engineer", "doctor", "dermatologist", "CEO", "dentist"] assert doc.credit_score in ["high", "low", "medium"] @@ -158,43 +141,16 @@ def test_range_query(index): assert len(results) == 2 -def test_count_query(index): +def test_count_query(index, sample_data): c = CountQuery(FilterExpression("*")) results = index.query(c) - assert results == len(data) + assert results == len(sample_data) c = CountQuery(Tag("credit_score") == "high") results = index.query(c) assert results == 4 -@pytest.fixture -def vector_query(): - return VectorQuery( - vector=[0.1, 0.1, 0.5], - vector_field_name="user_embedding", - return_fields=["user", "credit_score", "age", "job", "location"], - ) - - -@pytest.fixture -def filter_query(): - return FilterQuery( - return_fields=["user", "credit_score", "age", "job", "location"], - filter_expression=Tag("credit_score") == "high", - ) - - -@pytest.fixture -def range_query(): - return RangeQuery( - vector=[0.1, 0.1, 0.5], - vector_field_name="user_embedding", - return_fields=["user", "credit_score", "age", "job", "location"], - distance_threshold=0.2, - ) - - def search( query, index, @@ -344,14 +300,14 @@ def test_filter_combinations(index, query): search(query, index, n & t & g, 1, age_range=(18, 99), location="-122.4194,37.7749") -def test_query_batch_vector_query(index, vector_query): +def test_query_batch_vector_query(index, vector_query, sample_data): batch_size = 2 all_results = [] for i, batch in enumerate(index.query_batch(vector_query, batch_size), start=1): all_results.extend(batch) assert len(batch) <= batch_size - expected_total_results = len(data) + expected_total_results = len(sample_data) expected_iterations = -(-expected_total_results // batch_size) # Ceiling division assert len(all_results) == expected_total_results assert i == expected_iterations diff --git a/tests/unit/test_async_search_index.py b/tests/unit/test_async_search_index.py new file mode 100644 index 00000000..e33a64bf --- /dev/null +++ b/tests/unit/test_async_search_index.py @@ -0,0 +1,139 @@ +import pytest + +from redisvl.index import AsyncSearchIndex +from redisvl.redis.utils import convert_bytes +from redisvl.schema import IndexSchema, StorageType + +fields = [{"name": "test", "type": "tag"}] + + +@pytest.fixture +def index_schema(): + return IndexSchema.from_dict({"index": {"name": "my_index"}, "fields": fields}) + + +@pytest.fixture +def async_index(index_schema): + return AsyncSearchIndex(schema=index_schema) + + +def test_search_index_properties(index_schema, async_index): + assert async_index.schema == index_schema + # custom settings + assert async_index.name == index_schema.index.name == "my_index" + assert async_index.client == None + # default settings + assert async_index.prefix == index_schema.index.prefix == "rvl" + assert async_index.key_separator == index_schema.index.key_separator == ":" + assert ( + async_index.storage_type == index_schema.index.storage_type == StorageType.HASH + ) + assert async_index.key("foo").startswith(async_index.prefix) + + +def test_search_index_no_prefix(index_schema): + # specify an explicitly empty prefix... + index_schema.index.prefix = "" + async_index = AsyncSearchIndex(schema=index_schema) + assert async_index.prefix == "" + assert async_index.key("foo") == "foo" + + +def test_search_index_redis_url(redis_url, index_schema): + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) + assert async_index.client + + async_index.disconnect() + assert async_index.client == None + + +def test_search_index_client(async_client, index_schema): + async_index = AsyncSearchIndex(schema=index_schema, redis_client=async_client) + assert async_index.client == async_client + + +def test_search_index_set_client(async_client, client, async_index): + async_index.set_client(async_client) + assert async_index.client == async_client + # should not be able to set the sync client here + with pytest.raises(TypeError): + async_index.set_client(client) + + async_index.disconnect() + assert async_index.client == None + + +@pytest.mark.asyncio +async def test_search_index_create(async_client, async_index): + async_index.set_client(async_client) + await async_index.create(overwrite=True, drop=True) + assert await async_index.exists() + assert async_index.name in convert_bytes( + await async_index.client.execute_command("FT._LIST") + ) + + +@pytest.mark.asyncio +async def test_search_index_delete(async_client, async_index): + async_index.set_client(async_client) + await async_index.create(overwrite=True, drop=True) + await async_index.delete(drop=True) + assert not await async_index.exists() + assert async_index.name not in convert_bytes( + await async_index.client.execute_command("FT._LIST") + ) + + +@pytest.mark.asyncio +async def test_search_index_load_and_fetch(async_client, async_index): + async_index.set_client(async_client) + await async_index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}] + await async_index.load(data, key_field="id") + + res = await async_index.fetch("1") + assert ( + res["test"] + == convert_bytes(await async_index.client.hget("rvl:1", "test")) + == "foo" + ) + + await async_index.delete(drop=True) + assert not await async_index.exists() + assert not await async_index.fetch("1") + + +@pytest.mark.asyncio +async def test_search_index_load_preprocess(async_client, async_index): + async_index.set_client(async_client) + await async_index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}] + + async def preprocess(record): + record["test"] = "bar" + return record + + await async_index.load(data, key_field="id", preprocess=preprocess) + res = await async_index.fetch("1") + assert ( + res["test"] + == convert_bytes(await async_index.client.hget("rvl:1", "test")) + == "bar" + ) + + async def bad_preprocess(record): + return 1 + + with pytest.raises(TypeError): + await async_index.load(data, key_field="id", preprocess=bad_preprocess) + + +@pytest.mark.asyncio +async def test_no_key_field(async_client, async_index): + async_index.set_client(async_client) + await async_index.create(overwrite=True, drop=True) + bad_data = [{"wrong_key": "1", "value": "test"}] + + # catch missing / invalid key_field + with pytest.raises(ValueError): + await async_index.load(bad_data, key_field="key") diff --git a/tests/unit/test_index.py b/tests/unit/test_index.py deleted file mode 100644 index 7a693752..00000000 --- a/tests/unit/test_index.py +++ /dev/null @@ -1,183 +0,0 @@ -import pytest - -from redisvl.index import SearchIndex -from redisvl.schema import IndexSchema -from redisvl.schema.fields import TagField -from redisvl.utils.utils import convert_bytes - -fields = [{"name": "test", "type": "tag"}] - - -@pytest.fixture -def index_schema(): - return IndexSchema.from_dict({"index": {"name": "my_index"}, "fields": fields}) - - -@pytest.fixture -def index(index_schema): - return SearchIndex(schema=index_schema) - - -def test_search_index_get_key(index): - si = index - key = si.key("foo") - assert key.startswith(si.prefix) - assert "foo" in key - key = si._storage._create_key({"id": "foo"}) - assert key.startswith(si.prefix) - assert "foo" not in key - - -def test_search_index_no_prefix(index_schema): - # specify None as the prefix... - si = index_schema.index.prefix = "" - si = SearchIndex(schema=index_schema) - key = si.key("foo") - assert not si.prefix - assert key == "foo" - - -def test_search_index_client(client, index_schema): - si = index_schema.index.prefix = "" - si = SearchIndex(schema=index_schema) - - si.set_client(client) - assert si.client == client == si._redis_conn.client - - -def test_search_index_create(client, index, index_schema): - si = index - si.set_client(client) - si.create(overwrite=True) - assert si.exists() - assert "my_index" in convert_bytes(si.client.execute_command("FT._LIST")) - - s1_2 = SearchIndex(schema=index_schema).set_client(client) - assert s1_2.info()["index_name"] == si.info()["index_name"] - - si.create(overwrite=False) - assert si.exists() - assert "my_index" in convert_bytes(si.client.execute_command("FT._LIST")) - - -def test_search_index_delete(client, index): - si = index - si.set_client(client) - si.create(overwrite=True) - si.delete() - assert "my_index" not in convert_bytes(si.client.execute_command("FT._LIST")) - - -def test_search_index_load(client, index): - si = index - si.set_client(client) - si.create(overwrite=True) - data = [{"id": "1", "value": "test"}] - si.load(data, key_field="id") - - assert convert_bytes(client.hget("rvl:1", "value")) == "test" - - -# def test_search_index_load_preprocess(client, index_schema): -# si = SearchIndex("my_index", fields=fields) -# si.set_client(client) -# si.create(overwrite=True) -# data = [{"id": "1", "value": "test"}] - -# def preprocess(record): -# record["test"] = "foo" -# return record - -# si.load(data, key_field="id", preprocess=preprocess) -# assert convert_bytes(client.hget("rvl:1", "test")) == "foo" - -# def bad_preprocess(record): -# return 1 - -# with pytest.raises(TypeError): -# si.load(data, key_field="id", preprocess=bad_preprocess) - - -@pytest.mark.asyncio -async def test_async_search_index_creation(async_client, index): - asi = index - asi.set_client(async_client) - - assert asi.client == async_client - - -@pytest.mark.asyncio -async def test_async_search_index_create(async_client, index): - asi = index - asi.set_client(async_client) - await asi.acreate(overwrite=True) - - indices = await asi.client.execute_command("FT._LIST") - assert "my_index" in convert_bytes(indices) - - -@pytest.mark.asyncio -async def test_async_search_index_delete(async_client, index): - asi = index - asi.set_client(async_client) - await asi.acreate(overwrite=True) - await asi.adelete() - - indices = await asi.client.execute_command("FT._LIST") - assert "my_index" not in convert_bytes(indices) - - -# @pytest.mark.asyncio -# async def test_async_search_index_load(async_client): -# asi = SearchIndex("my_index", fields=fields) -# asi.set_client(async_client) -# await asi.acreate(overwrite=True) -# data = [{"id": "1", "value": "test"}] -# await asi.aload(data, key_field="id") -# result = await async_client.hget("rvl:1", "value") -# assert convert_bytes(result) == "test" -# await asi.adelete() - - -# # --- Index Errors ---- - - -# def test_search_index_delete_nonexistent(client): -# si = SearchIndex("my_index", fields=fields) -# si.set_client(client) -# with pytest.raises(ValueError): -# si.delete() - - -# @pytest.mark.asyncio -# async def test_async_search_index_delete_nonexistent(async_client): -# asi = SearchIndex("my_index", fields=fields) -# asi.set_client(async_client) -# with pytest.raises(ValueError): -# await asi.adelete() - - -# # --- Data Errors ---- - - -# def test_no_key_field(client): -# si = SearchIndex("my_index", fields=fields) -# si.set_client(client) -# si.create(overwrite=True) -# bad_data = [{"wrong_key": "1", "value": "test"}] - -# # TODO make a better error -# with pytest.raises(ValueError): -# si.load(bad_data, key_field="key") - - -# @pytest.mark.asyncio -# async def test_async_search_index_load_bad_data(async_client): -# asi = SearchIndex("my_index", fields=fields) -# asi.set_client(async_client) -# await asi.acreate(overwrite=True) - -# # dictionary not list of dictionaries -# bad_data = {"wrong_key": "1", "value": "test"} -# with pytest.raises(TypeError): -# await asi.aload(bad_data, key_field="id") diff --git a/tests/unit/test_search_index.py b/tests/unit/test_search_index.py new file mode 100644 index 00000000..01c6788f --- /dev/null +++ b/tests/unit/test_search_index.py @@ -0,0 +1,120 @@ +import pytest + +from redisvl.index import SearchIndex +from redisvl.redis.utils import convert_bytes +from redisvl.schema import IndexSchema, StorageType + +fields = [{"name": "test", "type": "tag"}] + + +@pytest.fixture +def index_schema(): + return IndexSchema.from_dict({"index": {"name": "my_index"}, "fields": fields}) + + +@pytest.fixture +def index(index_schema): + return SearchIndex(schema=index_schema) + + +def test_search_index_properties(index_schema, index): + assert index.schema == index_schema + # custom settings + assert index.name == index_schema.index.name == "my_index" + assert index.client == None + # default settings + assert index.prefix == index_schema.index.prefix == "rvl" + assert index.key_separator == index_schema.index.key_separator == ":" + assert index.storage_type == index_schema.index.storage_type == StorageType.HASH + assert index.key("foo").startswith(index.prefix) + + +def test_search_index_no_prefix(index_schema): + # specify an explicitly empty prefix... + index_schema.index.prefix = "" + index = SearchIndex(schema=index_schema) + assert index.prefix == "" + assert index.key("foo") == "foo" + + +def test_search_index_redis_url(redis_url, index_schema): + index = SearchIndex(schema=index_schema, redis_url=redis_url) + assert index.client + + index.disconnect() + assert index.client == None + + +def test_search_index_client(client, index_schema): + index = SearchIndex(schema=index_schema, redis_client=client) + assert index.client == client + + +def test_search_index_set_client(async_client, client, index): + index.set_client(client) + assert index.client == client + # should not be able to set the sync client here + with pytest.raises(TypeError): + index.set_client(async_client) + + index.disconnect() + assert index.client == None + + +def test_search_index_create(client, index): + index.set_client(client) + index.create(overwrite=True, drop=True) + assert index.exists() + assert index.name in convert_bytes(index.client.execute_command("FT._LIST")) + + +def test_search_index_delete(client, index): + index.set_client(client) + index.create(overwrite=True, drop=True) + index.delete(drop=True) + assert not index.exists() + assert index.name not in convert_bytes(index.client.execute_command("FT._LIST")) + + +def test_search_index_load_and_fetch(client, index): + index.set_client(client) + index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}] + index.load(data, key_field="id") + + res = index.fetch("1") + assert res["test"] == convert_bytes(client.hget("rvl:1", "test")) == "foo" + + index.delete(drop=True) + assert not index.exists() + assert not index.fetch("1") + + +def test_search_index_load_preprocess(client, index): + index.set_client(client) + index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}] + + def preprocess(record): + record["test"] = "bar" + return record + + index.load(data, key_field="id", preprocess=preprocess) + res = index.fetch("1") + assert res["test"] == convert_bytes(client.hget("rvl:1", "test")) == "bar" + + def bad_preprocess(record): + return 1 + + with pytest.raises(TypeError): + index.load(data, key_field="id", preprocess=bad_preprocess) + + +def test_no_key_field(client, index): + index.set_client(client) + index.create(overwrite=True, drop=True) + bad_data = [{"wrong_key": "1", "value": "test"}] + + # catch missing / invalid key_field + with pytest.raises(ValueError): + index.load(bad_data, key_field="key")