Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support late creation of DB connection using callable for Meta.database #543

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion aredis_om/__init__.py
Expand Up @@ -8,11 +8,11 @@
FindQuery,
HashModel,
JsonModel,
VectorFieldOptions,
KNNExpression,
NotFoundError,
QueryNotSupportedError,
QuerySyntaxError,
RedisModel,
RedisModelError,
VectorFieldOptions,
)
2 changes: 1 addition & 1 deletion aredis_om/model/__init__.py
Expand Up @@ -4,8 +4,8 @@
Field,
HashModel,
JsonModel,
VectorFieldOptions,
KNNExpression,
NotFoundError,
RedisModel,
VectorFieldOptions,
)
29 changes: 17 additions & 12 deletions aredis_om/model/model.py
Expand Up @@ -25,6 +25,8 @@
)

from more_itertools import ichunked
from redis import Redis
from redis.asyncio import Redis as RedisAsync
from redis.commands.json.path import Path
from redis.exceptions import ResponseError
from typing_extensions import Protocol, get_args, get_origin
Expand Down Expand Up @@ -1255,9 +1257,7 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
base_meta, "primary_key_pattern", "{pk}"
)
if not getattr(new_class._meta, "database", None):
new_class._meta.database = getattr(
base_meta, "database", get_redis_connection()
)
new_class._meta.database = getattr(base_meta, "database", None)
if not getattr(new_class._meta, "encoding", None):
new_class._meta.encoding = getattr(base_meta, "encoding")
if not getattr(new_class._meta, "primary_key_creator_cls", None):
Expand All @@ -1282,6 +1282,7 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901

class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
pk: Optional[str] = Field(default=None, primary_key=True)
_conn: Optional[Union[Redis, RedisAsync]] = None

Meta = DefaultMeta

Expand Down Expand Up @@ -1370,7 +1371,19 @@ def make_primary_key(cls, pk: Any):

@classmethod
def db(cls):
return cls._meta.database
if not cls._conn:
conn = (
cls._meta.database()
if callable(cls._meta.database)
else cls._meta.database or get_redis_connection()
)
if not has_redis_json(conn):
log.error(
"Your Redis instance does not have the RedisJson module "
"loaded. JsonModel depends on RedisJson."
)
cls._conn = conn
return cls._conn

@classmethod
def find(
Expand Down Expand Up @@ -1674,14 +1687,6 @@ def __init_subclass__(cls, **kwargs):
# Generate the RediSearch schema once to validate fields.
cls.redisearch_schema()

def __init__(self, *args, **kwargs):
if not has_redis_json(self.db()):
log.error(
"Your Redis instance does not have the RedisJson module "
"loaded. JsonModel depends on RedisJson."
)
super().__init__(*args, **kwargs)

async def save(
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
) -> "Model":
Expand Down
13 changes: 12 additions & 1 deletion tests/conftest.py
Expand Up @@ -3,7 +3,8 @@

import pytest

from aredis_om import get_redis_connection
from aredis_om import RedisModel, get_redis_connection
from aredis_om.model.model import DefaultMeta, model_registry


TEST_PREFIX = "redis-om:testing"
Expand Down Expand Up @@ -59,3 +60,13 @@ def cleanup_keys(request):
# Delete keys only once
if conn.decr(once_key) == 0:
_delete_test_keys(TEST_PREFIX, conn)


@pytest.fixture(autouse=True)
def reset_meta():
yield
RedisModel.Meta.database = DefaultMeta
if hasattr(RedisModel, "_meta"):
del RedisModel._meta
RedisModel._conn = None
model_registry.clear()
92 changes: 92 additions & 0 deletions tests/test_json_model.py
Expand Up @@ -10,6 +10,8 @@

import pytest
import pytest_asyncio
from redis import ConnectionError, Redis
from redis.asyncio import Redis as AsyncRedis

from aredis_om import (
EmbeddedJsonModel,
Expand Down Expand Up @@ -849,3 +851,93 @@ async def test_count(members, m):
m.Member.first_name == "Kim", m.Member.last_name == "Brookins"
).count()
assert actual_count == 1


@py_test_mark_asyncio
async def test_default_connection_not_configured_at_class_definition_time():
class MyJsonModel(JsonModel):
a_field: int

assert MyJsonModel._meta.database is None


@py_test_mark_asyncio
async def test_default_connection_configured_and_opened_at_usage_time():
class MyJsonModel(JsonModel):
a_field: int

obj = MyJsonModel(a_field=42)
await obj.save()

assert MyJsonModel._meta.database is None
assert isinstance(MyJsonModel._conn, (Redis, AsyncRedis))
assert MyJsonModel._conn.connection_pool.connection_kwargs["host"] == "localhost"


@py_test_mark_asyncio
async def test_custom_connection_configured_at_class_definition_time():
class MyJsonModel(JsonModel):
a_field: int

class Meta:
database = Redis(host="10.20.30.40", port=1234)

assert isinstance(MyJsonModel._meta.database, Redis)
assert (
MyJsonModel._meta.database.connection_pool.connection_kwargs["host"]
== "10.20.30.40"
)
assert MyJsonModel._meta.database.connection_pool.connection_kwargs["port"] == 1234


@py_test_mark_asyncio
async def test_custom_connection_opened_at_usage_time():
class MyJsonModel(JsonModel):
a_field: int

class Meta:
database = Redis(host="10.20.30.40", port=5678)

obj = MyJsonModel(a_field=42)
with pytest.raises(ConnectionError, match="connecting to 10.20.30.40:5678"):
await obj.save()


@py_test_mark_asyncio
async def test_lazy_connection_configured_and_opened_at_usage_time():
def my_connection():
return Redis(host="10.20.30.40", port=9012)

class MyJsonModel(JsonModel):
a_field: int

class Meta:
database = my_connection

obj = MyJsonModel(a_field=42)

assert not isinstance(MyJsonModel._meta.database, Redis)
assert callable(MyJsonModel._meta.database)
assert MyJsonModel._conn is None

with pytest.raises(ConnectionError, match="connecting to 10.20.30.40:9012"):
await obj.save()


@py_test_mark_asyncio
async def test_lazy_connection_cached(redis):
def my_connection():
return redis

class MyJsonModel(JsonModel):
a_field: int

class Meta:
database = my_connection

obj = MyJsonModel(a_field=42)
await obj.save()

assert isinstance(MyJsonModel._conn, (Redis, AsyncRedis))
assert MyJsonModel.db() is MyJsonModel._conn
assert MyJsonModel.db() is MyJsonModel.db()