Skip to content

Commit

Permalink
feat: Logical and Selector filters (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
adilansari committed May 23, 2023
1 parent bd93355 commit 03876ff
Show file tree
Hide file tree
Showing 12 changed files with 294 additions and 49 deletions.
6 changes: 2 additions & 4 deletions tests/test_search_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tests import StubRpcError
from tigrisdb.errors import TigrisServerError
from tigrisdb.search_index import SearchIndex
from tigrisdb.types import ClientConfig, Document
from tigrisdb.types import ClientConfig, Document, RFC3339_format
from tigrisdb.types.search import Query as SearchQuery
from tigrisdb.utils import marshal, unmarshal

Expand Down Expand Up @@ -204,9 +204,7 @@ def test_get_many(self, grpc_search):
{"id": "2", "title": "reliable systems 🙏", "tags": ["it"]},
]
ts, proto_ts = (
datetime.datetime.strptime(
"2023-05-05T10:00:00+00:00", "%Y-%m-%dT%H:%M:%S%z"
),
datetime.datetime.strptime("2023-05-05T10:00:00+00:00", RFC3339_format),
ProtoTimestamp(),
)
proto_ts.FromDatetime(ts)
Expand Down
81 changes: 81 additions & 0 deletions tests/test_types_filters_logical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from unittest import TestCase

from tigrisdb.types.filters import GT, GTE, LTE, And, Contains, Eq, Not, Or, Regex


class LogicalFiltersTest(TestCase):
def test_or(self):
f = Or(Eq("f1", True), LTE("f2", 25), Contains("f3", "v1"))
self.assertEqual(
{"$or": [{"f1": True}, {"f2": {"$lte": 25}}, {"f3": {"$contains": "v1"}}]},
f.query(),
)

def test_and(self):
f = And(Eq("f1", True), LTE("f2", 25), Contains("f3", "v1"))
self.assertEqual(
{"$and": [{"f1": True}, {"f2": {"$lte": 25}}, {"f3": {"$contains": "v1"}}]},
f.query(),
)

def test_empty(self):
self.assertEqual({}, And().query())
self.assertEqual({}, Or().query())

def test_single(self):
self.assertEqual({"f1": {"$gt": 10.5}}, Or(GT("f1", 10.5)).query())

def test_complex_or_filter(self):
f = Or(
Or(Eq("name", "alice"), GTE("rank", 2)),
Or(Eq("name", "emma"), LTE("rank", 6), Not("city", "sfo")),
And(GTE("f1", 1.5), LTE("f1", 3.14), Contains("f2", "hello")),
Not("f3", False),
Regex("f4", "/andy/i"),
)
self.assertEqual(
{
"$or": [
{"$or": [{"name": "alice"}, {"rank": {"$gte": 2}}]},
{
"$or": [
{"name": "emma"},
{"rank": {"$lte": 6}},
{"city": {"$not": "sfo"}},
]
},
{
"$and": [
{"f1": {"$gte": 1.5}},
{"f1": {"$lte": 3.14}},
{"f2": {"$contains": "hello"}},
]
},
{"f3": {"$not": False}},
{"f4": {"$regex": "/andy/i"}},
]
},
f.query(),
)

def test_complex_and_filter(self):
f = And(
Or(Eq("name", "alice"), GTE("rank", 2)),
Or(Eq("name", "emma"), LTE("rank", 6), Not("city", "sfo")),
)

self.assertEqual(
{
"$and": [
{"$or": [{"name": "alice"}, {"rank": {"$gte": 2}}]},
{
"$or": [
{"name": "emma"},
{"rank": {"$lte": 6}},
{"city": {"$not": "sfo"}},
]
},
]
},
f.query(),
)
40 changes: 40 additions & 0 deletions tests/test_types_filters_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import datetime
from unittest import TestCase

from tigrisdb.types import RFC3339_format
from tigrisdb.types.filters import GT, GTE, LT, LTE, Contains, Eq, Not, Regex


class SelectorTestCase(TestCase):
def test_equals(self):
f = Eq("f1", 25)
self.assertEqual({"f1": 25}, f.query())

def test_gte(self):
f = GTE("f1", 25)
self.assertEqual({"f1": {"$gte": 25}}, f.query())

def test_gt(self):
dt = datetime.datetime.strptime("2023-05-05T10:00:00+00:00", RFC3339_format)
f = GT("f1", dt)
self.assertEqual({"f1": {"$gt": dt}}, f.query())

def test_lte(self):
f = LTE("f1", 25)
self.assertEqual({"f1": {"$lte": 25}}, f.query())

def test_lt(self):
f = LT("f1", 25)
self.assertEqual({"f1": {"$lt": 25}}, f.query())

def test_regex(self):
f = Regex("f1", "emma*")
self.assertEqual({"f1": {"$regex": "emma*"}}, f.query())

def test_contains(self):
f = Contains("f1", "cars")
self.assertEqual({"f1": {"$contains": "cars"}}, f.query())

def test_not(self):
f = Not("f1", "Alex")
self.assertEqual({"f1": {"$not": "Alex"}}, f.query())
64 changes: 36 additions & 28 deletions tests/test_types_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,38 +56,42 @@ def test_default_fields(self):
self.assertEqual(proto_req.page_size, 20)

def test_with_q(self):
query, proto_req = Query(), SearchIndexRequest()
query.q = "hello world"
query, proto_req = Query(q="hello world"), SearchIndexRequest()
query.__build__(proto_req)

self.assertEqual("hello world", proto_req.q)

def test_with_search_fields(self):
query, proto_req = Query(), SearchIndexRequest()
query.search_fields = ["name.first", "balance"]
query, proto_req = (
Query(search_fields=["name.first", "balance"]),
SearchIndexRequest(),
)
query.__build__(proto_req)

self.assertEqual(["name.first", "balance"], proto_req.search_fields)

def test_with_vector_query(self):
query, proto_req = Query(), SearchIndexRequest()
query.vector_query = VectorField("embedding", [1.1, 2.456, 34.88])
query, proto_req = (
Query(vector_query=VectorField("embedding", [1.1, 2.456, 34.88])),
SearchIndexRequest(),
)
query.__build__(proto_req)

self.assertEqual(
'{"embedding": [1.1, 2.456, 34.88]}'.encode(), proto_req.vector
)

def test_with_facet_by(self):
query, proto_req = Query(), SearchIndexRequest()
query.facet_by = "field_1"
query, proto_req = Query(facet_by="field_1"), SearchIndexRequest()
query.__build__(proto_req)
self.assertEqual(
'{"field_1": {"size": 10, "type": "value"}}'.encode(), proto_req.facet
)

query, proto_req = Query(), SearchIndexRequest()
query.facet_by = ["f1", FacetSize("f2", 25), FacetSize("f3")]
query, proto_req = (
Query(facet_by=["f1", FacetSize("f2", 25), FacetSize("f3")]),
SearchIndexRequest(),
)
query.__build__(proto_req)
self.assertEqual(
'{"f1": {"size": 10, "type": "value"}, "f2": {"size": 25, "type": "value"},'
Expand All @@ -96,48 +100,52 @@ def test_with_facet_by(self):
)

def test_with_sort_by(self):
query, proto_req = Query(), SearchIndexRequest()
query.sort_by = sort.Ascending("f1")
query, proto_req = Query(sort_by=sort.Ascending("f1")), SearchIndexRequest()
query.__build__(proto_req)
self.assertEqual('[{"f1": "$asc"}]'.encode(), proto_req.sort)

query, proto_req = Query(), SearchIndexRequest()
query.sort_by = [
sort.Descending("f2"),
sort.Ascending("f1"),
sort.Ascending("f3"),
]
query, proto_req = (
Query(
sort_by=[
sort.Descending("f2"),
sort.Ascending("f1"),
sort.Ascending("f3"),
]
),
SearchIndexRequest(),
)
query.__build__(proto_req)
self.assertEqual(
'[{"f2": "$desc"}, {"f1": "$asc"}, {"f3": "$asc"}]'.encode(), proto_req.sort
)

def test_with_group_by(self):
query, proto_req = Query(), SearchIndexRequest()
query.group_by = "f1"
query, proto_req = Query(group_by="f1"), SearchIndexRequest()
query.__build__(proto_req)
self.assertEqual('["f1"]'.encode(), proto_req.group_by)

query, proto_req = Query(), SearchIndexRequest()
query.group_by = ["f1", "f2", "f3"]
query, proto_req = Query(group_by=["f1", "f2", "f3"]), SearchIndexRequest()
query.__build__(proto_req)
self.assertEqual('["f1", "f2", "f3"]'.encode(), proto_req.group_by)

def test_with_include_fields(self):
query, proto_req = Query(), SearchIndexRequest()
query.include_fields = ["f1", "f2", "f3"]
query, proto_req = (
Query(include_fields=["f1", "f2", "f3"]),
SearchIndexRequest(),
)
query.__build__(proto_req)
self.assertEqual(["f1", "f2", "f3"], proto_req.include_fields)

def test_with_exclude_fields(self):
query, proto_req = Query(), SearchIndexRequest()
query.exclude_fields = ["f1", "f2", "f3"]
query, proto_req = (
Query(exclude_fields=["f1", "f2", "f3"]),
SearchIndexRequest(),
)
query.__build__(proto_req)
self.assertEqual(["f1", "f2", "f3"], proto_req.exclude_fields)

def test_with_page_size(self):
query, proto_req = Query(), SearchIndexRequest()
query.hits_per_page = 25
query, proto_req = Query(hits_per_page=25), SearchIndexRequest()
query.__build__(proto_req)
self.assertEqual(25, proto_req.page_size)

Expand Down
17 changes: 17 additions & 0 deletions tests/test_types_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from unittest import TestCase

from tigrisdb.types.sort import Ascending, Descending, Sort


class SortTests(TestCase):
def test_ascending(self):
sort = Ascending("f1")
self.assertEqual({"f1": "$asc"}, sort.query())

def test_descending(self):
sort = Descending("f1")
self.assertEqual({"f1": "$desc"}, sort.query())

def test_base_sort_error(self):
with self.assertRaisesRegex(TypeError, "abstract class"):
Sort()
9 changes: 8 additions & 1 deletion tigrisdb/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,12 @@ class ClientConfig:

class Serializable(abc.ABC):
@abc.abstractmethod
def as_obj(self) -> Dict:
def query(self) -> Dict:
raise NotImplementedError()


class BaseOperator(abc.ABC):
@property
@abc.abstractmethod
def operator(self):
raise NotImplementedError()
2 changes: 2 additions & 0 deletions tigrisdb/types/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .logical import And, Or # noqa: F401
from .selector import GT, GTE, LT, LTE, Contains, Eq, Not, Regex # noqa: F401
31 changes: 31 additions & 0 deletions tigrisdb/types/filters/logical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import abc
from typing import Any, Dict, Union

from tigrisdb.types import BaseOperator, Serializable

from .selector import SelectorFilter


class LogicalFilter(Serializable, BaseOperator, abc.ABC):
def __init__(self, *args: Union[SelectorFilter, Any]):
self.filters = args

def query(self) -> Dict:
if not self.filters:
return {}
if len(self.filters) == 1:
return self.filters[0].query()
gen = [f.query() for f in self.filters]
return {self.operator: gen}


class And(LogicalFilter):
@property
def operator(self):
return "$and"


class Or(LogicalFilter):
@property
def operator(self):
return "$or"
64 changes: 64 additions & 0 deletions tigrisdb/types/filters/selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import abc
from typing import Any, Dict

from tigrisdb.types import BaseOperator, Serializable


class SelectorFilter(Serializable, BaseOperator, abc.ABC):
def __init__(self, field: str, value: Any):
self.field = field
self.value = value

def query(self) -> Dict:
return {self.field: {self.operator: self.value}}


class Eq(SelectorFilter):
@property
def operator(self):
return ""

def query(self) -> Dict:
return {self.field: self.value}


class Not(SelectorFilter):
@property
def operator(self):
return "$not"


class GT(SelectorFilter):
@property
def operator(self):
return "$gt"


class GTE(SelectorFilter):
@property
def operator(self):
return "$gte"


class LT(SelectorFilter):
@property
def operator(self):
return "$lt"


class LTE(SelectorFilter):
@property
def operator(self):
return "$lte"


class Regex(SelectorFilter):
@property
def operator(self):
return "$regex"


class Contains(SelectorFilter):
@property
def operator(self):
return "$contains"
Loading

0 comments on commit 03876ff

Please sign in to comment.