From 95801b8924f98416c86c86e4f20965da6f813b09 Mon Sep 17 00:00:00 2001 From: Sam Partee Date: Thu, 3 Aug 2023 17:59:58 -0700 Subject: [PATCH 1/3] take out query --- redisvl/cli/query.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 redisvl/cli/query.py diff --git a/redisvl/cli/query.py b/redisvl/cli/query.py deleted file mode 100644 index e69de29b..00000000 From 656f87f6f231d3f9e61b430d9f5e6c7408244b7f Mon Sep 17 00:00:00 2001 From: Sam Partee Date: Thu, 3 Aug 2023 19:28:59 -0700 Subject: [PATCH 2/3] Add GeoFilter --- docs/api/filter.rst | 17 + docs/user_guide/hybrid_example_data.pkl | Bin 0 -> 494 bytes docs/user_guide/hybrid_queries_02.ipynb | 351 +++++++++--------- docs/user_guide/jupyterutils.py | 42 +++ ...roviders_03.ipynb => vectorizers_03.ipynb} | 0 redisvl/query.py | 55 +++ tests/integration/test_llmcache.py | 7 + tests/test_filter.py | 16 +- 8 files changed, 321 insertions(+), 167 deletions(-) create mode 100644 docs/user_guide/hybrid_example_data.pkl create mode 100644 docs/user_guide/jupyterutils.py rename docs/user_guide/{providers_03.ipynb => vectorizers_03.ipynb} (100%) diff --git a/docs/api/filter.rst b/docs/api/filter.rst index 92786a34..d88bcd9c 100644 --- a/docs/api/filter.rst +++ b/docs/api/filter.rst @@ -57,3 +57,20 @@ NumericFilter :show-inheritance: :members: :inherited-members: + + +GeoFilter +========= + +.. currentmodule:: redisvl.query + +.. autosummary:: + + GeoFilter.__init__ + GeoFilter.to_string + + +.. autoclass:: GeoFilter + :show-inheritance: + :members: + :inherited-members: \ No newline at end of file diff --git a/docs/user_guide/hybrid_example_data.pkl b/docs/user_guide/hybrid_example_data.pkl new file mode 100644 index 0000000000000000000000000000000000000000..b5928b917e3263b461fe62779d384c1ea63d0b61 GIT binary patch literal 494 zcmZvZ%T59@6o%EgfQq7SjB7V6ln6rv7Q%wKV&M}g89JRg#hEr4Ix#Fve1I;G3w8k(N|2yrA`}(=PRjzw%JY&PS#hmA8+?fK2#t9RS;}he|)D%sX%S1?S z&uc8_k8oRU8Y(C#hng7K$!7zHGJI%}%S|xd^(=Znx7%b7(6jJe%?; z5;DP2?PIyTv71DOSda$nm`cR+T(Cz3cYFw*@gv2w$`LBYN!1g=F6_bn|EwoN$rhG; z01xU=f$DcW2!i47Do-y`8j9i(#r;)$t$x)D`c^&YUAJ(fZ?Ng=XpvdGMxF=corPMZ zDiL!{x@Gt28XZvg57gJ8$cnPOPP8++y|c^_`IgN4+u0vq`VDUD&IzDrl6qBS{EYsj RAd@v4$MMi^J1-Wz`30w4sp\\x00\\x00\\x00?'},\n", - " {'age': 12,\n", - " 'credit_score': 'high',\n", - " 'job': 'dermatologist',\n", - " 'user': 'tim',\n", - " 'user_embedding': b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'},\n", - " {'age': 15,\n", - " 'credit_score': 'low',\n", - " 'job': 'CEO',\n", - " 'user': 'taimur',\n", - " 'user_embedding': b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'},\n", - " {'age': 35,\n", - " 'credit_score': 'medium',\n", - " 'job': 'dentist',\n", - " 'user': 'joe',\n", - " 'user_embedding': b'fff?fff?\\xcd\\xcc\\xcc='}]\n" - ] + "data": { + "text/html": [ + "
useragejobcredit_scoreoffice_locationuser_embedding
john18engineerhigh-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
derrick14doctorlow-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
nancy94doctorhigh-122.4194,37.7749b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
tyler100engineerhigh-122.0839,37.3861b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'
tim12dermatologisthigh-122.0839,37.3861b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'
taimur15CEOlow-122.0839,37.3861b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
joe35dentistmedium-122.0839,37.3861b'fff?fff?\\xcd\\xcc\\xcc='
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "# converted to bytes for redis\n", - "vectors = [\n", - " np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.7, 0.1, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.1, 0.4, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.4, 0.4, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.6, 0.1, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.9, 0.9, 0.1], dtype=np.float32).tobytes(),\n", - "]\n", - "\n", - "for record, vector in zip(data, vectors):\n", - " record[\"user_embedding\"] = vector\n", - "\n", - "pprint(data)" + "import pickle\n", + "from jupyterutils import table_print, result_print\n", + "\n", + "# load in the example data and printing utils\n", + "data = pickle.load(open(\"hybrid_example_data.pkl\", \"rb\"))\n", + "table_print(data)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -114,6 +56,7 @@ " \"tag\": [{\"name\": \"credit_score\"}],\n", " \"text\": [{\"name\": \"job\"}],\n", " \"numeric\": [{\"name\": \"age\"}],\n", + " \"geo\": [{\"name\": \"office_location\"}],\n", " \"vector\": [{\n", " \"name\": \"user_embedding\",\n", " \"dims\": 3,\n", @@ -127,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -145,15 +88,15 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m17:46:27\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[8815]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m17:46:27\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[8815]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n" + "\u001b[32m19:22:36\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[21909]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m19:22:36\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[21909]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n" ] } ], @@ -164,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -193,18 +136,20 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", - "Document {'id': 'v1:tim', 'payload': None, 'vector_distance': '0.158809006214', 'user': 'tim', 'credit_score': 'high', 'age': '12', 'job': 'dermatologist'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158809006214timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -214,13 +159,12 @@ "\n", "v = VectorQuery([0.1, 0.1, 0.5],\n", " \"user_embedding\",\n", - " return_fields=[\"user\", \"credit_score\", \"age\", \"job\"],\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", " hybrid_filter=t)\n", "\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" ] }, { @@ -234,18 +178,20 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n", - "Document {'id': 'v1:joe', 'payload': None, 'vector_distance': '0.653301358223', 'user': 'joe', 'credit_score': 'medium', 'age': '35', 'job': 'dentist'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -254,8 +200,7 @@ "v.set_filter(n)\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" ] }, { @@ -269,16 +214,20 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:derrick', 'payload': None, 'vector_distance': '0', 'user': 'derrick', 'credit_score': 'low', 'age': '14', 'job': 'doctor'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -288,8 +237,72 @@ "v.set_filter(text_filter)\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Geographic Filters\n", + "\n", + "Geographic filters are filters that are applied to geographic fields. These filters are used to find results that are within a certain distance of a given point. The distance is specified in kilometers, miles, meters, or feet. A radius can also be specified to find results within a certain radius of a given point." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from redisvl.query import GeoFilter\n", + "\n", + "# within 10 km of San Francisco office\n", + "geo_filter = GeoFilter(\"office_location\", -122.4194, 37.7749 , 10, \"km\")\n", + "v.set_filter(geo_filter)\n", + "\n", + "results = index.query(v)\n", + "result_print(results)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158809006214timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# change to 100 km of San Francisco office\n", + "geo_filter = GeoFilter(\"office_location\", -122.4194, 37.7749 , 100, \"km\")\n", + "v.set_filter(geo_filter)\n", + "\n", + "results = index.query(v)\n", + "result_print(results)" ] }, { @@ -303,17 +316,20 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -323,13 +339,12 @@ "\n", "v = VectorQuery([0.1, 0.1, 0.5],\n", " \"user_embedding\",\n", - " return_fields=[\"user\", \"credit_score\", \"age\", \"job\"],\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", " hybrid_filter=t)\n", "\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" ] }, { @@ -343,15 +358,20 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:tim', 'payload': None, 'vector_distance': '0.158809006214', 'user': 'tim', 'credit_score': 'high', 'age': '12', 'job': 'dermatologist'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0.158809006214timhigh12dermatologist-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -362,8 +382,7 @@ "v.set_filter(t)\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" ] }, { @@ -377,19 +396,20 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", - "Document {'id': 'v1:tim', 'payload': None, 'vector_distance': '0.158809006214', 'user': 'tim', 'credit_score': 'high', 'age': '12', 'job': 'dermatologist'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n", - "Document {'id': 'v1:joe', 'payload': None, 'vector_distance': '0.653301358223', 'user': 'joe', 'credit_score': 'medium', 'age': '35', 'job': 'dentist'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158809006214timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -400,8 +420,7 @@ "v.set_filter(t)\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" ] }, { @@ -419,19 +438,20 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'age': '100', 'user': 'tyler', 'credit_score': 'high', 'job': 'engineer'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'age': '94', 'user': 'nancy', 'credit_score': 'high', 'job': 'doctor'}\n", - "Document {'id': 'v1:joe', 'payload': None, 'vector_distance': '0.653301358223', 'age': '35', 'user': 'joe', 'credit_score': 'medium', 'job': 'dentist'}\n", - "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'age': '18', 'user': 'john', 'credit_score': 'high', 'job': 'engineer'}\n", - "Document {'id': 'v1:tim', 'payload': None, 'vector_distance': '0.158809006214', 'age': '12', 'user': 'tim', 'credit_score': 'high', 'job': 'dermatologist'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceageusercredit_scorejoboffice_location
0.109129190445100tylerhighengineer-122.0839,37.3861
0.26666665077294nancyhighdoctor-122.4194,37.7749
0.65330135822335joemediumdentist-122.0839,37.3861
018johnhighengineer-122.4194,37.7749
0.15880900621412timhighdermatologist-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -443,8 +463,7 @@ "\n", "# run the query with the ``SearchIndex.search`` method\n", "result = index.search(redis_py_query, v.params)\n", - "for doc in result.docs:\n", - " print(doc)" + "result_print(result)" ] }, { @@ -458,7 +477,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -467,7 +486,7 @@ "'(@credit_score:{high})'" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -479,24 +498,24 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Document {'id': 'v1:john', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'user_embedding': '==\\x00\\x00\\x00?'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", - "Document {'id': 'v1:tim', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" + "{'id': 'v1:38bfee0253ca452e96b4b3fdcb2798f7', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'v1:747bce550564443199ae1118cf03b5e3', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'v1:da7b6b0bf94f4c40a2ea23e20035ca73', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'v1:abcdf6be4fb042389a93a9b27d6cce5c', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" ] } ], "source": [ "results = index.search(str(t))\n", - "for doc in results.docs:\n", - " print(doc)" + "for r in results.docs:\n", + " print(r.__dict__)" ] }, { @@ -512,7 +531,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -521,7 +540,7 @@ "'(@credit_score:{high} @age:[18 100])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 5 user credit_score age job vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10'" ] }, - "execution_count": 16, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } diff --git a/docs/user_guide/jupyterutils.py b/docs/user_guide/jupyterutils.py new file mode 100644 index 00000000..8e4954b7 --- /dev/null +++ b/docs/user_guide/jupyterutils.py @@ -0,0 +1,42 @@ +from IPython.display import display, HTML + +def table_print(dict_list): + # If there's nothing in the list, there's nothing to print + if len(dict_list) == 0: + return + + # Getting column names (dictionary keys) using the first dictionary + columns = dict_list[0].keys() + + # HTML table header + html = '' + + # HTML table content + for dictionary in dict_list: + html += '' + + # HTML table footer + html += '
' + html += ''.join(columns) + html += '
' + html += ''.join(str(dictionary[column]) for column in columns) + html += '
' + + # Displaying the table + display(HTML(html)) + + +def result_print(results): + # If there's nothing in the list, there's nothing to print + if len(results.docs) == 0: + return + + data = [doc.__dict__ for doc in results.docs] + + to_remove = ["id", "payload"] + for doc in data: + for key in to_remove: + if key in doc: + del doc[key] + + table_print(data) \ No newline at end of file diff --git a/docs/user_guide/providers_03.ipynb b/docs/user_guide/vectorizers_03.ipynb similarity index 100% rename from docs/user_guide/providers_03.ipynb rename to docs/user_guide/vectorizers_03.ipynb diff --git a/redisvl/query.py b/redisvl/query.py index d1eaf149..51903183 100644 --- a/redisvl/query.py +++ b/redisvl/query.py @@ -66,6 +66,61 @@ def to_string(self) -> str: ) +class GeoFilter(Filter): + GEO_UNITS = ["m", "km", "mi", "ft"] + + def __init__(self, field, longitude, latitude, radius, unit="km"): + """Filter for Geo fields. + + Args: + field (str): The field to filter on. + longitude (float): The longitude. + latitude (float): The latitude. + radius (float): The radius. + unit (str, optional): The unit of the radius. Defaults to "km". + + Raises: + ValueError: If the unit is not one of ["m", "km", "mi", "ft"]. + + Examples: + >>> # looking for Chinese restaurants near San Francisco + >>> # (within a 5km radius) would be + >>> # + >>> from redisvl.query import GeoFilter + >>> gf = GeoFilter("location", -122.4194, 37.7749, 5) + """ + super().__init__(field) + self._longitude = longitude + self._latitude = latitude + self._radius = radius + self._unit = self._set_unit(unit) + + def _set_unit(self, unit): + if unit.lower() not in self.GEO_UNITS: + raise ValueError(f"Unit must be one of {self.GEO_UNITS}") + return unit.lower() + + def to_string(self) -> str: + """Converts the geo filter to a string. + + Returns: + str: The geo filter as a string. + """ + return ( + "@" + + self._field + + ":[" + + str(self._longitude) + + " " + + str(self._latitude) + + " " + + str(self._radius) + + " " + + self._unit + + "]" + ) + + class NumericFilter(Filter): def __init__(self, field, minval, maxval, min_exclusive=False, max_exclusive=False): """Filter for Numeric fields. diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index c808e051..a3d5aa02 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -60,11 +60,14 @@ def test_check_no_match(cache, vector): cache._index.delete(True) +<<<<<<< HEAD def test_check_failure(cache): with pytest.raises(ValueError): cache.check(num_results=1) +======= +>>>>>>> 5956bf2 (Add GeoFilter) def test_store_with_vector_and_metadata(cache, vector): # Test storing a response with a vector and metadata prompt = "This is another test prompt." @@ -85,6 +88,7 @@ def test_set_threshold(cache): cache._index.delete(True) +<<<<<<< HEAD def test_from_index(client, vector): # Create customer index index = SearchIndex(name="test", fields=SemanticCache._default_fields) @@ -104,6 +108,9 @@ def test_from_index(client, vector): def test_from_existing_cache(cache, vector, vectorizer): +======= +def test_from_existing(cache, vector, vectorizer): +>>>>>>> 5956bf2 (Add GeoFilter) prompt = "This is another test prompt." response = "This is another test response." metadata = {"source": "test"} diff --git a/tests/test_filter.py b/tests/test_filter.py index e3d48351..030afec9 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -1,6 +1,13 @@ import pytest -from redisvl.query import Filter, NumericFilter, TagFilter, TextFilter, VectorQuery +from redisvl.query import ( + Filter, + GeoFilter, + NumericFilter, + TagFilter, + TextFilter, + VectorQuery, +) from redisvl.utils.utils import TokenEscaper @@ -25,6 +32,13 @@ def test_text_filter(self): txt_f = TextFilter("text_field", "text") assert txt_f.to_string() == "@text_field:text" + def test_geo_filter(self): + geo_f = GeoFilter("geo_field", 1, 2, 3) + assert geo_f.to_string() == "@geo_field:[1 2 3 km]" + + geo_f = GeoFilter("geo_field", 1, 2, 3, unit="m") + assert geo_f.to_string() == "@geo_field:[1 2 3 m]" + def test_filters_combination(self): tf1 = TagFilter("tag_field", ["tag1", "tag2"]) tf2 = TagFilter("tag_field", ["tag3"]) From 61f83eafd957e19b83e8ea055d3070f8fe723b2e Mon Sep 17 00:00:00 2001 From: Sam Partee Date: Thu, 3 Aug 2023 19:35:10 -0700 Subject: [PATCH 3/3] resolve merge conflict --- tests/integration/test_llmcache.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index a3d5aa02..c808e051 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -60,14 +60,11 @@ def test_check_no_match(cache, vector): cache._index.delete(True) -<<<<<<< HEAD def test_check_failure(cache): with pytest.raises(ValueError): cache.check(num_results=1) -======= ->>>>>>> 5956bf2 (Add GeoFilter) def test_store_with_vector_and_metadata(cache, vector): # Test storing a response with a vector and metadata prompt = "This is another test prompt." @@ -88,7 +85,6 @@ def test_set_threshold(cache): cache._index.delete(True) -<<<<<<< HEAD def test_from_index(client, vector): # Create customer index index = SearchIndex(name="test", fields=SemanticCache._default_fields) @@ -108,9 +104,6 @@ def test_from_index(client, vector): def test_from_existing_cache(cache, vector, vectorizer): -======= -def test_from_existing(cache, vector, vectorizer): ->>>>>>> 5956bf2 (Add GeoFilter) prompt = "This is another test prompt." response = "This is another test response." metadata = {"source": "test"}