Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions redisvl/redis/connection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type

from redis import ConnectionPool, Redis
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from redis.asyncio import SSLConnection as ASSLConnection
from redis.connection import (
AbstractConnection,
Connection,
ConnectionPool,
SSLConnection,
)

from redisvl.redis.constants import REDIS_REQUIRED_MODULES
from redisvl.redis.utils import convert_bytes
Expand Down Expand Up @@ -130,8 +137,18 @@ def validate_async_redis_modules(
Raises:
ValueError: If required Redis modules are not installed.
"""
# pick the right connection class
connection_class: Type[AbstractConnection] = (
SSLConnection
if client.connection_pool.connection_class == ASSLConnection
else Connection
)
# set up a temp sync client
temp_client = Redis(
connection_pool=ConnectionPool(**client.connection_pool.connection_kwargs)
connection_pool=ConnectionPool(
connection_class=connection_class,
**client.connection_pool.connection_kwargs,
)
)
RedisConnectionFactory.validate_redis_modules(
temp_client, redis_required_modules
Expand Down
6 changes: 5 additions & 1 deletion tests/integration/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,8 @@ def test_unknown_redis():

def test_required_modules(client):
RedisConnectionFactory.validate_redis_modules(client)
RedisConnectionFactory.validate_async_redis_modules(client)


@pytest.mark.asyncio
async def test_async_required_modules(async_client):
RedisConnectionFactory.validate_async_redis_modules(async_client)