diff --git a/CHANGES.rst b/CHANGES.rst index 7a2298917..d11268a5c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,8 @@ Unreleased * Ability to pass a gallery url to :meth:`.Reddit.submission`. * Ability to specify modmail mute duration. * Added :meth:`.Reddit.close` to close the requestor session. +* Ability to use :class:`.Reddit` as an asynchronous context manager that automatically + closes the requestor session on exit. **Changed** diff --git a/asyncpraw/reddit.py b/asyncpraw/reddit.py index 5178841c8..a77769c43 100644 --- a/asyncpraw/reddit.py +++ b/asyncpraw/reddit.py @@ -112,10 +112,33 @@ def validate_on_submit(self) -> bool: def validate_on_submit(self, val: bool): self._validate_on_submit = val - def __enter__(self): + async def __aenter__(self): """Handle the context manager open.""" return self + async def __aexit__(self, *_args): + """Handle the context manager close.""" + await self.close() + + def __enter__(self): + """Handle the context manager open. + + .. deprecated:: 7.1.1 + + Using this class as a synchronous context manager is deprecated and will + be removed in the next release. Use this class as an asynchronous context + manager instead. + + """ + warn( + "Using this class as a synchronous context manager is deprecated and will " + "be removed in the next release. Use this class as an asynchronous context " + "manager instead.", + category=DeprecationWarning, + stacklevel=3, + ) + return self # pragma: no cover + def __exit__(self, *_args): """Handle the context manager close.""" @@ -179,6 +202,24 @@ async def request(self, *args, **kwargs): reddit = Reddit(..., requestor_class=JSONDebugRequestor, requestor_kwargs={"session": my_session}) + You can automatically close the requestor session by using this class as an + context manager: + + .. code-block:: python + + async with Reddit(...) as reddit: + print(await reddit.user.me() + + You can also call :meth:`.Reddit.close`: + + .. code-block:: python + + reddit = Reddit(...) + # do stuff with reddit + ... + # then close the reqestor when done + await reddit.close() + """ self._core = self._authorized_core = self._read_only_core = None self._objector = None diff --git a/tests/unit/test_deprecations.py b/tests/unit/test_deprecations.py index ff36af4ac..7d5ef5b02 100644 --- a/tests/unit/test_deprecations.py +++ b/tests/unit/test_deprecations.py @@ -60,3 +60,14 @@ def test_gold_method(self): excinfo.value.args[0] == "`subreddits.gold` has be renamed to `subreddits.premium`." ) + + def test_synchronous_context_manager(self): + with pytest.raises(DeprecationWarning) as excinfo: + with self.reddit: + pass + assert ( + excinfo.value.args[0] + == "Using this class as a synchronous context manager is deprecated" + " and will be removed in the next release. Use this class as an " + "asynchronous context manager instead." + ) diff --git a/tests/unit/test_reddit.py b/tests/unit/test_reddit.py index 43314086a..00c02e9f6 100644 --- a/tests/unit/test_reddit.py +++ b/tests/unit/test_reddit.py @@ -21,17 +21,20 @@ class TestReddit(UnitTest): } async def test_close_session(self): - assert not self.reddit.requestor._http.closed - async with self.reddit as reddit: + temp_reddit = Reddit(**self.REQUIRED_DUMMY_SETTINGS) + assert not temp_reddit.requestor._http.closed + async with temp_reddit as reddit: pass - assert reddit.requestor._http.closed and self.reddit.requestor._http.closed + assert reddit.requestor._http.closed and temp_reddit.requestor._http.closed def test_comment(self): assert Comment(self.reddit, id="cklfmye").id == "cklfmye" - def test_context_manager(self): - with Reddit(**self.REQUIRED_DUMMY_SETTINGS) as reddit: + async def test_context_manager(self): + async with Reddit(**self.REQUIRED_DUMMY_SETTINGS) as reddit: assert not reddit._validate_on_submit + assert not reddit.requestor._http.closed + assert reddit.requestor._http.closed def test_info__invalid_param(self): with pytest.raises(TypeError) as excinfo: @@ -159,8 +162,8 @@ async def test_post_ratelimit(self, __, _): response = await self.reddit.post("test") assert response == {} - def test_read_only__with_authenticated_core(self): - with Reddit( + async def test_read_only__with_authenticated_core(self): + async with Reddit( password=None, refresh_token="refresh", username=None, @@ -172,8 +175,8 @@ def test_read_only__with_authenticated_core(self): reddit.read_only = False assert not reddit.read_only - def test_read_only__with_authenticated_core__non_confidential(self): - with Reddit( + async def test_read_only__with_authenticated_core__non_confidential(self): + async with Reddit( client_id="dummy", client_secret=None, redirect_uri="dummy", @@ -186,8 +189,8 @@ def test_read_only__with_authenticated_core__non_confidential(self): reddit.read_only = False assert not reddit.read_only - def test_read_only__with_script_authenticated_core(self): - with Reddit( + async def test_read_only__with_script_authenticated_core(self): + async with Reddit( password="dummy", username="dummy", **self.REQUIRED_DUMMY_SETTINGS ) as reddit: assert not reddit.read_only @@ -196,8 +199,8 @@ def test_read_only__with_script_authenticated_core(self): reddit.read_only = False assert not reddit.read_only - def test_read_only__without_trusted_authenticated_core(self): - with Reddit( + async def test_read_only__without_trusted_authenticated_core(self): + async with Reddit( password=None, username=None, **self.REQUIRED_DUMMY_SETTINGS ) as reddit: assert reddit.read_only @@ -207,10 +210,10 @@ def test_read_only__without_trusted_authenticated_core(self): reddit.read_only = True assert reddit.read_only - def test_read_only__without_untrusted_authenticated_core(self): + async def test_read_only__without_untrusted_authenticated_core(self): required_settings = self.REQUIRED_DUMMY_SETTINGS.copy() required_settings["client_secret"] = None - with Reddit(password=None, username=None, **required_settings) as reddit: + async with Reddit(password=None, username=None, **required_settings) as reddit: assert reddit.read_only with pytest.raises(ClientException): reddit.read_only = False