diff --git a/redis/_parsers/helpers.py b/redis/_parsers/helpers.py index 154dc66dfb..9cb4682fe2 100644 --- a/redis/_parsers/helpers.py +++ b/redis/_parsers/helpers.py @@ -224,6 +224,39 @@ def zset_score_pairs(response, **options): return list(zip(it, map(score_cast_func, it))) +def zset_score_for_rank(response, **options): + """ + If ``withscores`` is specified in the options, return the response as + a list of (value, score) pairs + """ + if not response or not options.get("withscore"): + return response + score_cast_func = options.get("score_cast_func", float) + return [response[0], score_cast_func(response[1])] + + +def zset_score_pairs_resp3(response, **options): + """ + If ``withscores`` is specified in the options, return the response as + a list of (value, score) pairs + """ + if not response or not options.get("withscores"): + return response + score_cast_func = options.get("score_cast_func", float) + return [[name, score_cast_func(val)] for name, val in response] + + +def zset_score_for_rank_resp3(response, **options): + """ + If ``withscores`` is specified in the options, return the response as + a list of (value, score) pairs + """ + if not response or not options.get("withscore"): + return response + score_cast_func = options.get("score_cast_func", float) + return [response[0], score_cast_func(response[1])] + + def sort_return_tuples(response, **options): """ If ``groups`` is specified, return the response as a list of @@ -797,10 +830,14 @@ def string_keys_to_dict(key_string, callback): "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() ), **string_keys_to_dict( - "ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZRANK ZREVRANGE " - "ZREVRANGEBYSCORE ZREVRANK ZUNION", + "ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZREVRANGE " + "ZREVRANGEBYSCORE ZUNION", zset_score_pairs, ), + **string_keys_to_dict( + "ZREVRANK ZRANK", + zset_score_for_rank, + ), **string_keys_to_dict("ZINCRBY ZSCORE", float_or_none), **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), @@ -844,10 +881,17 @@ def string_keys_to_dict(key_string, callback): "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() ), **string_keys_to_dict( - "ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE " - "ZUNION HGETALL XREADGROUP", + "ZRANGE ZINTER ZPOPMAX ZPOPMIN HGETALL XREADGROUP", lambda r, **kwargs: r, ), + **string_keys_to_dict( + "ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE ZUNION", + zset_score_pairs_resp3, + ), + **string_keys_to_dict( + "ZREVRANK ZRANK", + zset_score_for_rank_resp3, + ), **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3), "ACL LOG": lambda r: ( [ diff --git a/redis/commands/core.py b/redis/commands/core.py index ad21885b3d..90352c233e 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -4778,6 +4778,7 @@ def zrank( name: KeyT, value: EncodableT, withscore: bool = False, + score_cast_func: Union[type, Callable] = float, ) -> ResponseT: """ Returns a 0-based value indicating the rank of ``value`` in sorted set @@ -4785,11 +4786,17 @@ def zrank( The optional WITHSCORE argument supplements the command's reply with the score of the element returned. + ``score_cast_func`` a callable used to cast the score return value + For more information, see https://redis.io/commands/zrank """ + pieces = ["ZRANK", name, value] if withscore: - return self.execute_command("ZRANK", name, value, "WITHSCORE", keys=[name]) - return self.execute_command("ZRANK", name, value, keys=[name]) + pieces.append("WITHSCORE") + + options = {"withscore": withscore, "score_cast_func": score_cast_func} + + return self.execute_command(*pieces, **options) def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -4837,6 +4844,7 @@ def zrevrank( name: KeyT, value: EncodableT, withscore: bool = False, + score_cast_func: Union[type, Callable] = float, ) -> ResponseT: """ Returns a 0-based value indicating the descending rank of @@ -4844,13 +4852,17 @@ def zrevrank( The optional ``withscore`` argument supplements the command's reply with the score of the element returned. + ``score_cast_func`` a callable used to cast the score return value + For more information, see https://redis.io/commands/zrevrank """ + pieces = ["ZREVRANK", name, value] if withscore: - return self.execute_command( - "ZREVRANK", name, value, "WITHSCORE", keys=[name] - ) - return self.execute_command("ZREVRANK", name, value, keys=[name]) + pieces.append("WITHSCORE") + + options = {"withscore": withscore, "score_cast_func": score_cast_func} + + return self.execute_command(*pieces, **options) def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: """ @@ -4865,6 +4877,7 @@ def zunion( keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], aggregate: Optional[str] = None, withscores: bool = False, + score_cast_func: Union[type, Callable] = float, ) -> ResponseT: """ Return the union of multiple sorted sets specified by ``keys``. @@ -4872,9 +4885,18 @@ def zunion( Scores will be aggregated based on the ``aggregate``, or SUM if none is provided. + ``score_cast_func`` a callable used to cast the score return value + For more information, see https://redis.io/commands/zunion """ - return self._zaggregate("ZUNION", None, keys, aggregate, withscores=withscores) + return self._zaggregate( + "ZUNION", + None, + keys, + aggregate, + withscores=withscores, + score_cast_func=score_cast_func, + ) def zunionstore( self, diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index dda4dc2a1c..14690ef6e6 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -23,6 +23,7 @@ from redis.commands.json.path import Path from redis.commands.search.field import TextField from redis.commands.search.query import Query +from redis.utils import safe_str from tests.conftest import ( assert_resp_response, assert_resp_response_in, @@ -2071,11 +2072,14 @@ async def test_zrange(self, r: redis.Redis): r, response, [(b"a2", 2.0), (b"a3", 3.0)], [[b"a2", 2.0], [b"a3", 3.0]] ) - # custom score function - # assert await r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - # (b"a1", 1), - # (b"a2", 2), - # ] + # custom score cast function + response = await r.zrange("a", 0, 1, withscores=True, score_cast_func=safe_str) + assert_resp_response( + r, + response, + [(b"a1", "1"), (b"a2", "2")], + [[b"a1", "1.0"], [b"a2", "2.0"]], + ) @skip_if_server_version_lt("2.8.9") async def test_zrangebylex(self, r: redis.Redis): @@ -2127,6 +2131,15 @@ async def test_zrangebyscore(self, r: redis.Redis): [(b"a2", 2), (b"a3", 3), (b"a4", 4)], [[b"a2", 2], [b"a3", 3], [b"a4", 4]], ) + response = await r.zrangebyscore( + "a", 2, 4, withscores=True, score_cast_func=safe_str + ) + assert_resp_response( + r, + response, + [(b"a2", "2"), (b"a3", "3"), (b"a4", "4")], + [[b"a2", "2.0"], [b"a3", "3.0"], [b"a4", "4.0"]], + ) async def test_zrank(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2141,10 +2154,14 @@ async def test_zrank_withscore(self, r: redis.Redis): assert await r.zrank("a", "a2") == 1 assert await r.zrank("a", "a6") is None assert_resp_response( - r, await r.zrank("a", "a3", withscore=True), [2, b"3"], [2, 3.0] + r, await r.zrank("a", "a3", withscore=True), [2, 3.0], [2, 3.0] ) assert await r.zrank("a", "a6", withscore=True) is None + # custom score cast function + response = await r.zrank("a", "a3", withscore=True, score_cast_func=safe_str) + assert_resp_response(r, response, [2, "3"], [2, "3.0"]) + async def test_zrem(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert await r.zrem("a", "a2") == 1 @@ -2200,6 +2217,19 @@ async def test_zrevrange(self, r: redis.Redis): r, response, [(b"a3", 3), (b"a2", 2)], [[b"a3", 3], [b"a2", 2]] ) + # custom score cast function + # should be applied to resp2 and resp3 + # responses + response = await r.zrevrange( + "a", 0, 1, withscores=True, score_cast_func=safe_str + ) + assert_resp_response( + r, + response, + [(b"a3", "3"), (b"a2", "2")], + [[b"a3", "3.0"], [b"a2", "2.0"]], + ) + async def test_zrevrangebyscore(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) assert await r.zrevrangebyscore("a", 4, 2) == [b"a4", b"a3", b"a2"] @@ -2240,7 +2270,7 @@ async def test_zrevrank_withscore(self, r: redis.Redis): assert await r.zrevrank("a", "a2") == 3 assert await r.zrevrank("a", "a6") is None assert_resp_response( - r, await r.zrevrank("a", "a3", withscore=True), [2, b"3"], [2, 3.0] + r, await r.zrevrank("a", "a3", withscore=True), [2, 3.0], [2, 3.0] ) assert await r.zrevrank("a", "a6", withscore=True) is None diff --git a/tests/test_commands.py b/tests/test_commands.py index 2871dc45bf..c5974ea0fc 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -21,6 +21,7 @@ from redis.commands.json.path import Path from redis.commands.search.field import TextField from redis.commands.search.query import Query +from redis.utils import safe_str from tests.test_utils import redis_server_time from .conftest import ( @@ -3039,11 +3040,13 @@ def test_zrange(self, r): [[b"a2", 2.0], [b"a3", 3.0]], ) - # # custom score function - # assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - # (b"a1", 1), - # (b"a2", 2), - # ] + # custom score cast function + assert_resp_response( + r, + r.zrange("a", 0, 1, withscores=True, score_cast_func=safe_str), + [(b"a1", "1"), (b"a2", "2")], + [[b"a1", "1.0"], [b"a2", "2.0"]], + ) def test_zrange_errors(self, r): with pytest.raises(exceptions.DataError): @@ -3153,6 +3156,13 @@ def test_zrangebyscore(self, r): [(b"a2", 2), (b"a3", 3), (b"a4", 4)], [[b"a2", 2], [b"a3", 3], [b"a4", 4]], ) + # custom score cast function + assert_resp_response( + r, + r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=safe_str), + [(b"a2", "2"), (b"a3", "3"), (b"a4", "4")], + [[b"a2", "2.0"], [b"a3", "3.0"], [b"a4", "4.0"]], + ) def test_zrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -3166,9 +3176,17 @@ def test_zrank_withscore(self, r: redis.Redis): assert r.zrank("a", "a1") == 0 assert r.zrank("a", "a2") == 1 assert r.zrank("a", "a6") is None - assert_resp_response(r, r.zrank("a", "a3", withscore=True), [2, b"3"], [2, 3.0]) + assert_resp_response(r, r.zrank("a", "a3", withscore=True), [2, 3.0], [2, 3.0]) assert r.zrank("a", "a6", withscore=True) is None + # custom score cast function + assert_resp_response( + r, + r.zrank("a", "a3", withscore=True, score_cast_func=safe_str), + [2, "3"], + [2, "3.0"], + ) + def test_zrem(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert r.zrem("a", "a2") == 1 @@ -3222,11 +3240,15 @@ def test_zrevrange(self, r): [[b"a2", 2.0], [b"a1", 1.0]], ) - # # custom score function - # assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - # (b"a3", 3.0), - # (b"a2", 2.0), - # ] + # custom score cast function + # should be applied to resp2 and resp3 + # responses + assert_resp_response( + r, + r.zrevrange("a", 0, 1, withscores=True, score_cast_func=safe_str), + [(b"a3", "3"), (b"a2", "2")], + [[b"a3", "3.0"], [b"a2", "2.0"]], + ) def test_zrevrangebyscore(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -3241,13 +3263,20 @@ def test_zrevrangebyscore(self, r): [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], ) - # custom score function + # custom score type cast function assert_resp_response( r, r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int), [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], ) + # custom score cast function + assert_resp_response( + r, + r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=safe_str), + [(b"a4", "4"), (b"a3", "3"), (b"a2", "2")], + [[b"a4", "4.0"], [b"a3", "3.0"], [b"a2", "2.0"]], + ) def test_zrevrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -3262,10 +3291,18 @@ def test_zrevrank_withscore(self, r): assert r.zrevrank("a", "a2") == 3 assert r.zrevrank("a", "a6") is None assert_resp_response( - r, r.zrevrank("a", "a3", withscore=True), [2, b"3"], [2, 3.0] + r, r.zrevrank("a", "a3", withscore=True), [2, 3.0], [2, 3.0] ) assert r.zrevrank("a", "a6", withscore=True) is None + # custom score cast function + assert_resp_response( + r, + r.zrevrank("a", "a3", withscore=True, score_cast_func=safe_str), + [2, "3"], + [2, "3.0"], + ) + def test_zscore(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert r.zscore("a", "a1") == 1.0 @@ -3307,6 +3344,13 @@ def test_zunion(self, r): [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], [[b"a2", 5], [b"a4", 12], [b"a3", 20], [b"a1", 23]], ) + # with custom score cast function + assert_resp_response( + r, + r.zunion(["a", "b", "c"], withscores=True, score_cast_func=safe_str), + [(b"a2", "3"), (b"a4", "4"), (b"a3", "8"), (b"a1", "9")], + [[b"a2", "3.0"], [b"a4", "4.0"], [b"a3", "8.0"], [b"a1", "9.0"]], + ) @pytest.mark.onlynoncluster def test_zunionstore_sum(self, r):