From 1b62c12a9957d23567bd55e3e0c7747c21b265dc Mon Sep 17 00:00:00 2001 From: Sam Partee Date: Tue, 25 Jul 2023 00:39:49 -0700 Subject: [PATCH 1/3] Update query interface --- docs/api/filter.rst | 59 +++ docs/api/index.md | 5 +- docs/api/query.rst | 25 + docs/api/{redisvl_api.rst => searchindex.rst} | 7 +- docs/user_guide/getting_started_01.ipynb | 45 +- docs/user_guide/hybrid_queries_02.ipynb | 445 ++++++++++++++++++ docs/user_guide/index.md | 4 + docs/user_guide/llmcache_03.ipynb | 63 ++- redisvl/llmcache/semantic.py | 17 +- redisvl/query.py | 240 ++++++++-- 10 files changed, 803 insertions(+), 107 deletions(-) create mode 100644 docs/api/filter.rst create mode 100644 docs/api/query.rst rename docs/api/{redisvl_api.rst => searchindex.rst} (96%) create mode 100644 docs/user_guide/hybrid_queries_02.ipynb diff --git a/docs/api/filter.rst b/docs/api/filter.rst new file mode 100644 index 00000000..92786a34 --- /dev/null +++ b/docs/api/filter.rst @@ -0,0 +1,59 @@ +****** +Filter +****** + +.. _filter_api: + +TagFilter +========= + + +.. currentmodule:: redisvl.query + +.. autosummary:: + + TagFilter.__init__ + TagFilter.to_string + + +.. autoclass:: TagFilter + :show-inheritance: + :members: + :inherited-members: + + + +TextFilter +========== + + +.. currentmodule:: redisvl.query + +.. autosummary:: + + TextFilter.__init__ + TextFilter.to_string + + +.. autoclass:: TextFilter + :show-inheritance: + :members: + :inherited-members: + + +NumericFilter +============= + + +.. currentmodule:: redisvl.query + +.. autosummary:: + + NumericFilter.__init__ + NumericFilter.to_string + + +.. autoclass:: NumericFilter + :show-inheritance: + :members: + :inherited-members: diff --git a/docs/api/index.md b/docs/api/index.md index aef6f340..3fe601da 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -9,7 +9,10 @@ myst: ```{toctree} :caption: RedisVL +:maxdepth: 2 -redisvl_api +searchindex +query +filter ``` diff --git a/docs/api/query.rst b/docs/api/query.rst new file mode 100644 index 00000000..31943cb3 --- /dev/null +++ b/docs/api/query.rst @@ -0,0 +1,25 @@ + +***** +Query +***** + +VectorQuery +=========== + +.. _query_api: + +.. currentmodule:: redisvl.query + +.. autosummary:: + + VectorQuery.__init__ + VectorQuery.set_filter + VectorQuery.get_filter + VectorQuery.query + VectorQuery.params + + +.. autoclass:: VectorQuery + :show-inheritance: + :members: + :inherited-members: diff --git a/docs/api/redisvl_api.rst b/docs/api/searchindex.rst similarity index 96% rename from docs/api/redisvl_api.rst rename to docs/api/searchindex.rst index 4a459e4b..a5dec643 100644 --- a/docs/api/redisvl_api.rst +++ b/docs/api/searchindex.rst @@ -1,8 +1,7 @@ -*********** -RedisVL API -*********** - +***** +Index +***** SearchIndex =========== diff --git a/docs/user_guide/getting_started_01.ipynb b/docs/user_guide/getting_started_01.ipynb index ab23a8ed..2be9642e 100644 --- a/docs/user_guide/getting_started_01.ipynb +++ b/docs/user_guide/getting_started_01.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -58,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -131,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -171,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -189,15 +189,16 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m14:26:12\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[17001]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m14:26:12\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[17001]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n" + "\u001b[32m00:29:48\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[40909]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m00:29:48\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[40909]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n", + "\u001b[32m00:29:48\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[40909]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. my_index\n" ] } ], @@ -217,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -237,30 +238,28 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "from redisvl.query import create_vector_query\n", + "from redisvl.query import VectorQuery\n", "\n", "# create a vector query returning a number of results\n", "# with specific fields to return.\n", - "query = create_vector_query(\n", - " return_fields=[\"users\", \"age\", \"job\", \"credit_score\", \"vector_score\"],\n", - " number_of_results=3,\n", - " vector_field_name=\"user_embedding\"\n", + "query = VectorQuery(\n", + " vector=[0.1, 0.1, 0.5],\n", + " vector_field_name=\"user_embedding\",\n", + " return_fields=[\"user\", \"age\", \"job\", \"credit_score\", \"vector_distance\"],\n", + " num_results=3\n", ")\n", "\n", - "# establish a query vector to search against the data in Redis\n", - "query_vector = np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes()\n", - "\n", "# use the SearchIndex instance (or Redis client) to execute the query\n", - "results = index.search(query, query_params={\"vector\": query_vector})" + "results = index.search(query.query, query_params=query.params)" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -268,17 +267,17 @@ "output_type": "stream", "text": [ "Score: 0\n", - "Document {'id': 'v1:john', 'payload': None, 'vector_score': '0', 'age': '1', 'job': 'engineer', 'credit_score': 'high'}\n", + "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'age': '1', 'job': 'engineer', 'credit_score': 'high'}\n", "Score: 0\n", - "Document {'id': 'v1:mary', 'payload': None, 'vector_score': '0', 'age': '2', 'job': 'doctor', 'credit_score': 'low'}\n", + "Document {'id': 'v1:mary', 'payload': None, 'vector_distance': '0', 'user': 'mary', 'age': '2', 'job': 'doctor', 'credit_score': 'low'}\n", "Score: 0.653301358223\n", - "Document {'id': 'v1:joe', 'payload': None, 'vector_score': '0.653301358223', 'age': '3', 'job': 'dentist', 'credit_score': 'medium'}\n" + "Document {'id': 'v1:joe', 'payload': None, 'vector_distance': '0.653301358223', 'user': 'joe', 'age': '3', 'job': 'dentist', 'credit_score': 'medium'}\n" ] } ], "source": [ "for doc in results.docs:\n", - " print(\"Score:\", doc.vector_score)\n", + " print(\"Score:\", doc.vector_distance)\n", " print(doc)\n" ] }, diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb new file mode 100644 index 00000000..17d6d9f3 --- /dev/null +++ b/docs/user_guide/hybrid_queries_02.ipynb @@ -0,0 +1,445 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Complex Queries\n", + "\n", + "In this notebook, we will explore more complex queries that can be performed with ``redisvl``\n", + "\n", + "Before running this notebook, be sure to\n", + "1. Have installed ``redisvl`` and have that environment active for this notebook.\n", + "2. Have a running Redis instance with RediSearch > 2.4 running." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from pprint import pprint\n", + "\n", + "data = [\n", + " {'user': 'john', 'age': 18, 'job': 'engineer', 'credit_score': 'high'},\n", + " {'user': 'derrick', 'age': 14, 'job': 'doctor', 'credit_score': 'low'},\n", + " {'user': 'nancy', 'age': 94, 'job': 'doctor', 'credit_score': 'high'},\n", + " {'user': 'tyler', 'age': 100, 'job': 'engineer', 'credit_score': 'high'},\n", + " {'user': 'tim', 'age': 12, 'job': 'dermatologist', 'credit_score': 'high'},\n", + " {'user': 'taimur', 'age': 15, 'job': 'CEO', 'credit_score': 'low'},\n", + " {'user': 'joe', 'age': 35, 'job': 'dentist', 'credit_score': 'medium'}\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'age': 18,\n", + " 'credit_score': 'high',\n", + " 'job': 'engineer',\n", + " 'user': 'john',\n", + " 'user_embedding': b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'},\n", + " {'age': 14,\n", + " 'credit_score': 'low',\n", + " 'job': 'doctor',\n", + " 'user': 'derrick',\n", + " 'user_embedding': b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'},\n", + " {'age': 94,\n", + " 'credit_score': 'high',\n", + " 'job': 'doctor',\n", + " 'user': 'nancy',\n", + " 'user_embedding': b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'},\n", + " {'age': 100,\n", + " 'credit_score': 'high',\n", + " 'job': 'engineer',\n", + " 'user': 'tyler',\n", + " 'user_embedding': b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'},\n", + " {'age': 12,\n", + " 'credit_score': 'high',\n", + " 'job': 'dermatologist',\n", + " 'user': 'tim',\n", + " 'user_embedding': b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'},\n", + " {'age': 15,\n", + " 'credit_score': 'low',\n", + " 'job': 'CEO',\n", + " 'user': 'taimur',\n", + " 'user_embedding': b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'},\n", + " {'age': 35,\n", + " 'credit_score': 'medium',\n", + " 'job': 'dentist',\n", + " 'user': 'joe',\n", + " 'user_embedding': b'fff?fff?\\xcd\\xcc\\xcc='}]\n" + ] + } + ], + "source": [ + "# converted to bytes for redis\n", + "vectors = [\n", + " np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(),\n", + " np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(),\n", + " np.array([0.7, 0.1, 0.5], dtype=np.float32).tobytes(),\n", + " np.array([0.1, 0.4, 0.5], dtype=np.float32).tobytes(),\n", + " np.array([0.4, 0.4, 0.5], dtype=np.float32).tobytes(),\n", + " np.array([0.6, 0.1, 0.5], dtype=np.float32).tobytes(),\n", + " np.array([0.9, 0.9, 0.1], dtype=np.float32).tobytes(),\n", + "]\n", + "\n", + "for record, vector in zip(data, vectors):\n", + " record[\"user_embedding\"] = vector\n", + "\n", + "pprint(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "schema = {\n", + " \"index\": {\n", + " \"name\": \"user_index\",\n", + " \"prefix\": \"v1\",\n", + " \"key_field\": \"user\",\n", + " \"storage_type\": \"hash\",\n", + " },\n", + " \"fields\": {\n", + " \"tag\": [{\"name\": \"credit_score\"}],\n", + " \"text\": [{\"name\": \"job\"}],\n", + " \"numeric\": [{\"name\": \"age\"}],\n", + " \"vector\": [{\n", + " \"name\": \"user_embedding\",\n", + " \"dims\": 3,\n", + " \"distance_metric\": \"cosine\",\n", + " \"algorithm\": \"flat\",\n", + " \"datatype\": \"float32\"}\n", + " ]\n", + " },\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from redisvl.index import SearchIndex\n", + "\n", + "# construct a search index from the schema\n", + "index = SearchIndex.from_dict(schema)\n", + "\n", + "# connect to local redis instance\n", + "index.connect(\"redis://localhost:6379\")\n", + "\n", + "# create the index (no data yet)\n", + "index.create(overwrite=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m21:10:34\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[39050]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m21:10:34\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[39050]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n", + "\u001b[32m21:10:34\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[39050]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. my_index\n" + ] + } + ], + "source": [ + "# use the CLI to see the created index\n", + "!rvl index listall" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# load expects an iterable of dictionaries\n", + "index.load(data)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Executing Hybrid Queries\n", + "\n", + "Hybrid queries are queries that combine multiple types of filters. For example, you may want to search for a user that is a certain age, has a certain job, and is within a certain distance of a location. This is a hybrid query that combines numeric, tag, and geographic filters." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tag Filters\n", + "\n", + "Tag filters are filters that are applied to tag fields. These are fields that are not tokenized and are used to store a single categorical value." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document {'id': 'v1:john', 'payload': None, 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", + "Document {'id': 'v1:tyler', 'payload': None, 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", + "Document {'id': 'v1:tim', 'payload': None, 'user': 'tim', 'credit_score': 'high', 'age': '12', 'job': 'dermatologist'}\n", + "Document {'id': 'v1:nancy', 'payload': None, 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" + ] + } + ], + "source": [ + "from redisvl.query import VectorQuery, TagFilter, NumericFilter\n", + "\n", + "t = TagFilter(\"credit_score\", \"high\")\n", + "\n", + "v = VectorQuery([0.1, 0.1, 0.5],\n", + " \"user_embedding\",\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\"],\n", + " hybrid_filter=t)\n", + "\n", + "\n", + "results = index.search(v.query, query_params=v.params)\n", + "for doc in results.docs:\n", + " print(doc)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Numeric Filters\n", + "\n", + "Numeric filters are filters that are applied to numeric fields and can be used to isolate a range of values for a given field." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document {'id': 'v1:john', 'payload': None, 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", + "Document {'id': 'v1:tyler', 'payload': None, 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", + "Document {'id': 'v1:nancy', 'payload': None, 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n", + "Document {'id': 'v1:joe', 'payload': None, 'user': 'joe', 'credit_score': 'medium', 'age': '35', 'job': 'dentist'}\n" + ] + } + ], + "source": [ + "n = NumericFilter(\"age\", 18, 100)\n", + "\n", + "v.set_filter(n)\n", + "\n", + "results = index.search(v.query, query_params=v.params)\n", + "for doc in results.docs:\n", + " print(doc)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Text Filters\n", + "\n", + "Text filters are filters that are applied to text fields. These filters are applied to the entire text field. For example, if you have a text field that contains the text \"The quick brown fox jumps over the lazy dog\", a text filter of \"quick\" will match this text field." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document {'id': 'v1:derrick', 'payload': None, 'user': 'derrick', 'credit_score': 'low', 'age': '14', 'job': 'doctor'}\n", + "Document {'id': 'v1:nancy', 'payload': None, 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" + ] + } + ], + "source": [ + "from redisvl.query import TextFilter\n", + "\n", + "text_filter = TextFilter(\"job\", \"doctor\")\n", + "v.set_filter(text_filter)\n", + "\n", + "results = index.search(v.query, query_params=v.params)\n", + "for doc in results.docs:\n", + " print(doc)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combining Filters\n", + "\n", + "In this example, we will combine a numeric filter with a tag filter. We will search for users that are between the ages of 20 and 30 and have a job of \"engineer\"." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", + "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", + "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" + ] + } + ], + "source": [ + "t = TagFilter(\"credit_score\", \"high\")\n", + "n = NumericFilter(\"age\", 18, 100)\n", + "t += n\n", + "\n", + "v = VectorQuery([0.1, 0.1, 0.5],\n", + " \"user_embedding\",\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"vector_distance\"],\n", + " hybrid_filter=t)\n", + "\n", + "\n", + "results = index.search(v.query, query_params=v.params)\n", + "for doc in results.docs:\n", + " print(doc)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Negation\n", + "\n", + "The next example will combine the tag field with a negation. We will search for users that are in a numeric range." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document {'id': 'v1:tim', 'payload': None, 'vector_distance': '0.158809006214', 'user': 'tim', 'credit_score': 'high', 'age': '12', 'job': 'dermatologist'}\n" + ] + } + ], + "source": [ + "t = TagFilter(\"credit_score\", \"high\")\n", + "n = NumericFilter(\"age\", 18, 100)\n", + "t -= n\n", + "\n", + "v.set_filter(t)\n", + "\n", + "results = index.search(v.query, query_params=v.params)\n", + "for doc in results.docs:\n", + " print(doc)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Union of Filters\n", + "\n", + "This example will show how to combine multiple filters with a union. We will search for users that are either between the ages of 18 to 100 and have a high credit score." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", + "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", + "Document {'id': 'v1:tim', 'payload': None, 'vector_distance': '0.158809006214', 'user': 'tim', 'credit_score': 'high', 'age': '12', 'job': 'dermatologist'}\n", + "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n", + "Document {'id': 'v1:joe', 'payload': None, 'vector_distance': '0.653301358223', 'user': 'joe', 'credit_score': 'medium', 'age': '35', 'job': 'dentist'}\n" + ] + } + ], + "source": [ + "t = TagFilter(\"credit_score\", \"high\")\n", + "n = NumericFilter(\"age\", 18, 100)\n", + "t &= n\n", + "\n", + "v.set_filter(t)\n", + "\n", + "results = index.search(v.query, query_params=v.params)\n", + "for doc in results.docs:\n", + " print(doc)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.13 ('redisvl2')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "9b1e6e9c2967143209c2f955cb869d1d3234f92dc4787f49f155f3abbdfb1316" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index 0a55008a..6ce2f868 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -15,18 +15,22 @@ RedisVL is still under active development and is subject to change at any time. ```{toctree} :caption: Introduction +:maxdepth: 3 getting_started_01 +hybrid_queries_02 ``` ```{toctree} :caption: Providers +:maxdepth: 3 embedding_creation ``` ```{toctree} :caption: LLMCache +:maxdepth: 3 llmcache_03 ``` diff --git a/docs/user_guide/llmcache_03.ipynb b/docs/user_guide/llmcache_03.ipynb index 1a5d22c1..c4c33aa6 100644 --- a/docs/user_guide/llmcache_03.ipynb +++ b/docs/user_guide/llmcache_03.ipynb @@ -20,11 +20,10 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ - "import os\n", "import openai\n", "openai.api_key = \"sk-\"\n", "\n", @@ -39,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -66,18 +65,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 46, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/sam.partee/.virtualenvs/rvl/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "from redisvl.llmcache.semantic import SemanticCache\n", "cache = SemanticCache(\n", @@ -88,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 47, "metadata": {}, "outputs": [ { @@ -97,7 +87,7 @@ "[]" ] }, - "execution_count": 4, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } @@ -109,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -119,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -128,7 +118,7 @@ "['Paris']" ] }, - "execution_count": 6, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -140,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -149,7 +139,7 @@ "[]" ] }, - "execution_count": 7, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -161,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -170,7 +160,7 @@ "['Paris']" ] }, - "execution_count": 8, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -183,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -192,7 +182,7 @@ "[]" ] }, - "execution_count": 10, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -202,6 +192,15 @@ "cache.check(\"What is the capital of Spain?\")" ] }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "cache.index.delete()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -213,7 +212,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ @@ -229,14 +228,14 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Time taken without cache 0.7418899536132812\n" + "Time taken without cache 0.6192829608917236\n" ] } ], @@ -250,15 +249,15 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Time Taken with cache: 0.07415914535522461\n", - "Percentage of time saved: 90.0%\n" + "Time Taken with cache: 0.05961775779724121\n", + "Percentage of time saved: 90.37%\n" ] } ], @@ -272,7 +271,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index 90d93485..af177379 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -6,7 +6,7 @@ from redisvl.llmcache.base import BaseLLMCache from redisvl.providers import HuggingfaceProvider from redisvl.providers.base import BaseProvider -from redisvl.query import create_vector_query +from redisvl.query import VectorQuery from redisvl.utils.utils import array_to_buffer @@ -87,12 +87,23 @@ def check( prompt_vector = array_to_buffer(vector) else: prompt_vector = array_to_buffer(self._provider.embed(prompt)) # type: ignore - results = self._index.search(query, query_params={"vector": prompt_vector}) + + # TODO: Come back if vector_distance is changed + fields.append("vector_distance") + + v = VectorQuery( + vector=prompt_vector, + vector_field_name="prompt_vector", + return_fields=fields, + number_of_results=num_results + ) + + results = self._index.search(v.query, query_params=v.params) cache_hits = [] for doc in results.docs: self._refresh_ttl(doc.id) - sim = similarity(doc.vector_score) + sim = similarity(doc.vector_distance) if sim > self.threshold: cache_hits.append(doc.response) return cache_hits diff --git a/redisvl/query.py b/redisvl/query.py index 57deca79..87c0a6e3 100644 --- a/redisvl/query.py +++ b/redisvl/query.py @@ -1,48 +1,200 @@ -import typing as t +from typing import Any, Dict, List, Optional +import numpy as np from redis.commands.search.query import Query +from redisvl.utils.utils import TokenEscaper, array_to_buffer -def create_vector_query( - return_fields: t.List[str], - search_type: str = "KNN", - number_of_results: int = 20, - vector_field_name: str = "vector", - vector_param_name: str = "vector", - return_score: bool = True, - sort: bool = True, - tags: str = "*", -) -> Query: - """Create a vector query for use with a SearchIndex - - Args: - return_fields (t.List[str]): A list of fields to return in the query results - search_type (str, optional): The type of search to perform. Defaults to "KNN". - number_of_results (int, optional): The number of results to return. Defaults to 20. - vector_field_name (str, optional): The name of the vector field in the index. Defaults to "vector". - vector_param_name (str, optional): The name of the query param for searches. Defaults to "vector". - return_score (bool, optional): Whether to return the score in the query results. Defaults to True. - sort (bool, optional): Whether to sort the results by score. Defaults to True. - tags (str, optional): tag string to filter the results by. Defaults to "*". - - example usage: - vector_param = "user_vector" - query = create_vector_query( - return_fields=["users", "age", "job", "credit_score"], - search_type="KNN", - number_of_results=3, - vector_field_name="user_embedding", - vector_param_name="user_vector", - tags="*") - index.search(query, query_params={"user_vector": query_vector}) - - Returns: - Query: A Query object that can be used with SearchIndex.search - """ - base_query = f"{tags}=>[{search_type} {number_of_results} @{vector_field_name} ${vector_param_name} AS vector_score]" - if return_score: - return_fields.append("vector_score") - query = Query(base_query).return_fields(*return_fields).dialect(2) - if sort: - query.sort_by("vector_score") - return query + +class Filter: + escaper = TokenEscaper() + + def __init__(self, field): + self._field = field + self._filters = [] + + def __str__(self): + base = "(" + self.to_string() + if self._filters: + base += " ".join(self._filters) + return base + ")" + + def __iadd__(self, other): + "intersection '+='" + self._filters.append(f" {other.to_string()}") + return self + + def __iand__(self, other): + "union '&='" + self._filters.append(f" |{other.to_string()}") + return self + + def __isub__(self, other): + "subtract '-='" + self._filters.append(f" -{other.to_string()}") + return self + + def __ixor__(self, other): + "With optional '^='" + self._filters.append(f" ~{other.to_string()}") + return self + + def to_string(self) -> str: + raise NotImplementedError + + +class TagFilter(Filter): + def __init__(self, field, tags: List[str]): + super().__init__(field) + self.tags = tags + + def to_string(self) -> str: + """Converts the tag filter to a string. + + Returns: + str: The tag filter as a string. + """ + if not isinstance(self.tags, list): + self.tags = [self.tags] + return ( + "@" + + self._field + + ":{" + + " | ".join([self.escaper.escape(tag) for tag in self.tags]) + + "}" + ) + + +class NumericFilter(Filter): + + def __init__(self, field, minval, maxval, minExclusive=False, maxExclusive=False): + """Filter for Numeric fields. + + Args: + field (str): The field to filter on. + minval (int): The minimum value. + maxval (int): The maximum value. + minExclusive (bool, optional): Whether the minimum value is exclusive. Defaults to False. + maxExclusive (bool, optional): Whether the maximum value is exclusive. Defaults to False. + """ + self.top = maxval if not maxExclusive else f"({maxval}" + self.bottom = minval if not minExclusive else f"{minval})" + super().__init__(field) + + def to_string(self): + return "@" + self._field + ":[" + str(self.bottom) + " " + str(self.top) + "]" + + +class TextFilter(Filter): + + def __init__(self, field, text: str): + """Filter for Text fields. + Args: + field (str): The field to filter on. + text (str): The text to filter on. + """ + super().__init__(field) + self.text = text + + def to_string(self) -> str: + """Converts the filter to a string. + + Returns: + str: The filter as a string. + """ + return "@" + self._field + ":" + self.escaper.escape(self.text) + + +class BaseQuery: + def __init__( + self, return_fields: Optional[List[str]] = None, num_results: Optional[int] = 10 + ): + self._return_fields = return_fields + self._num_results = num_results + + @property + def query(self): + pass + + @property + def params(self): + pass + + +class VectorQuery(BaseQuery): + dtypes = { + "float32": np.float32, + "float64": np.float64, + } + + def __init__( + self, + vector: List[float], + vector_field_name: str, + return_fields: List[str], + hybrid_filter: Filter = None, + dtype: str = "float32", + num_results: Optional[int] = 10, + ): + """Query for vector fields + + Args: + vector (List[float]): The vector to query for. + vector_field_name (str): The name of the vector field + return_fields (List[str]): The fields to return. + hybrid_filter (Filter, optional): A filter to apply to the query. Defaults to None. + dtype (str, optional): The dtype of the vector. Defaults to "float32". + num_results (Optional[int], optional): The number of results to return. Defaults to 10. + """ + super().__init__(return_fields, num_results) + self._vector = vector + self._field = vector_field_name + self._dtype = dtype.lower() + if hybrid_filter: + self.set_filter(hybrid_filter) + else: + self._filter = "*" + + def set_filter(self, hybrid_filter: Filter): + """Set the filter for the query. + + Args: + hybrid_filter (Filter): The filter to apply to the query. + """ + if not isinstance(hybrid_filter, Filter): + raise TypeError("hybrid_filter must be of type redisvl.query.Filter") + self._filter = str(hybrid_filter) + + def get_filter(self): + """Get the filter for the query. + + Returns: + Filter: The filter for the query. + """ + return self._filter + + @property + def query(self): + """Return a Redis-Py Query object representing the query. + + Returns: + redis.commands.search.query.Query: The query object. + """ + base_query = f"{self._filter}=>[KNN {self._num_results} @{self._field} $vector AS vector_distance]" + query = ( + Query(base_query) + .return_fields(*self._return_fields) + .sort_by("vector_distance") + .paging(0, self._num_results) + .dialect(2) + ) + return query + + @property + def params(self): + """Return the parameters for the query. + + Returns: + Dict[str, Any]: The parameters for the query. + """ + return {"vector": array_to_buffer(self._vector, dtype=self.dtypes[self._dtype])} From 65046a5e8a924bf9fdb1539926a3dc20bbd3d0cd Mon Sep 17 00:00:00 2001 From: Sam Partee Date: Tue, 25 Jul 2023 16:42:00 -0700 Subject: [PATCH 2/3] tests for query and filter --- conftest.py | 8 +- docs/api/cache.rst | 28 +++ docs/api/index.md | 1 + docs/user_guide/hybrid_queries_02.ipynb | 26 +-- redisvl/index.py | 61 +++++-- redisvl/llmcache/base.py | 19 -- redisvl/llmcache/semantic.py | 131 ++++++++++---- redisvl/query.py | 31 ++-- redisvl/schema.py | 3 +- tests/integration/test_llmcache.py | 13 -- tests/integration/test_query.py | 220 ++++++++++++++++++++++++ tests/integration/test_simple.py | 20 +-- tests/integration/test_simple_async.py | 19 +- tests/test_filter.py | 51 ++++++ tests/test_index.py | 28 +-- 15 files changed, 515 insertions(+), 144 deletions(-) create mode 100644 docs/api/cache.rst create mode 100644 tests/integration/test_query.py create mode 100644 tests/test_filter.py diff --git a/conftest.py b/conftest.py index 2f817991..619e308b 100644 --- a/conftest.py +++ b/conftest.py @@ -33,4 +33,10 @@ def event_loop(): except RuntimeError: loop = asyncio.new_event_loop() yield loop - loop.close() \ No newline at end of file + loop.close() + +@pytest.fixture +def clear_db(): + redis.flushall() + yield + redis.flushall() \ No newline at end of file diff --git a/docs/api/cache.rst b/docs/api/cache.rst new file mode 100644 index 00000000..e08df9bb --- /dev/null +++ b/docs/api/cache.rst @@ -0,0 +1,28 @@ + +******** +LLMCache +******** + +SemanticCache +============= + +.. _semantic_cache_api: + +.. currentmodule:: redisvl.llmcache.semantic + +.. autosummary:: + + SemanticCache.__init__ + SemanticCache.check + SemanticCache.store + SemanticCache.set_threshold + SemanticCache.threshold + SemanticCache.index + SemanticCache.ttl + SemanticCache.set_ttl + + +.. autoclass:: SemanticCache + :show-inheritance: + :members: + :inherited-members: diff --git a/docs/api/index.md b/docs/api/index.md index 3fe601da..cbda4091 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -14,5 +14,6 @@ myst: searchindex query filter +cache ``` diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb index 17d6d9f3..43612c11 100644 --- a/docs/user_guide/hybrid_queries_02.ipynb +++ b/docs/user_guide/hybrid_queries_02.ipynb @@ -154,9 +154,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m21:10:34\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[39050]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m21:10:34\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[39050]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n", - "\u001b[32m21:10:34\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[39050]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. my_index\n" + "\u001b[32m16:36:51\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[74676]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m16:36:51\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[74676]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n", + "\u001b[32m16:36:51\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[74676]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. my_index\n" ] } ], @@ -203,10 +203,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Document {'id': 'v1:john', 'payload': None, 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", - "Document {'id': 'v1:tim', 'payload': None, 'user': 'tim', 'credit_score': 'high', 'age': '12', 'job': 'dermatologist'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" + "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", + "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", + "Document {'id': 'v1:tim', 'payload': None, 'vector_distance': '0.158809006214', 'user': 'tim', 'credit_score': 'high', 'age': '12', 'job': 'dermatologist'}\n", + "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" ] } ], @@ -244,10 +244,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Document {'id': 'v1:john', 'payload': None, 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n", - "Document {'id': 'v1:joe', 'payload': None, 'user': 'joe', 'credit_score': 'medium', 'age': '35', 'job': 'dentist'}\n" + "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", + "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", + "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n", + "Document {'id': 'v1:joe', 'payload': None, 'vector_distance': '0.653301358223', 'user': 'joe', 'credit_score': 'medium', 'age': '35', 'job': 'dentist'}\n" ] } ], @@ -279,8 +279,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Document {'id': 'v1:derrick', 'payload': None, 'user': 'derrick', 'credit_score': 'low', 'age': '14', 'job': 'doctor'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" + "Document {'id': 'v1:derrick', 'payload': None, 'vector_distance': '0', 'user': 'derrick', 'credit_score': 'low', 'age': '14', 'job': 'doctor'}\n", + "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" ] } ], diff --git a/redisvl/index.py b/redisvl/index.py index f0e2afcd..e2040fde 100644 --- a/redisvl/index.py +++ b/redisvl/index.py @@ -1,5 +1,6 @@ import asyncio from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional +from uuid import uuid4 if TYPE_CHECKING: from redis.commands.search.field import Field @@ -20,17 +21,17 @@ class SearchIndexBase: def __init__( self, name: str, - storage_type: str = "hash", - key_field: str = "id", - prefix: str = "", + prefix: str = "rvl", + storage_type: Optional[str] = "hash", + key_field: Optional[str] = None, fields: Optional[List["Field"]] = None, ): self._name = name - self._key_field = key_field - self._storage = storage_type self._prefix = prefix + self._storage = storage_type self._fields = fields self._redis_conn: Optional[redis.Redis] = None + self._key_field = key_field def set_client(self, client: redis.Redis): self._redis_conn = client @@ -111,6 +112,28 @@ def disconnect(self): """Disconnect from the Redis instance""" self._redis_conn = None + def _get_key_field(self, record: Dict[str, Any]): + """Get the key field for this index + + Args: + record (Dict[str, Any]): A dictionary containing the record to be indexed + + Returns: + str: The key to be used for a given record + + Raises: + ValueError: If the key field is not found in the record + """ + if self._key_field is None: + return uuid4().hex + else: + try: + return record[self._key_field] # type: ignore + except KeyError: + raise ValueError( + f"Key field {self._key_field} not found in record {record}" + ) + @check_connected("_redis_conn") def info(self) -> Dict[str, Any]: """Get information about the index @@ -159,12 +182,12 @@ class SearchIndex(SearchIndexBase): def __init__( self, name: str, - storage_type: str = "hash", - key_field: str = "id", - prefix: str = "", + prefix: str = "rvl", + storage_type: Optional[str] = "hash", + key_field: Optional[str] = None, fields: Optional[List["Field"]] = None, ): - super().__init__(name, storage_type, key_field, prefix, fields) + super().__init__(name, prefix, storage_type, key_field, fields) def connect(self, url: Optional[str] = None, **kwargs): """Connect to a Redis instance @@ -194,7 +217,7 @@ def create(self, overwrite: Optional[bool] = False): if not self._fields: raise ValueError("No fields defined for index") - if self._index_exists() and overwrite: + if self.exists() and overwrite: self.delete() # set storage_type, default to hash @@ -240,11 +263,11 @@ def load(self, data: Iterable[Dict[str, Any]], **kwargs): for record in data: # TODO don't use colon if no prefix - key = f"{self._prefix}:{str(record[self._key_field])}" + key = f"{self._prefix}:{self._get_key_field(record)}" self._redis_conn.hset(key, mapping=record) # type: ignore @check_connected("_redis_conn") - def _index_exists(self) -> bool: + def exists(self) -> bool: """Check if the index exists in Redis Returns: @@ -258,12 +281,12 @@ class AsyncSearchIndex(SearchIndexBase): def __init__( self, name: str, - storage_type: str = "hash", - key_field: str = "id", - prefix: str = "", + prefix: str = "rvl", + storage_type: Optional[str] = "hash", + key_field: Optional[str] = None, fields: Optional[List["Field"]] = None, ): - super().__init__(name, storage_type, key_field, prefix, fields) + super().__init__(name, prefix, storage_type, key_field, fields) def connect(self, url: Optional[str] = None, **kwargs): """Connect to a Redis instance @@ -287,7 +310,7 @@ async def create(self, overwrite: Optional[bool] = False): Raises: redis.exceptions.ResponseError: If the index already exists """ - exists = await self._index_exists() + exists = await self.exists() if exists and overwrite: await self.delete() @@ -330,14 +353,14 @@ async def load(self, data: Iterable[Dict[str, Any]], concurrency: int = 10): async def load(d: dict): async with semaphore: - key = self._prefix + ":" + str(d[self._key_field]) + key = f"{self._prefix}:{self._get_key_field(d)}" await self._redis_conn.hset(key, mapping=d) # type: ignore # gather with concurrency await asyncio.gather(*[load(d) for d in data]) @check_connected("_redis_conn") - async def _index_exists(self) -> bool: + async def exists(self) -> bool: """Check if the index exists in Redis Returns: diff --git a/redisvl/llmcache/base.py b/redisvl/llmcache/base.py index aa6af20b..c0223ddd 100644 --- a/redisvl/llmcache/base.py +++ b/redisvl/llmcache/base.py @@ -26,22 +26,3 @@ def _refresh_ttl(self, key: str): def hash_input(self, prompt: str): """Hashes the input using SHA256.""" return hashlib.sha256(prompt.encode("utf-8")).hexdigest() - - def cache_response(self, llm_callable: Callable): - """Decorator method for wrapping custom callables""" - - def wrapper(*args, **kwargs): - # Check LLM Cache first - key = self.hash_input(*args, **kwargs) - response = self.check(*args, **kwargs) - if response: - self._refresh_ttl(key) - return response - # Otherwise execute the llm callable here - response = llm_callable(*args, **kwargs) - args = list(args) - args.append(response) - self.store(*args, **kwargs) - return response - - return wrapper diff --git a/redisvl/llmcache/semantic.py b/redisvl/llmcache/semantic.py index af177379..5d20bf75 100644 --- a/redisvl/llmcache/semantic.py +++ b/redisvl/llmcache/semantic.py @@ -13,6 +13,7 @@ class SemanticCache(BaseLLMCache): """Cache for Large Language Models.""" + # TODO allow for user to change default fields _default_fields = [ VectorField( "prompt_vector", @@ -20,7 +21,6 @@ class SemanticCache(BaseLLMCache): {"DIM": 768, "TYPE": "FLOAT32", "DISTANCE_METRIC": "COSINE"}, ), ] - _default_provider = HuggingfaceProvider("sentence-transformers/all-mpnet-base-v2") def __init__( self, @@ -28,34 +28,72 @@ def __init__( prefix: str = "llmcache", threshold: float = 0.9, ttl: Optional[int] = None, - provider: Optional[BaseProvider] = None, + provider: Optional[BaseProvider] = HuggingfaceProvider( + "sentence-transformers/all-mpnet-base-v2" + ), redis_url: Optional[str] = "redis://localhost:6379", connection_args: Optional[dict] = None, ): + """Semantic Cache for Large Language Models. + + Args: + index_name (str, optional): The name of the index. Defaults to "cache". + prefix (str, optional): The prefix for the index. Defaults to "llmcache". + threshold (float, optional): Semantic threshold for the cache. Defaults to 0.9. + ttl (Optional[int], optional): The TTL for the cache. Defaults to None. + provider (Optional[BaseProvider], optional): The provider for the cache. + Defaults to HuggingfaceProvider("sentence-transformers/all-mpnet-base-v2"). + redis_url (Optional[str], optional): The redis url. Defaults to "redis://localhost:6379". + connection_args (Optional[dict], optional): The connection arguments for the redis client. Defaults to None. + + Raises: + ValueError: If the threshold is not between 0 and 1. + + """ self._ttl = ttl - self._provider = provider or self._default_provider - self._threshold = threshold + self._provider = provider + self.set_threshold(threshold) - # TODO - configure logging based on verbosity - self._index = SearchIndex( - index_name, prefix=prefix, fields=self._default_fields - ) + index = SearchIndex(name=index_name, prefix=prefix, fields=self._default_fields) connection_args = connection_args or {} - self._index.connect(url=redis_url, **connection_args) - self._index.create() + index.connect(url=redis_url, **connection_args) + + # create index or connect to existing index + if not index.exists(): + index.create() + self._index = index + else: + # TODO check prefix and fields are the same + client = index.client + self._index = SearchIndex.from_existing(client, index_name) @property def ttl(self) -> Optional[int]: - """Returns the TTL for the cache.""" + """Returns the TTL for the cache. + + Returns: + Optional[int]: The TTL for the cache. + """ return self._ttl def set_ttl(self, ttl: int): - """Sets the TTL for the cache.""" + """Sets the TTL for the cache. + + Args: + ttl (int): The TTL for the cache. + + Raises: + ValueError: If the TTL is not an integer. + """ self._ttl = int(ttl) @property def index(self) -> SearchIndex: - """Returns the index for the cache.""" + """Returns the index for the cache. + + Returns: + SearchIndex: The index for the cache. + """ return self._index @property @@ -64,8 +102,17 @@ def threshold(self) -> float: return self._threshold def set_threshold(self, threshold: float): - """Sets the threshold for the cache.""" - self._threshold = threshold + """Sets the threshold for the cache. + + Args: + threshold (float): The threshold for the cache. + + Raises: + ValueError: If the threshold is not between 0 and 1. + """ + if not 0 <= float(threshold) <= 1: + raise ValueError("Threshold must be between 0 and 1.") + self._threshold = float(threshold) def check( self, @@ -74,28 +121,32 @@ def check( num_results: int = 1, fields: List[str] = ["response"], ) -> Optional[List[str]]: - """Checks whether the cache contains the specified key.""" + """Checks whether the cache contains the specified prompt or vector. + + Args: + prompt (Optional[str], optional): The prompt to check. Defaults to None. + vector (Optional[List[float]], optional): The vector to check. Defaults to None. + num_results (int, optional): The number of results to return. Defaults to 1. + fields (List[str], optional): The fields to return. Defaults to ["response"]. + + Raises: + ValueError: If neither prompt nor vector is specified. + + Returns: + Optional[List[str]]: The response(s) if the cache contains the prompt or vector. + """ if not prompt and not vector: raise ValueError("Either prompt or vector must be specified.") - query = create_vector_query( - return_fields=fields, - vector_field_name="prompt_vector", - number_of_results=num_results, - ) - if vector: - prompt_vector = array_to_buffer(vector) - else: - prompt_vector = array_to_buffer(self._provider.embed(prompt)) # type: ignore - - # TODO: Come back if vector_distance is changed - fields.append("vector_distance") + if not vector: + vector = self._provider.embed(prompt) # type: ignore v = VectorQuery( - vector=prompt_vector, + vector=vector, vector_field_name="prompt_vector", return_fields=fields, - number_of_results=num_results + num_results=num_results, + return_score=True, ) results = self._index.search(v.query, query_params=v.params) @@ -116,15 +167,27 @@ def store( metadata: Optional[dict] = {}, key: Optional[str] = None, ) -> None: - """Stores the specified key-value pair in the cache along with metadata.""" + """Stores the specified key-value pair in the cache along with metadata. + + Args: + prompt (str): The prompt to store. + response (str): The response to store. + vector (Optional[List[float]], optional): The vector to store. Defaults to None + metadata (Optional[dict], optional): The metadata to store. Defaults to {}. + key (Optional[str], optional): The key to store. Defaults to None. + + Raises: + ValueError: If neither prompt nor vector is specified. + """ if not key: key = self.hash_input(prompt) + if vector: - prompt_vector = array_to_buffer(vector) + vector = array_to_buffer(vector) else: - prompt_vector = array_to_buffer(self._provider.embed(prompt)) + vector = self._provider.embed(prompt) # type: ignore - payload = {"id": key, "prompt_vector": prompt_vector, "response": response} + payload = {"id": key, "prompt_vector": vector, "response": response} if metadata: payload.update(metadata) self._index.load([payload]) diff --git a/redisvl/query.py b/redisvl/query.py index 87c0a6e3..b849c058 100644 --- a/redisvl/query.py +++ b/redisvl/query.py @@ -16,7 +16,7 @@ def __init__(self, field): def __str__(self): base = "(" + self.to_string() if self._filters: - base += " ".join(self._filters) + base += "".join(self._filters) return base + ")" def __iadd__(self, other): @@ -26,7 +26,7 @@ def __iadd__(self, other): def __iand__(self, other): "union '&='" - self._filters.append(f" |{other.to_string()}") + self._filters.append(f" | {other.to_string()}") return self def __isub__(self, other): @@ -66,19 +66,18 @@ def to_string(self) -> str: class NumericFilter(Filter): - - def __init__(self, field, minval, maxval, minExclusive=False, maxExclusive=False): + def __init__(self, field, minval, maxval, min_exclusive=False, max_exclusive=False): """Filter for Numeric fields. Args: field (str): The field to filter on. minval (int): The minimum value. maxval (int): The maximum value. - minExclusive (bool, optional): Whether the minimum value is exclusive. Defaults to False. - maxExclusive (bool, optional): Whether the maximum value is exclusive. Defaults to False. + min_exclusive (bool, optional): Whether the minimum value is exclusive. Defaults to False. + max_exclusive (bool, optional): Whether the maximum value is exclusive. Defaults to False. """ - self.top = maxval if not maxExclusive else f"({maxval}" - self.bottom = minval if not minExclusive else f"{minval})" + self.top = maxval if not max_exclusive else f"({maxval}" + self.bottom = minval if not min_exclusive else f"({minval}" super().__init__(field) def to_string(self): @@ -86,7 +85,6 @@ def to_string(self): class TextFilter(Filter): - def __init__(self, field, text: str): """Filter for Text fields. Args: @@ -127,6 +125,8 @@ class VectorQuery(BaseQuery): "float64": np.float64, } + DISTANCE_ID = "vector_distance" + def __init__( self, vector: List[float], @@ -135,6 +135,7 @@ def __init__( hybrid_filter: Filter = None, dtype: str = "float32", num_results: Optional[int] = 10, + return_score: bool = True, ): """Query for vector fields @@ -145,6 +146,11 @@ def __init__( hybrid_filter (Filter, optional): A filter to apply to the query. Defaults to None. dtype (str, optional): The dtype of the vector. Defaults to "float32". num_results (Optional[int], optional): The number of results to return. Defaults to 10. + return_score (bool, optional): Whether to return the score. Defaults to True. + + Raises: + TypeError: If hybrid_filter is not of type redisvl.query.Filter + """ super().__init__(return_fields, num_results) self._vector = vector @@ -155,6 +161,9 @@ def __init__( else: self._filter = "*" + if return_score: + self._return_fields.append(self.DISTANCE_ID) + def set_filter(self, hybrid_filter: Filter): """Set the filter for the query. @@ -180,11 +189,11 @@ def query(self): Returns: redis.commands.search.query.Query: The query object. """ - base_query = f"{self._filter}=>[KNN {self._num_results} @{self._field} $vector AS vector_distance]" + base_query = f"{self._filter}=>[KNN {self._num_results} @{self._field} $vector AS {self.DISTANCE_ID}]" query = ( Query(base_query) .return_fields(*self._return_fields) - .sort_by("vector_distance") + .sort_by(self.DISTANCE_ID) .paging(0, self._num_results) .dialect(2) ) diff --git a/redisvl/schema.py b/redisvl/schema.py index 65597d32..a958fc01 100644 --- a/redisvl/schema.py +++ b/redisvl/schema.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import List, Optional, Union +from uuid import uuid4 import yaml from pydantic import BaseModel, Field, field_validator @@ -116,7 +117,7 @@ def as_field(self): class IndexModel(BaseModel): name: str = Field(...) prefix: str = Field(...) - key_field: str = Field(...) + key_field: Optional[str] = Field(default=None) storage_type: str = Field(default="hash") diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 783cb2c9..3f8f0820 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -53,16 +53,3 @@ def test_set_threshold(cache): cache.set_threshold(0.9) assert cache.threshold == 0.9 cache.index.delete(drop=True) - - -def test_wrapper(cache): - @cache.cache_response - def test_function(prompt): - return "This is a test response." - - # Check that the wrapper works - test_function("This is a test prompt.") - check_result = cache.check("This is a test prompt.") - assert len(check_result) >= 1 - assert "This is a test response." in check_result - cache.index.delete(drop=True) diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py new file mode 100644 index 00000000..8b2e83d4 --- /dev/null +++ b/tests/integration/test_query.py @@ -0,0 +1,220 @@ +from pprint import pprint + +import numpy as np +import pytest + +from redisvl.index import SearchIndex +from redisvl.query import NumericFilter, TagFilter, VectorQuery + +data = [ + { + "user": "john", + "age": 18, + "job": "engineer", + "credit_score": "high", + "user_embedding": np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(), + }, + { + "user": "derrick", + "age": 14, + "job": "doctor", + "credit_score": "low", + "user_embedding": np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(), + }, + { + "user": "nancy", + "age": 94, + "job": "doctor", + "credit_score": "high", + "user_embedding": np.array([0.7, 0.1, 0.5], dtype=np.float32).tobytes(), + }, + { + "user": "tyler", + "age": 100, + "job": "engineer", + "credit_score": "high", + "user_embedding": np.array([0.1, 0.4, 0.5], dtype=np.float32).tobytes(), + }, + { + "user": "tim", + "age": 12, + "job": "dermatologist", + "credit_score": "high", + "user_embedding": np.array([0.4, 0.4, 0.5], dtype=np.float32).tobytes(), + }, + { + "user": "taimur", + "age": 15, + "job": "CEO", + "credit_score": "low", + "user_embedding": np.array([0.6, 0.1, 0.5], dtype=np.float32).tobytes(), + }, + { + "user": "joe", + "age": 35, + "job": "dentist", + "credit_score": "medium", + "user_embedding": np.array([0.9, 0.9, 0.1], dtype=np.float32).tobytes(), + }, +] + +schema = { + "index": { + "name": "user_index", + "prefix": "v1", + "storage_type": "hash", + }, + "fields": { + "tag": [{"name": "credit_score"}], + "text": [{"name": "job"}], + "numeric": [{"name": "age"}], + "vector": [ + { + "name": "user_embedding", + "dims": 3, + "distance_metric": "cosine", + "algorithm": "flat", + "datatype": "float32", + } + ], + }, +} + + +@pytest.fixture(scope="module") +def index(): + # construct a search index from the schema + index = SearchIndex.from_dict(schema) + + # connect to local redis instance + index.connect("redis://localhost:6379") + + # create the index (no data yet) + index.create(overwrite=True) + + index.load(data) + + # run the test + yield index + + # clean up + index.delete() + + +def test_simple(index): + # *=>[KNN 7 @user_embedding $vector AS vector_distance] + v = VectorQuery( + [0.1, 0.1, 0.5], + "user_embedding", + return_fields=["user", "credit_score", "age", "job"], + num_results=7, + ) + results = index.search(v.query, query_params=v.params) + assert len(results.docs) == 7 + for doc in results.docs: + # ensure all return fields present + assert doc.user in ["john", "derrick", "nancy", "tyler", "tim", "taimur", "joe"] + assert int(doc.age) in [18, 14, 94, 100, 12, 15, 35] + assert doc.job in ["engineer", "doctor", "dermatologist", "CEO", "dentist"] + assert doc.credit_score in ["high", "low", "medium"] + + +def test_simple_tag_filter(index): + # (@credit_score:{high})=>[KNN 10 @user_embedding $vector AS vector_distance] + t = TagFilter("credit_score", "high") + v = VectorQuery( + [0.1, 0.1, 0.5], + "user_embedding", + return_fields=["user", "credit_score", "age", "job"], + hybrid_filter=t, + ) + + results = index.search(v.query, query_params=v.params) + assert len(results.docs) == 4 + + +def test_simple_numeric_filter(index): + # (@age:[18 101])=>[KNN 10 @user_embedding $vector AS vector_distance] + n = NumericFilter("age", 18, 100) + v = VectorQuery( + [0.1, 0.1, 0.5], + "user_embedding", + return_fields=["user", "credit_score", "age", "job"], + hybrid_filter=n, + ) + + results = index.search(v.query, query_params=v.params) + assert len(results.docs) == 4 + + +def test_numeric_filter_exclusive(index): + n = NumericFilter("age", 18, 100, min_exclusive=True) + v = VectorQuery( + [0.1, 0.1, 0.5], + "user_embedding", + return_fields=["user", "credit_score", "age", "job"], + hybrid_filter=n, + ) + + results = index.search(v.query, query_params=v.params) + assert len(results.docs) == 3 + + n_both_exclusive = NumericFilter( + "age", 18, 100, min_exclusive=True, max_exclusive=True + ) + v.set_filter(n_both_exclusive) + results = index.search(v.query, query_params=v.params) + assert len(results.docs) == 2 + + +def test_combinations(index): + # (@age:[18 100] @credit_score:{high})=>[KNN 10 @user_embedding $vector AS vector_distance] + t = TagFilter("credit_score", "high") + n = NumericFilter("age", 18, 100) + t += n + v = VectorQuery( + [0.1, 0.1, 0.5], + "user_embedding", + return_fields=["user", "credit_score", "age", "job"], + hybrid_filter=t, + ) + + results = index.search(v.query, query_params=v.params) + for doc in results.docs: + assert doc.credit_score == "high" + assert 18 <= int(doc.age) <= 100 + assert len(results.docs) == 3 + + # (@credit_score:{high} -@age:[18 100])=>[KNN 10 @user_embedding $vector AS vector_distance] + t = TagFilter("credit_score", "high") + n = NumericFilter("age", 18, 100) + t -= n + v.set_filter(t) + + results = index.search(v.query, query_params=v.params) + for doc in results.docs: + assert doc.credit_score == "high" + assert int(doc.age) not in range(18, 101) + assert len(results.docs) == 1 + + # (@credit_score:{high} | @age:[18 100])=>[KNN 10 @user_embedding $vector AS vector_distance] + t = TagFilter("credit_score", "high") + n = NumericFilter("age", 18, 100) + t &= n + v.set_filter(t) + + results = index.search(v.query, query_params=v.params) + for doc in results.docs: + assert (doc.credit_score == "high") or (18 <= int(doc.age) <= 100) + assert len(results.docs) == 5 + + # (@credit_score:{high} ~@age:[18 100])=>[KNN 10 @user_embedding $vector AS vector_distance] + t = TagFilter("credit_score", "high") + n = NumericFilter("age", 18, 100) + t ^= n + v.set_filter(t) + + results = index.search(v.query, query_params=v.params) + for doc in results.docs: + assert doc.credit_score == "high" + assert len(results.docs) == 4 diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index 053339b2..3d5a2556 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -3,7 +3,7 @@ import numpy as np from redisvl.index import SearchIndex -from redisvl.query import create_vector_query +from redisvl.query import VectorQuery data = [ { @@ -31,7 +31,6 @@ "user_embedding": np.array([0.9, 0.9, 0.1], dtype=np.float32).tobytes(), }, ] -query_vector = np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes() schema = { "index": { @@ -67,13 +66,14 @@ def test_simple(client): # load data into the index in Redis index.load(data) - query = create_vector_query( - ["user", "age", "job", "credit_score"], - number_of_results=3, + query = VectorQuery( + vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", + return_fields=["user", "age", "job", "credit_score"], + num_results=3, ) - results = index.search(query, query_params={"vector": query_vector}) + results = index.search(query.query, query_params=query.params) # make sure correct users returned # users = list(results.docs) @@ -85,13 +85,13 @@ def test_simple(client): # make sure vector scores are correct # query vector and first two are the same vector. # third is different (hence should be positive difference) - assert float(users[0].vector_score) == 0.0 - assert float(users[1].vector_score) == 0.0 - assert float(users[2].vector_score) > 0 + assert float(users[0].vector_distance) == 0.0 + assert float(users[1].vector_distance) == 0.0 + assert float(users[2].vector_distance) > 0 print() for doc in results.docs: - print("Score:", doc.vector_score) + print("Score:", doc.vector_distance) pprint(doc) index.delete() diff --git a/tests/integration/test_simple_async.py b/tests/integration/test_simple_async.py index 9f3bae00..26e543c1 100644 --- a/tests/integration/test_simple_async.py +++ b/tests/integration/test_simple_async.py @@ -5,7 +5,7 @@ import pytest from redisvl.index import AsyncSearchIndex -from redisvl.query import create_vector_query +from redisvl.query import VectorQuery data = [ { @@ -73,13 +73,14 @@ async def test_simple(async_client): # wait for async index to create time.sleep(1) - query = create_vector_query( - ["user", "age", "job", "credit_score", "vector_score"], - number_of_results=3, + query = VectorQuery( + vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", + return_fields=["user", "age", "job", "credit_score"], + num_results=3, ) - results = await index.search(query, query_params={"vector": query_vector}) + results = await index.search(query.query, query_params=query.params) # make sure correct users returned # users = list(results.docs) @@ -91,13 +92,13 @@ async def test_simple(async_client): # make sure vector scores are correct # query vector and first two are the same vector. # third is different (hence should be positive difference) - assert float(users[0].vector_score) == 0.0 - assert float(users[1].vector_score) == 0.0 - assert float(users[2].vector_score) > 0 + assert float(users[0].vector_distance) == 0.0 + assert float(users[1].vector_distance) == 0.0 + assert float(users[2].vector_distance) > 0 print() for doc in results.docs: - print("Score:", doc.vector_score) + print("Score:", doc.vector_distance) pprint(doc) await index.delete() diff --git a/tests/test_filter.py b/tests/test_filter.py new file mode 100644 index 00000000..e3d48351 --- /dev/null +++ b/tests/test_filter.py @@ -0,0 +1,51 @@ +import pytest + +from redisvl.query import Filter, NumericFilter, TagFilter, TextFilter, VectorQuery +from redisvl.utils.utils import TokenEscaper + + +class TestFilters: + def test_tag_filter(self): + tf = TagFilter("tag_field", ["tag1", "tag2"]) + assert tf.to_string() == "@tag_field:{tag1 | tag2}" + + def test_numeric_filter(self): + nf = NumericFilter( + "numeric_field", 1, 10, min_exclusive=True, max_exclusive=True + ) + assert nf.to_string() == "@numeric_field:[(1 (10]" + + def test_numeric_filter_2(self): + nf = NumericFilter( + "numeric_field", 1, 10, min_exclusive=False, max_exclusive=False + ) + assert nf.to_string() == "@numeric_field:[1 10]" + + def test_text_filter(self): + txt_f = TextFilter("text_field", "text") + assert txt_f.to_string() == "@text_field:text" + + def test_filters_combination(self): + tf1 = TagFilter("tag_field", ["tag1", "tag2"]) + tf2 = TagFilter("tag_field", ["tag3"]) + tf1 += tf2 + assert str(tf1) == "(@tag_field:{tag1 | tag2} @tag_field:{tag3})" + tf1 &= tf2 + assert ( + str(tf1) + == "(@tag_field:{tag1 | tag2} @tag_field:{tag3} | @tag_field:{tag3})" + ) + tf1 -= tf2 + assert ( + str(tf1) + == "(@tag_field:{tag1 | tag2} @tag_field:{tag3} | @tag_field:{tag3} -@tag_field:{tag3})" + ) + tf1 ^= tf2 + assert ( + str(tf1) + == "(@tag_field:{tag1 | tag2} @tag_field:{tag3} | @tag_field:{tag3} -@tag_field:{tag3} ~@tag_field:{tag3})" + ) + + def test_filter_raise_not_implemented_error(self): + with pytest.raises(NotImplementedError): + Filter("field").to_string() diff --git a/tests/test_index.py b/tests/test_index.py index 8a117495..cd99a21d 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -9,14 +9,14 @@ def test_search_index_client(client): - si = SearchIndex("my_index", fields=fields) + si = SearchIndex("my_index", key_field="id", fields=fields) si.set_client(client) assert si.client is not None def test_search_index_create(client): - si = SearchIndex("my_index", fields=fields) + si = SearchIndex("my_index", key_field="id", fields=fields) si.set_client(client) si.create(overwrite=True) @@ -27,7 +27,7 @@ def test_search_index_create(client): def test_search_index_delete(client): - si = SearchIndex("my_index", fields=fields) + si = SearchIndex("my_index", key_field="id", fields=fields) si.set_client(client) si.create(overwrite=True) si.delete() @@ -36,18 +36,18 @@ def test_search_index_delete(client): def test_search_index_load(client): - si = SearchIndex("my_index", fields=fields) + si = SearchIndex("my_index", key_field="id", fields=fields) si.set_client(client) si.create(overwrite=True) data = [{"id": "1", "value": "test"}] si.load(data) - assert convert_bytes(client.hget(":1", "value")) == "test" + assert convert_bytes(client.hget("rvl:1", "value")) == "test" @pytest.mark.asyncio async def test_async_search_index_creation(async_client): - asi = AsyncSearchIndex("my_index", fields=fields) + asi = AsyncSearchIndex("my_index", key_field="id", fields=fields) asi.set_client(async_client) assert asi.client == async_client @@ -55,7 +55,7 @@ async def test_async_search_index_creation(async_client): @pytest.mark.asyncio async def test_async_search_index_create(async_client): - asi = AsyncSearchIndex("my_index", fields=fields) + asi = AsyncSearchIndex("my_index", key_field="id", fields=fields) asi.set_client(async_client) await asi.create(overwrite=True) @@ -65,7 +65,7 @@ async def test_async_search_index_create(async_client): @pytest.mark.asyncio async def test_async_search_index_delete(async_client): - asi = AsyncSearchIndex("my_index", fields=fields) + asi = AsyncSearchIndex("my_index", key_field="id", fields=fields) asi.set_client(async_client) await asi.create(overwrite=True) await asi.delete() @@ -76,12 +76,12 @@ async def test_async_search_index_delete(async_client): @pytest.mark.asyncio async def test_async_search_index_load(async_client): - asi = AsyncSearchIndex("my_index", fields=fields) + asi = AsyncSearchIndex("my_index", key_field="id", fields=fields) asi.set_client(async_client) await asi.create(overwrite=True) data = [{"id": "1", "value": "test"}] await asi.load(data) - result = await async_client.hget(":1", "value") + result = await async_client.hget("rvl:1", "value") assert convert_bytes(result) == "test" await asi.delete() @@ -90,7 +90,7 @@ async def test_async_search_index_load(async_client): def test_search_index_delete_nonexistent(client): - si = SearchIndex("my_index") + si = SearchIndex("my_index", key_field="id", fields=fields) si.set_client(client) with pytest.raises(redis.exceptions.ResponseError): si.delete() @@ -98,7 +98,7 @@ def test_search_index_delete_nonexistent(client): @pytest.mark.asyncio async def test_async_search_index_delete_nonexistent(async_client): - asi = AsyncSearchIndex("my_index") + asi = AsyncSearchIndex("my_index", key_field="id", fields=fields) asi.set_client(async_client) with pytest.raises(redis.exceptions.ResponseError): await asi.delete() @@ -114,13 +114,13 @@ def test_no_key_field(client): bad_data = [{"wrong_key": "1", "value": "test"}] # TODO make a better error - with pytest.raises(KeyError): + with pytest.raises(ValueError): si.load(bad_data) @pytest.mark.asyncio async def test_async_search_index_load_bad_data(async_client): - asi = AsyncSearchIndex("my_index", fields=fields) + asi = AsyncSearchIndex("my_index", key_field="id", fields=fields) asi.set_client(async_client) await asi.create(overwrite=True) From a24de6477a3ca1433c48b36249f1d663988f4f55 Mon Sep 17 00:00:00 2001 From: Sam Partee Date: Tue, 25 Jul 2023 22:56:57 -0700 Subject: [PATCH 3/3] Edit README --- README.md | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 5851cc5b..0199346e 100644 --- a/README.md +++ b/README.md @@ -74,27 +74,27 @@ This would correspond to a dataset that looked something like With the schema, the RedisVL library can be used to create, load vectors and perform vector searches ```python -import pandas as pd from redisvl.index import SearchIndex -from redisvl.query import create_vector_query +from redisvl.query import VectorQuery -# define and create the index -index = SearchIndex.from_yaml("./users_schema.yml")) +# initialize the index and connect to Redis +index = SearchIndex.from_dict(schema) index.connect("redis://localhost:6379") -index.create() -index.load(pd.read_csv("./users.csv").to_dict("records")) +# create the index in Redis +index.create(overwrite=True) -query = create_vector_query( - ["user", "age", "job", "credit_score"], - number_of_results=2, +# load data into the index in Redis (list of dicts) +index.load(data) + +query = VectorQuery( + vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", + return_fields=["user", "age", "job", "credit_score"], + num_results=3, ) - -query_vector = np.array([0.1, 0.1, 0.5]).tobytes() -results = index.search(query, query_params={"vector": query_vector}) - +results = index.search(query.query, query_params=query.params) ```