Skip to content

Commit

Permalink
Add extra tests for GEO search (#3244)
Browse files Browse the repository at this point in the history
Add more tests for GEO search, to cover the query operators `within`,
`contains`, `intersects` and `disjoint`, for POINT and POLYGON, i.e. the
currently supported shapes and operators.
  • Loading branch information
gerzse committed Jun 12, 2024
1 parent 9d85723 commit 29b861b
Showing 1 changed file with 69 additions and 13 deletions.
82 changes: 69 additions & 13 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2105,16 +2105,60 @@ def test_geo_params(client):
params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"}
q = Query("@g:[$lon $lat $radius $units]").dialect(2)
res = client.ft().search(q, query_params=params_dict)
if is_resp2_connection(client):
assert 3 == res.total
assert "doc1" == res.docs[0].id
assert "doc2" == res.docs[1].id
assert "doc3" == res.docs[2].id
else:
assert 3 == res["total_results"]
assert "doc1" == res["results"][0]["id"]
assert "doc2" == res["results"][1]["id"]
assert "doc3" == res["results"][2]["id"]
_assert_geosearch_result(client, res, ["doc1", "doc2", "doc3"])


@pytest.mark.redismod
def test_geoshapes_query_intersects_and_disjoint(client):
client.ft().create_index((GeoShapeField("g", coord_system=GeoShapeField.FLAT)))
client.hset("doc_point1", mapping={"g": "POINT (10 10)"})
client.hset("doc_point2", mapping={"g": "POINT (50 50)"})
client.hset("doc_polygon1", mapping={"g": "POLYGON ((20 20, 25 35, 35 25, 20 20))"})
client.hset(
"doc_polygon2", mapping={"g": "POLYGON ((60 60, 65 75, 70 70, 65 55, 60 60))"}
)

intersection = client.ft().search(
Query("@g:[intersects $shape]").dialect(3),
query_params={"shape": "POLYGON((15 15, 75 15, 50 70, 20 40, 15 15))"},
)
_assert_geosearch_result(client, intersection, ["doc_point2", "doc_polygon1"])

disjunction = client.ft().search(
Query("@g:[disjoint $shape]").dialect(3),
query_params={"shape": "POLYGON((15 15, 75 15, 50 70, 20 40, 15 15))"},
)
_assert_geosearch_result(client, disjunction, ["doc_point1", "doc_polygon2"])


@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.0", "search")
def test_geoshapes_query_contains_and_within(client):
client.ft().create_index((GeoShapeField("g", coord_system=GeoShapeField.FLAT)))
client.hset("doc_point1", mapping={"g": "POINT (10 10)"})
client.hset("doc_point2", mapping={"g": "POINT (50 50)"})
client.hset("doc_polygon1", mapping={"g": "POLYGON ((20 20, 25 35, 35 25, 20 20))"})
client.hset(
"doc_polygon2", mapping={"g": "POLYGON ((60 60, 65 75, 70 70, 65 55, 60 60))"}
)

contains_a = client.ft().search(
Query("@g:[contains $shape]").dialect(3),
query_params={"shape": "POINT(25 25)"},
)
_assert_geosearch_result(client, contains_a, ["doc_polygon1"])

contains_b = client.ft().search(
Query("@g:[contains $shape]").dialect(3),
query_params={"shape": "POLYGON((24 24, 24 26, 25 25, 24 24))"},
)
_assert_geosearch_result(client, contains_b, ["doc_polygon1"])

within = client.ft().search(
Query("@g:[within $shape]").dialect(3),
query_params={"shape": "POLYGON((15 15, 75 15, 50 70, 20 40, 15 15))"},
)
_assert_geosearch_result(client, within, ["doc_point2", "doc_polygon1"])


@pytest.mark.redismod
Expand Down Expand Up @@ -2278,7 +2322,19 @@ def test_geoshape(client: redis.Redis):
q2 = Query("@geom:[CONTAINS $poly]").dialect(3)
qp2 = {"poly": "POLYGON((2 2, 2 50, 50 50, 50 2, 2 2))"}
result = client.ft().search(q1, query_params=qp1)
assert len(result.docs) == 1
assert result.docs[0]["id"] == "small"
_assert_geosearch_result(client, result, ["small"])
result = client.ft().search(q2, query_params=qp2)
assert len(result.docs) == 2
_assert_geosearch_result(client, result, ["small", "large"])


def _assert_geosearch_result(client, result, expected_doc_ids):
"""
Make sure the result of a geo search is as expected, taking into account the RESP
version being used.
"""
if is_resp2_connection(client):
assert set([doc.id for doc in result.docs]) == set(expected_doc_ids)
assert result.total == len(expected_doc_ids)
else:
assert set([doc["id"] for doc in result["results"]]) == set(expected_doc_ids)
assert result["total_results"] == len(expected_doc_ids)

0 comments on commit 29b861b

Please sign in to comment.