diff --git a/docs/api/filter.rst b/docs/api/filter.rst index 92786a34..d88bcd9c 100644 --- a/docs/api/filter.rst +++ b/docs/api/filter.rst @@ -57,3 +57,20 @@ NumericFilter :show-inheritance: :members: :inherited-members: + + +GeoFilter +========= + +.. currentmodule:: redisvl.query + +.. autosummary:: + + GeoFilter.__init__ + GeoFilter.to_string + + +.. autoclass:: GeoFilter + :show-inheritance: + :members: + :inherited-members: \ No newline at end of file diff --git a/docs/user_guide/hybrid_example_data.pkl b/docs/user_guide/hybrid_example_data.pkl new file mode 100644 index 00000000..b5928b91 Binary files /dev/null and b/docs/user_guide/hybrid_example_data.pkl differ diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb index 04982d8a..38c4f9d9 100644 --- a/docs/user_guide/hybrid_queries_02.ipynb +++ b/docs/user_guide/hybrid_queries_02.ipynb @@ -18,90 +18,32 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "from pprint import pprint\n", - "\n", - "data = [\n", - " {'user': 'john', 'age': 18, 'job': 'engineer', 'credit_score': 'high'},\n", - " {'user': 'derrick', 'age': 14, 'job': 'doctor', 'credit_score': 'low'},\n", - " {'user': 'nancy', 'age': 94, 'job': 'doctor', 'credit_score': 'high'},\n", - " {'user': 'tyler', 'age': 100, 'job': 'engineer', 'credit_score': 'high'},\n", - " {'user': 'tim', 'age': 12, 'job': 'dermatologist', 'credit_score': 'high'},\n", - " {'user': 'taimur', 'age': 15, 'job': 'CEO', 'credit_score': 'low'},\n", - " {'user': 'joe', 'age': 35, 'job': 'dentist', 'credit_score': 'medium'}\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "[{'age': 18,\n", - " 'credit_score': 'high',\n", - " 'job': 'engineer',\n", - " 'user': 'john',\n", - " 'user_embedding': b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'},\n", - " {'age': 14,\n", - " 'credit_score': 'low',\n", - " 'job': 'doctor',\n", - " 'user': 'derrick',\n", - " 'user_embedding': b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'},\n", - " {'age': 94,\n", - " 'credit_score': 'high',\n", - " 'job': 'doctor',\n", - " 'user': 'nancy',\n", - " 'user_embedding': b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'},\n", - " {'age': 100,\n", - " 'credit_score': 'high',\n", - " 'job': 'engineer',\n", - " 'user': 'tyler',\n", - " 'user_embedding': b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'},\n", - " {'age': 12,\n", - " 'credit_score': 'high',\n", - " 'job': 'dermatologist',\n", - " 'user': 'tim',\n", - " 'user_embedding': b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'},\n", - " {'age': 15,\n", - " 'credit_score': 'low',\n", - " 'job': 'CEO',\n", - " 'user': 'taimur',\n", - " 'user_embedding': b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'},\n", - " {'age': 35,\n", - " 'credit_score': 'medium',\n", - " 'job': 'dentist',\n", - " 'user': 'joe',\n", - " 'user_embedding': b'fff?fff?\\xcd\\xcc\\xcc='}]\n" - ] + "data": { + "text/html": [ + "
useragejobcredit_scoreoffice_locationuser_embedding
john18engineerhigh-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
derrick14doctorlow-122.4194,37.7749b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
nancy94doctorhigh-122.4194,37.7749b'333?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
tyler100engineerhigh-122.0839,37.3861b'\\xcd\\xcc\\xcc=\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'
tim12dermatologisthigh-122.0839,37.3861b'\\xcd\\xcc\\xcc>\\xcd\\xcc\\xcc>\\x00\\x00\\x00?'
taimur15CEOlow-122.0839,37.3861b'\\x9a\\x99\\x19?\\xcd\\xcc\\xcc=\\x00\\x00\\x00?'
joe35dentistmedium-122.0839,37.3861b'fff?fff?\\xcd\\xcc\\xcc='
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "# converted to bytes for redis\n", - "vectors = [\n", - " np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.1, 0.1, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.7, 0.1, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.1, 0.4, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.4, 0.4, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.6, 0.1, 0.5], dtype=np.float32).tobytes(),\n", - " np.array([0.9, 0.9, 0.1], dtype=np.float32).tobytes(),\n", - "]\n", - "\n", - "for record, vector in zip(data, vectors):\n", - " record[\"user_embedding\"] = vector\n", - "\n", - "pprint(data)" + "import pickle\n", + "from jupyterutils import table_print, result_print\n", + "\n", + "# load in the example data and printing utils\n", + "data = pickle.load(open(\"hybrid_example_data.pkl\", \"rb\"))\n", + "table_print(data)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -114,6 +56,7 @@ " \"tag\": [{\"name\": \"credit_score\"}],\n", " \"text\": [{\"name\": \"job\"}],\n", " \"numeric\": [{\"name\": \"age\"}],\n", + " \"geo\": [{\"name\": \"office_location\"}],\n", " \"vector\": [{\n", " \"name\": \"user_embedding\",\n", " \"dims\": 3,\n", @@ -127,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -145,15 +88,15 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m17:46:27\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[8815]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m17:46:27\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[8815]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n" + "\u001b[32m19:22:36\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[21909]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m19:22:36\u001b[0m \u001b[35msam.partee-NW9MQX5Y74\u001b[0m \u001b[34mredisvl.cli.index[21909]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n" ] } ], @@ -164,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -193,18 +136,20 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", - "Document {'id': 'v1:tim', 'payload': None, 'vector_distance': '0.158809006214', 'user': 'tim', 'credit_score': 'high', 'age': '12', 'job': 'dermatologist'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158809006214timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -214,13 +159,12 @@ "\n", "v = VectorQuery([0.1, 0.1, 0.5],\n", " \"user_embedding\",\n", - " return_fields=[\"user\", \"credit_score\", \"age\", \"job\"],\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", " hybrid_filter=t)\n", "\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" ] }, { @@ -234,18 +178,20 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n", - "Document {'id': 'v1:joe', 'payload': None, 'vector_distance': '0.653301358223', 'user': 'joe', 'credit_score': 'medium', 'age': '35', 'job': 'dentist'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -254,8 +200,7 @@ "v.set_filter(n)\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" ] }, { @@ -269,16 +214,20 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:derrick', 'payload': None, 'vector_distance': '0', 'user': 'derrick', 'credit_score': 'low', 'age': '14', 'job': 'doctor'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -288,8 +237,72 @@ "v.set_filter(text_filter)\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Geographic Filters\n", + "\n", + "Geographic filters are filters that are applied to geographic fields. These filters are used to find results that are within a certain distance of a given point. The distance is specified in kilometers, miles, meters, or feet. A radius can also be specified to find results within a certain radius of a given point." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from redisvl.query import GeoFilter\n", + "\n", + "# within 10 km of San Francisco office\n", + "geo_filter = GeoFilter(\"office_location\", -122.4194, 37.7749 , 10, \"km\")\n", + "v.set_filter(geo_filter)\n", + "\n", + "results = index.query(v)\n", + "result_print(results)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158809006214timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# change to 100 km of San Francisco office\n", + "geo_filter = GeoFilter(\"office_location\", -122.4194, 37.7749 , 100, \"km\")\n", + "v.set_filter(geo_filter)\n", + "\n", + "results = index.query(v)\n", + "result_print(results)" ] }, { @@ -303,17 +316,20 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -323,13 +339,12 @@ "\n", "v = VectorQuery([0.1, 0.1, 0.5],\n", " \"user_embedding\",\n", - " return_fields=[\"user\", \"credit_score\", \"age\", \"job\"],\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", " hybrid_filter=t)\n", "\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" ] }, { @@ -343,15 +358,20 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:tim', 'payload': None, 'vector_distance': '0.158809006214', 'user': 'tim', 'credit_score': 'high', 'age': '12', 'job': 'dermatologist'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0.158809006214timhigh12dermatologist-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -362,8 +382,7 @@ "v.set_filter(t)\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" ] }, { @@ -377,19 +396,20 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'user': 'john', 'credit_score': 'high', 'age': '18', 'job': 'engineer'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'user': 'tyler', 'credit_score': 'high', 'age': '100', 'job': 'engineer'}\n", - "Document {'id': 'v1:tim', 'payload': None, 'vector_distance': '0.158809006214', 'user': 'tim', 'credit_score': 'high', 'age': '12', 'job': 'dermatologist'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'user': 'nancy', 'credit_score': 'high', 'age': '94', 'job': 'doctor'}\n", - "Document {'id': 'v1:joe', 'payload': None, 'vector_distance': '0.653301358223', 'user': 'joe', 'credit_score': 'medium', 'age': '35', 'job': 'dentist'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158809006214timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -400,8 +420,7 @@ "v.set_filter(t)\n", "\n", "results = index.query(v)\n", - "for doc in results.docs:\n", - " print(doc)" + "result_print(results)" ] }, { @@ -419,19 +438,20 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Document {'id': 'v1:tyler', 'payload': None, 'vector_distance': '0.109129190445', 'age': '100', 'user': 'tyler', 'credit_score': 'high', 'job': 'engineer'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'vector_distance': '0.266666650772', 'age': '94', 'user': 'nancy', 'credit_score': 'high', 'job': 'doctor'}\n", - "Document {'id': 'v1:joe', 'payload': None, 'vector_distance': '0.653301358223', 'age': '35', 'user': 'joe', 'credit_score': 'medium', 'job': 'dentist'}\n", - "Document {'id': 'v1:john', 'payload': None, 'vector_distance': '0', 'age': '18', 'user': 'john', 'credit_score': 'high', 'job': 'engineer'}\n", - "Document {'id': 'v1:tim', 'payload': None, 'vector_distance': '0.158809006214', 'age': '12', 'user': 'tim', 'credit_score': 'high', 'job': 'dermatologist'}\n" - ] + "data": { + "text/html": [ + "
vector_distanceageusercredit_scorejoboffice_location
0.109129190445100tylerhighengineer-122.0839,37.3861
0.26666665077294nancyhighdoctor-122.4194,37.7749
0.65330135822335joemediumdentist-122.0839,37.3861
018johnhighengineer-122.4194,37.7749
0.15880900621412timhighdermatologist-122.0839,37.3861
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -443,8 +463,7 @@ "\n", "# run the query with the ``SearchIndex.search`` method\n", "result = index.search(redis_py_query, v.params)\n", - "for doc in result.docs:\n", - " print(doc)" + "result_print(result)" ] }, { @@ -458,7 +477,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -467,7 +486,7 @@ "'(@credit_score:{high})'" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -479,24 +498,24 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Document {'id': 'v1:john', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'user_embedding': '==\\x00\\x00\\x00?'}\n", - "Document {'id': 'v1:nancy', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", - "Document {'id': 'v1:tyler', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", - "Document {'id': 'v1:tim', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" + "{'id': 'v1:38bfee0253ca452e96b4b3fdcb2798f7', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'v1:747bce550564443199ae1118cf03b5e3', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'v1:da7b6b0bf94f4c40a2ea23e20035ca73', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'v1:abcdf6be4fb042389a93a9b27d6cce5c', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" ] } ], "source": [ "results = index.search(str(t))\n", - "for doc in results.docs:\n", - " print(doc)" + "for r in results.docs:\n", + " print(r.__dict__)" ] }, { @@ -512,7 +531,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -521,7 +540,7 @@ "'(@credit_score:{high} @age:[18 100])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 5 user credit_score age job vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10'" ] }, - "execution_count": 16, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } diff --git a/docs/user_guide/jupyterutils.py b/docs/user_guide/jupyterutils.py new file mode 100644 index 00000000..8e4954b7 --- /dev/null +++ b/docs/user_guide/jupyterutils.py @@ -0,0 +1,42 @@ +from IPython.display import display, HTML + +def table_print(dict_list): + # If there's nothing in the list, there's nothing to print + if len(dict_list) == 0: + return + + # Getting column names (dictionary keys) using the first dictionary + columns = dict_list[0].keys() + + # HTML table header + html = '' + + # HTML table content + for dictionary in dict_list: + html += '' + + # HTML table footer + html += '
' + html += ''.join(columns) + html += '
' + html += ''.join(str(dictionary[column]) for column in columns) + html += '
' + + # Displaying the table + display(HTML(html)) + + +def result_print(results): + # If there's nothing in the list, there's nothing to print + if len(results.docs) == 0: + return + + data = [doc.__dict__ for doc in results.docs] + + to_remove = ["id", "payload"] + for doc in data: + for key in to_remove: + if key in doc: + del doc[key] + + table_print(data) \ No newline at end of file diff --git a/docs/user_guide/providers_03.ipynb b/docs/user_guide/vectorizers_03.ipynb similarity index 100% rename from docs/user_guide/providers_03.ipynb rename to docs/user_guide/vectorizers_03.ipynb diff --git a/redisvl/cli/query.py b/redisvl/cli/query.py deleted file mode 100644 index e69de29b..00000000 diff --git a/redisvl/query.py b/redisvl/query.py index d1eaf149..51903183 100644 --- a/redisvl/query.py +++ b/redisvl/query.py @@ -66,6 +66,61 @@ def to_string(self) -> str: ) +class GeoFilter(Filter): + GEO_UNITS = ["m", "km", "mi", "ft"] + + def __init__(self, field, longitude, latitude, radius, unit="km"): + """Filter for Geo fields. + + Args: + field (str): The field to filter on. + longitude (float): The longitude. + latitude (float): The latitude. + radius (float): The radius. + unit (str, optional): The unit of the radius. Defaults to "km". + + Raises: + ValueError: If the unit is not one of ["m", "km", "mi", "ft"]. + + Examples: + >>> # looking for Chinese restaurants near San Francisco + >>> # (within a 5km radius) would be + >>> # + >>> from redisvl.query import GeoFilter + >>> gf = GeoFilter("location", -122.4194, 37.7749, 5) + """ + super().__init__(field) + self._longitude = longitude + self._latitude = latitude + self._radius = radius + self._unit = self._set_unit(unit) + + def _set_unit(self, unit): + if unit.lower() not in self.GEO_UNITS: + raise ValueError(f"Unit must be one of {self.GEO_UNITS}") + return unit.lower() + + def to_string(self) -> str: + """Converts the geo filter to a string. + + Returns: + str: The geo filter as a string. + """ + return ( + "@" + + self._field + + ":[" + + str(self._longitude) + + " " + + str(self._latitude) + + " " + + str(self._radius) + + " " + + self._unit + + "]" + ) + + class NumericFilter(Filter): def __init__(self, field, minval, maxval, min_exclusive=False, max_exclusive=False): """Filter for Numeric fields. diff --git a/tests/test_filter.py b/tests/test_filter.py index e3d48351..030afec9 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -1,6 +1,13 @@ import pytest -from redisvl.query import Filter, NumericFilter, TagFilter, TextFilter, VectorQuery +from redisvl.query import ( + Filter, + GeoFilter, + NumericFilter, + TagFilter, + TextFilter, + VectorQuery, +) from redisvl.utils.utils import TokenEscaper @@ -25,6 +32,13 @@ def test_text_filter(self): txt_f = TextFilter("text_field", "text") assert txt_f.to_string() == "@text_field:text" + def test_geo_filter(self): + geo_f = GeoFilter("geo_field", 1, 2, 3) + assert geo_f.to_string() == "@geo_field:[1 2 3 km]" + + geo_f = GeoFilter("geo_field", 1, 2, 3, unit="m") + assert geo_f.to_string() == "@geo_field:[1 2 3 m]" + def test_filters_combination(self): tf1 = TagFilter("tag_field", ["tag1", "tag2"]) tf2 = TagFilter("tag_field", ["tag3"])