Skip to content

Commit

Permalink
Provide aclose() / close() for classes requiring lifetime management (#…
Browse files Browse the repository at this point in the history
…2898)

* Define `aclose()` methods instead of `close()` for async Redis()

* update examples to use `aclose()`

* Update tests to use Redis.aclose()

* Add aclose() to asyncio.client.PubSub
close() and reset() retained as aliases

* Add aclose method to asyncio.RedisCluster

* Add aclose() to asyncio.client.Pipeline

* add `close()` method to sync Pipeline

* add `aclose()` to asyncio.connection.ConnectionPool

* Add `close()` method to redis.ConnectionPool

* Deprecate older functions.

* changes.txt

* fix unittest

* fix typo

* Update docs
  • Loading branch information
kristjanvalur committed Sep 20, 2023
1 parent 6207641 commit c46a28d
Show file tree
Hide file tree
Showing 16 changed files with 294 additions and 91 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Add 'aclose()' methods to async classes, deprecate async close().
* Fix #2831, add auto_close_connection_pool=True arg to asyncio.Redis.from_url()
* Fix incorrect redis.asyncio.Cluster type hint for `retry_on_error`
* Fix dead weakref in sentinel connection causing ReferenceError (#2767)
Expand Down
28 changes: 14 additions & 14 deletions docs/examples/asyncio_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"\n",
"## Connecting and Disconnecting\n",
"\n",
"Utilizing asyncio Redis requires an explicit disconnect of the connection since there is no asyncio deconstructor magic method. By default, a connection pool is created on `redis.Redis()` and attached to this `Redis` instance. The connection pool closes automatically on the call to `Redis.close` which disconnects all connections."
"Utilizing asyncio Redis requires an explicit disconnect of the connection since there is no asyncio deconstructor magic method. By default, a connection pool is created on `redis.Redis()` and attached to this `Redis` instance. The connection pool closes automatically on the call to `Redis.aclose` which disconnects all connections."
]
},
{
Expand All @@ -39,9 +39,9 @@
"source": [
"import redis.asyncio as redis\n",
"\n",
"connection = redis.Redis()\n",
"print(f\"Ping successful: {await connection.ping()}\")\n",
"await connection.close()"
"client = redis.Redis()\n",
"print(f\"Ping successful: {await client.ping()}\")\n",
"await client.aclose()"
]
},
{
Expand All @@ -60,8 +60,8 @@
"import redis.asyncio as redis\n",
"\n",
"pool = redis.ConnectionPool.from_url(\"redis://localhost\")\n",
"connection = redis.Redis.from_pool(pool)\n",
"await connection.close()"
"client = redis.Redis.from_pool(pool)\n",
"await client.close()"
]
},
{
Expand Down Expand Up @@ -91,11 +91,11 @@
"import redis.asyncio as redis\n",
"\n",
"pool = redis.ConnectionPool.from_url(\"redis://localhost\")\n",
"connection1 = redis.Redis(connection_pool=pool)\n",
"connection2 = redis.Redis(connection_pool=pool)\n",
"await connection1.close()\n",
"await connection2.close()\n",
"await pool.disconnect()"
"client1 = redis.Redis(connection_pool=pool)\n",
"client2 = redis.Redis(connection_pool=pool)\n",
"await client1.aclose()\n",
"await client2.aclose()\n",
"await pool.aclose()"
]
},
{
Expand All @@ -113,9 +113,9 @@
"source": [
"import redis.asyncio as redis\n",
"\n",
"connection = redis.Redis(protocol=3)\n",
"await connection.close()\n",
"await connection.ping()"
"client = redis.Redis(protocol=3)\n",
"await client.aclose()\n",
"await client.ping()"
]
},
{
Expand Down
46 changes: 32 additions & 14 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -65,6 +64,7 @@
from redis.utils import (
HIREDIS_AVAILABLE,
_set_info_logger,
deprecated_function,
get_lib_version,
safe_str,
str_if_bytes,
Expand Down Expand Up @@ -527,7 +527,7 @@ async def __aenter__(self: _RedisT) -> _RedisT:
return await self.initialize()

async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
await self.aclose()

_DEL_MESSAGE = "Unclosed Redis client"

Expand All @@ -539,7 +539,7 @@ def __del__(self, _warnings: Any = warnings) -> None:
context = {"client": self, "message": self._DEL_MESSAGE}
asyncio.get_running_loop().call_exception_handler(context)

async def close(self, close_connection_pool: Optional[bool] = None) -> None:
async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
"""
Closes Redis client connection
Expand All @@ -557,6 +557,13 @@ async def close(self, close_connection_pool: Optional[bool] = None) -> None:
):
await self.connection_pool.disconnect()

@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
async def close(self, close_connection_pool: Optional[bool] = None) -> None:
"""
Alias for aclose(), for backwards compatibility
"""
await self.aclose(close_connection_pool)

async def _send_command_parse_response(self, conn, command_name, *args, **options):
"""
Send a command and parse the response
Expand Down Expand Up @@ -764,13 +771,18 @@ async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.reset()
await self.aclose()

def __del__(self):
if self.connection:
self.connection.clear_connect_callbacks()

async def reset(self):
async def aclose(self):
# In case a connection property does not yet exist
# (due to a crash earlier in the Redis() constructor), return
# immediately as there is nothing to clean-up.
if not hasattr(self, "connection"):
return
async with self._lock:
if self.connection:
await self.connection.disconnect()
Expand All @@ -782,13 +794,15 @@ async def reset(self):
self.patterns = {}
self.pending_unsubscribe_patterns = set()

def close(self) -> Awaitable[NoReturn]:
# In case a connection property does not yet exist
# (due to a crash earlier in the Redis() constructor), return
# immediately as there is nothing to clean-up.
if not hasattr(self, "connection"):
return
return self.reset()
@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
async def close(self) -> None:
"""Alias for aclose(), for backwards compatibility"""
await self.aclose()

@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="reset")
async def reset(self) -> None:
"""Alias for aclose(), for backwards compatibility"""
await self.aclose()

async def on_connect(self, connection: Connection):
"""Re-subscribe to any channels and patterns previously subscribed to"""
Expand Down Expand Up @@ -1232,6 +1246,10 @@ async def reset(self):
await self.connection_pool.release(self.connection)
self.connection = None

async def aclose(self) -> None:
"""Alias for reset(), a standard method name for cleanup"""
await self.reset()

def multi(self):
"""
Start a transactional block of the pipeline after WATCH commands
Expand Down Expand Up @@ -1264,14 +1282,14 @@ async def _disconnect_reset_raise(self, conn, error):
# valid since this connection has died. raise a WatchError, which
# indicates the user should retry this transaction.
if self.watching:
await self.reset()
await self.aclose()
raise WatchError(
"A ConnectionError occurred on while watching one or more keys"
)
# if retry_on_timeout is not set, or the error is not
# a TimeoutError, raise it
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
await self.reset()
await self.aclose()
raise

async def immediate_execute_command(self, *args, **options):
Expand Down
35 changes: 23 additions & 12 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@
TryAgainError,
)
from redis.typing import AnyKeyT, EncodableT, KeyT
from redis.utils import dict_merge, get_lib_version, safe_str, str_if_bytes
from redis.utils import (
deprecated_function,
dict_merge,
get_lib_version,
safe_str,
str_if_bytes,
)

TargetNodesT = TypeVar(
"TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
Expand Down Expand Up @@ -395,27 +401,32 @@ async def initialize(self) -> "RedisCluster":
)
self._initialize = False
except BaseException:
await self.nodes_manager.close()
await self.nodes_manager.close("startup_nodes")
await self.nodes_manager.aclose()
await self.nodes_manager.aclose("startup_nodes")
raise
return self

async def close(self) -> None:
async def aclose(self) -> None:
"""Close all connections & client if initialized."""
if not self._initialize:
if not self._lock:
self._lock = asyncio.Lock()
async with self._lock:
if not self._initialize:
self._initialize = True
await self.nodes_manager.close()
await self.nodes_manager.close("startup_nodes")
await self.nodes_manager.aclose()
await self.nodes_manager.aclose("startup_nodes")

@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
async def close(self) -> None:
"""alias for aclose() for backwards compatibility"""
await self.aclose()

async def __aenter__(self) -> "RedisCluster":
return await self.initialize()

async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
await self.close()
await self.aclose()

def __await__(self) -> Generator[Any, None, "RedisCluster"]:
return self.initialize().__await__()
Expand Down Expand Up @@ -767,13 +778,13 @@ async def _execute_command(
self.nodes_manager.startup_nodes.pop(target_node.name, None)
# Hard force of reinitialize of the node/slots setup
# and try again with the new setup
await self.close()
await self.aclose()
raise
except ClusterDownError:
# ClusterDownError can occur during a failover and to get
# self-healed, we will try to reinitialize the cluster layout
# and retry executing the command
await self.close()
await self.aclose()
await asyncio.sleep(0.25)
raise
except MovedError as e:
Expand All @@ -790,7 +801,7 @@ async def _execute_command(
self.reinitialize_steps
and self.reinitialize_counter % self.reinitialize_steps == 0
):
await self.close()
await self.aclose()
# Reset the counter
self.reinitialize_counter = 0
else:
Expand Down Expand Up @@ -1323,7 +1334,7 @@ async def initialize(self) -> None:
# If initialize was called after a MovedError, clear it
self._moved_exception = None

async def close(self, attr: str = "nodes_cache") -> None:
async def aclose(self, attr: str = "nodes_cache") -> None:
self.default_node = None
await asyncio.gather(
*(
Expand Down Expand Up @@ -1471,7 +1482,7 @@ async def execute(
if type(e) in self.__class__.ERRORS_ALLOW_RETRY:
# Try again with the new cluster setup.
exception = e
await self._client.close()
await self._client.aclose()
await asyncio.sleep(0.25)
else:
# All other errors should be raised.
Expand Down
4 changes: 4 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,10 @@ async def disconnect(self, inuse_connections: bool = True):
if exc:
raise exc

async def aclose(self) -> None:
"""Close the pool, disconnecting all connections"""
await self.disconnect()

def set_retry(self, retry: "Retry") -> None:
for conn in self._available_connections:
conn.retry = retry
Expand Down
4 changes: 4 additions & 0 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,10 @@ def reset(self):
self.connection_pool.release(self.connection)
self.connection = None

def close(self):
"""Close the pipeline"""
self.reset()

def multi(self):
"""
Start a transactional block of the pipeline after WATCH commands
Expand Down
4 changes: 4 additions & 0 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,10 @@ def disconnect(self, inuse_connections=True):
for connection in connections:
connection.disconnect()

def close(self) -> None:
"""Close the pool, disconnecting all connections"""
self.disconnect()

def set_retry(self, retry: "Retry") -> None:
self.connection_kwargs.update({"retry": retry})
for conn in self._available_connections:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_asyncio/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@
except AttributeError:
import mock

try:
from contextlib import aclosing
except ImportError:
import contextlib

@contextlib.asynccontextmanager
async def aclosing(thing):
try:
yield thing
finally:
await thing.aclose()


def create_task(coroutine):
return asyncio.create_task(coroutine)
4 changes: 2 additions & 2 deletions tests/test_asyncio/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def teardown():
# handle cases where a test disconnected a client
# just manually retry the flushdb
await client.flushdb()
await client.close()
await client.aclose()
await client.connection_pool.disconnect()
else:
if flushdb:
Expand All @@ -110,7 +110,7 @@ async def teardown():
# handle cases where a test disconnected a client
# just manually retry the flushdb
await client.flushdb(target_nodes="primaries")
await client.close()
await client.aclose()

teardown_clients.append(teardown)
return client
Expand Down

0 comments on commit c46a28d

Please sign in to comment.