diff --git a/docs/source/conf.py b/docs/source/conf.py index 19f777d..6fdd630 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # @@ -18,7 +17,7 @@ extensions = [ "m2r2", "sphinx.ext.autodoc", - 'sphinx_autodoc_typehints', + "sphinx_autodoc_typehints", "sphinx.ext.autosectionlabel", "sphinx.ext.intersphinx", "sphinx.ext.napoleon", @@ -26,7 +25,7 @@ templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] html_theme = "sphinx_rtd_theme" html_static_path = ["_static"] diff --git a/src/pytest_mock_resources/container/postgres.py b/src/pytest_mock_resources/container/postgres.py index c81a913..2a9bc7a 100644 --- a/src/pytest_mock_resources/container/postgres.py +++ b/src/pytest_mock_resources/container/postgres.py @@ -135,7 +135,9 @@ def detect_driver(drivername: Optional[str] = None, async_: bool = False) -> str if any(Distribution.discover(name="asyncpg")): return "postgresql+asyncpg" else: - if any(Distribution.discover(name="psycopg2")) or any(Distribution.discover(name="psycopg2-binary")): + if any(Distribution.discover(name="psycopg2")) or any( + Distribution.discover(name="psycopg2-binary") + ): return "postgresql+psycopg2" raise ValueError( # pragma: no cover diff --git a/src/pytest_mock_resources/container/redis.py b/src/pytest_mock_resources/container/redis.py index 7984e54..53c5d89 100644 --- a/src/pytest_mock_resources/container/redis.py +++ b/src/pytest_mock_resources/container/redis.py @@ -1,7 +1,7 @@ from typing import ClassVar, Iterable from pytest_mock_resources.compat import redis -from pytest_mock_resources.config import DockerContainerConfig +from pytest_mock_resources.config import DockerContainerConfig, fallback from pytest_mock_resources.container.base import ContainerCheckFailed @@ -17,17 +17,30 @@ class RedisConfig(DockerContainerConfig): Defaults to :code:`6380`. ci_port (int): The port to bind the container to when a CI environment is detected. Defaults to :code:`6379`. + decode_responses (bool): Whether to decode responses from the server on the client. + Defaults to :code:`False`. """ name = "redis" - _fields: ClassVar[Iterable] = {"image", "host", "port", "ci_port"} + _fields: ClassVar[Iterable] = { + "image", + "host", + "port", + "ci_port", + "decode_responses", + } _fields_defaults: ClassVar[dict] = { "image": "redis:5.0.7", "port": 6380, "ci_port": 6379, + "decode_responses": False, } + @fallback + def decode_responses(self): + raise NotImplementedError() + def ports(self): return {6379: self.port} diff --git a/src/pytest_mock_resources/fixture/redis.py b/src/pytest_mock_resources/fixture/redis.py index 03f7ecf..e6a1584 100644 --- a/src/pytest_mock_resources/fixture/redis.py +++ b/src/pytest_mock_resources/fixture/redis.py @@ -1,5 +1,4 @@ import pytest - from pytest_mock_resources.compat import redis from pytest_mock_resources.container.base import get_container from pytest_mock_resources.container.redis import RedisConfig @@ -23,7 +22,7 @@ def pmr_redis_container(pytestconfig, pmr_redis_config): yield from get_container(pytestconfig, pmr_redis_config) -def create_redis_fixture(scope="function"): +def create_redis_fixture(scope="function", decode_responses: bool = False): """Produce a Redis fixture. Any number of fixture functions can be created. Under the hood they will all share the same @@ -44,6 +43,7 @@ def create_redis_fixture(scope="function"): Args: scope (str): The scope of the fixture can be specified by the user, defaults to "function". + decode_responses (bool): Whether to decode the responses from redis. Raises: KeyError: If any additional arguments are provided to the function than what is necessary. @@ -62,7 +62,12 @@ def _(request, pmr_redis_container, pmr_redis_config): "The redis fixture currently only supports up to 16 parallel executions" ) - db = redis.Redis(host=pmr_redis_config.host, port=pmr_redis_config.port, db=database_number) + db = redis.Redis( + host=pmr_redis_config.host, + port=pmr_redis_config.port, + db=database_number, + decode_responses=decode_responses or pmr_redis_config.decode_responses, + ) db.flushdb() Credentials.assign_from_credentials( diff --git a/tests/fixture/test_redis.py b/tests/fixture/test_redis.py index 349715d..8fda36c 100644 --- a/tests/fixture/test_redis.py +++ b/tests/fixture/test_redis.py @@ -2,6 +2,7 @@ from pytest_mock_resources.compat import redis redis_client = create_redis_fixture() +redis_client_decode = create_redis_fixture(decode_responses=True) def _sets_setup(redis_client): @@ -42,6 +43,65 @@ def test_custom_connection_url(self, redis_client): assert value == "bar" +class TestStringsDecoded: + def test_set(self, redis_client_decode): + redis_client_decode.set("foo", "bar") + value = redis_client_decode.get("foo") + assert value == "bar" + + def test_append(self, redis_client_decode): + redis_client_decode.set("foo", "bar") + redis_client_decode.append("foo", "baz") + value = redis_client_decode.get("foo") + assert value == "barbaz" + + def test_int_operations(self, redis_client_decode): + redis_client_decode.set("foo", 1) + redis_client_decode.incr("foo") + value = int(redis_client_decode.get("foo")) + assert value == 2 + + redis_client_decode.decr("foo") + value = int(redis_client_decode.get("foo")) + assert value == 1 + + redis_client_decode.incrby("foo", 4) + value = int(redis_client_decode.get("foo")) + assert value == 5 + + redis_client_decode.decrby("foo", 3) + value = int(redis_client_decode.get("foo")) + assert value == 2 + + def test_float_operations(self, redis_client_decode): + redis_client_decode.set("foo", 1.2) + value = float(redis_client_decode.get("foo")) + assert value == 1.2 + + redis_client_decode.incrbyfloat("foo", 4.1) + value = float(redis_client_decode.get("foo")) + assert value == 5.3 + + redis_client_decode.incrbyfloat("foo", -3.1) + value = float(redis_client_decode.get("foo")) + assert value == 2.2 + + def test_multiple_keys(self, redis_client_decode): + test_mapping = {"foo": "bar", "baz": 1, "flo": 1.2} + redis_client_decode.mset(test_mapping) + assert redis_client_decode.get("foo") == "bar" + assert int(redis_client_decode.get("baz")) == 1 + assert float(redis_client_decode.get("flo")) == 1.2 + + def test_querries(self, redis_client_decode): + test_mapping = {"foo1": "bar1", "foo2": "bar2", "flo": "flo"} + redis_client_decode.mset(test_mapping) + foo_keys = redis_client_decode.keys("foo*") + assert "foo1" in foo_keys + assert "foo2" in foo_keys + assert "flo" not in foo_keys + + class TestStrings: def test_set(self, redis_client): redis_client.set("foo", "bar")