From 7a2574ed48a889e7f3d886a30ba378d51f48a114 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 12 Mar 2024 11:50:13 -0400 Subject: [PATCH 1/3] fix connection pool for SSL --- redisvl/redis/connection.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 98d76482..d6d1d280 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -1,8 +1,10 @@ import os from typing import Any, Dict, List, Optional -from redis import ConnectionPool, Redis +from redis import Connection, ConnectionPool, Redis, SSLConnection +from redis.asyncio import Connection as AConnection from redis.asyncio import Redis as AsyncRedis +from redis.asyncio import SSLConnection as ASSLConnection from redisvl.redis.constants import REDIS_REQUIRED_MODULES from redisvl.redis.utils import convert_bytes @@ -130,8 +132,16 @@ def validate_async_redis_modules( Raises: ValueError: If required Redis modules are not installed. """ + # pick the right connection class + connection_class = Connection + if isinstance(client.connection_pool.connection_class, ASSLConnection): + connection_class = SSLConnection + # set up a temp sync client temp_client = Redis( - connection_pool=ConnectionPool(**client.connection_pool.connection_kwargs) + connection_pool=ConnectionPool( + connection_class=connection_class, + **client.connection_pool.connection_kwargs, + ) ) RedisConnectionFactory.validate_redis_modules( temp_client, redis_required_modules From 5a72b891355805514930360da606d2fc7a80c4a3 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 12 Mar 2024 14:29:26 -0400 Subject: [PATCH 2/3] fix typing --- redisvl/redis/connection.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index d6d1d280..701f4dbd 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -1,10 +1,15 @@ import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type -from redis import Connection, ConnectionPool, Redis, SSLConnection -from redis.asyncio import Connection as AConnection +from redis import Redis from redis.asyncio import Redis as AsyncRedis from redis.asyncio import SSLConnection as ASSLConnection +from redis.connection import ( + AbstractConnection, + Connection, + ConnectionPool, + SSLConnection, +) from redisvl.redis.constants import REDIS_REQUIRED_MODULES from redisvl.redis.utils import convert_bytes @@ -133,7 +138,7 @@ def validate_async_redis_modules( ValueError: If required Redis modules are not installed. """ # pick the right connection class - connection_class = Connection + connection_class: Type[AbstractConnection] = Connection if isinstance(client.connection_pool.connection_class, ASSLConnection): connection_class = SSLConnection # set up a temp sync client From 392661136486dda1ad0e2d5a48a0d790184e17a6 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 13 Mar 2024 14:24:58 -0400 Subject: [PATCH 3/3] consolidate logic and update tests --- redisvl/redis/connection.py | 8 +++++--- tests/integration/test_connection.py | 6 +++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 701f4dbd..79eb3060 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -138,9 +138,11 @@ def validate_async_redis_modules( ValueError: If required Redis modules are not installed. """ # pick the right connection class - connection_class: Type[AbstractConnection] = Connection - if isinstance(client.connection_pool.connection_class, ASSLConnection): - connection_class = SSLConnection + connection_class: Type[AbstractConnection] = ( + SSLConnection + if client.connection_pool.connection_class == ASSLConnection + else Connection + ) # set up a temp sync client temp_client = Redis( connection_pool=ConnectionPool( diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py index e9c69c14..45328363 100644 --- a/tests/integration/test_connection.py +++ b/tests/integration/test_connection.py @@ -51,4 +51,8 @@ def test_unknown_redis(): def test_required_modules(client): RedisConnectionFactory.validate_redis_modules(client) - RedisConnectionFactory.validate_async_redis_modules(client) + + +@pytest.mark.asyncio +async def test_async_required_modules(async_client): + RedisConnectionFactory.validate_async_redis_modules(async_client)