diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 98d76482..79eb3060 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -1,8 +1,15 @@ import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type -from redis import ConnectionPool, Redis +from redis import Redis from redis.asyncio import Redis as AsyncRedis +from redis.asyncio import SSLConnection as ASSLConnection +from redis.connection import ( + AbstractConnection, + Connection, + ConnectionPool, + SSLConnection, +) from redisvl.redis.constants import REDIS_REQUIRED_MODULES from redisvl.redis.utils import convert_bytes @@ -130,8 +137,18 @@ def validate_async_redis_modules( Raises: ValueError: If required Redis modules are not installed. """ + # pick the right connection class + connection_class: Type[AbstractConnection] = ( + SSLConnection + if client.connection_pool.connection_class == ASSLConnection + else Connection + ) + # set up a temp sync client temp_client = Redis( - connection_pool=ConnectionPool(**client.connection_pool.connection_kwargs) + connection_pool=ConnectionPool( + connection_class=connection_class, + **client.connection_pool.connection_kwargs, + ) ) RedisConnectionFactory.validate_redis_modules( temp_client, redis_required_modules diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py index e9c69c14..45328363 100644 --- a/tests/integration/test_connection.py +++ b/tests/integration/test_connection.py @@ -51,4 +51,8 @@ def test_unknown_redis(): def test_required_modules(client): RedisConnectionFactory.validate_redis_modules(client) - RedisConnectionFactory.validate_async_redis_modules(client) + + +@pytest.mark.asyncio +async def test_async_required_modules(async_client): + RedisConnectionFactory.validate_async_redis_modules(async_client)