Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/api/filter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,20 @@ NumericFilter
:show-inheritance:
:members:
:inherited-members:


GeoFilter
=========

.. currentmodule:: redisvl.query

.. autosummary::

GeoFilter.__init__
GeoFilter.to_string


.. autoclass:: GeoFilter
:show-inheritance:
:members:
:inherited-members:
Binary file added docs/user_guide/hybrid_example_data.pkl
Binary file not shown.
351 changes: 185 additions & 166 deletions docs/user_guide/hybrid_queries_02.ipynb

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions docs/user_guide/jupyterutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from IPython.display import display, HTML

def table_print(dict_list):
# If there's nothing in the list, there's nothing to print
if len(dict_list) == 0:
return

# Getting column names (dictionary keys) using the first dictionary
columns = dict_list[0].keys()

# HTML table header
html = '<table><tr><th>'
html += '</th><th>'.join(columns)
html += '</th></tr>'

# HTML table content
for dictionary in dict_list:
html += '<tr><td>'
html += '</td><td>'.join(str(dictionary[column]) for column in columns)
html += '</td></tr>'

# HTML table footer
html += '</table>'

# Displaying the table
display(HTML(html))


def result_print(results):
# If there's nothing in the list, there's nothing to print
if len(results.docs) == 0:
return

data = [doc.__dict__ for doc in results.docs]

to_remove = ["id", "payload"]
for doc in data:
for key in to_remove:
if key in doc:
del doc[key]

table_print(data)
Empty file removed redisvl/cli/query.py
Empty file.
55 changes: 55 additions & 0 deletions redisvl/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,61 @@ def to_string(self) -> str:
)


class GeoFilter(Filter):
GEO_UNITS = ["m", "km", "mi", "ft"]

def __init__(self, field, longitude, latitude, radius, unit="km"):
"""Filter for Geo fields.

Args:
field (str): The field to filter on.
longitude (float): The longitude.
latitude (float): The latitude.
radius (float): The radius.
unit (str, optional): The unit of the radius. Defaults to "km".

Raises:
ValueError: If the unit is not one of ["m", "km", "mi", "ft"].

Examples:
>>> # looking for Chinese restaurants near San Francisco
>>> # (within a 5km radius) would be
>>> #
>>> from redisvl.query import GeoFilter
>>> gf = GeoFilter("location", -122.4194, 37.7749, 5)
"""
super().__init__(field)
self._longitude = longitude
self._latitude = latitude
self._radius = radius
self._unit = self._set_unit(unit)

def _set_unit(self, unit):
if unit.lower() not in self.GEO_UNITS:
raise ValueError(f"Unit must be one of {self.GEO_UNITS}")
return unit.lower()

def to_string(self) -> str:
"""Converts the geo filter to a string.

Returns:
str: The geo filter as a string.
"""
return (
"@"
+ self._field
+ ":["
+ str(self._longitude)
+ " "
+ str(self._latitude)
+ " "
+ str(self._radius)
+ " "
+ self._unit
+ "]"
)


class NumericFilter(Filter):
def __init__(self, field, minval, maxval, min_exclusive=False, max_exclusive=False):
"""Filter for Numeric fields.
Expand Down
16 changes: 15 additions & 1 deletion tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import pytest

from redisvl.query import Filter, NumericFilter, TagFilter, TextFilter, VectorQuery
from redisvl.query import (
Filter,
GeoFilter,
NumericFilter,
TagFilter,
TextFilter,
VectorQuery,
)
from redisvl.utils.utils import TokenEscaper


Expand All @@ -25,6 +32,13 @@ def test_text_filter(self):
txt_f = TextFilter("text_field", "text")
assert txt_f.to_string() == "@text_field:text"

def test_geo_filter(self):
geo_f = GeoFilter("geo_field", 1, 2, 3)
assert geo_f.to_string() == "@geo_field:[1 2 3 km]"

geo_f = GeoFilter("geo_field", 1, 2, 3, unit="m")
assert geo_f.to_string() == "@geo_field:[1 2 3 m]"

def test_filters_combination(self):
tf1 = TagFilter("tag_field", ["tag1", "tag2"])
tf2 = TagFilter("tag_field", ["tag3"])
Expand Down