Skip to content

Commit

Permalink
Merge pull request #17 from GerardRodes/master
Browse files Browse the repository at this point in the history
Close requestor session
  • Loading branch information
LilSpazJoekp committed Jan 17, 2021
2 parents e0884d1 + c6a63f6 commit b33ad52
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 14 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Expand Up @@ -68,4 +68,5 @@ Source Contributors
- Todd Roberts `@toddrob99 <https://github.com/toddrob99>`_
- MaybeNetwork `@MaybeNetwork <https://github.com/MaybeNetwork>`_
- Nick Kelly `@nickatnight <https://github.com/nickatnight>`_
- Gerard Rodes <GerardRodesVidal@gmail.com> `@GerardRodes <https://github.com/GerardRodes>`_
<!-- - Add "Name <email (optional)> and github profile link" above this line. -->
3 changes: 3 additions & 0 deletions CHANGES.rst
Expand Up @@ -9,6 +9,9 @@ Unreleased
* Ability to submit image galleries with :meth:`.submit_gallery`.
* Ability to pass a gallery url to :meth:`.Reddit.submission`.
* Ability to specify modmail mute duration.
* Added :meth:`.Reddit.close` to close the requestor session.
* Ability to use :class:`.Reddit` as an asynchronous context manager that automatically
closes the requestor session on exit.

**Changed**

Expand Down
51 changes: 49 additions & 2 deletions asyncpraw/reddit.py
Expand Up @@ -112,13 +112,40 @@ def validate_on_submit(self) -> bool:
def validate_on_submit(self, val: bool):
self._validate_on_submit = val

def __enter__(self):
async def __aenter__(self):
"""Handle the context manager open."""
return self

async def __aexit__(self, *_args):
"""Handle the context manager close."""
await self.close()

def __enter__(self):
"""Handle the context manager open.
.. deprecated:: 7.1.1
Using this class as a synchronous context manager is deprecated and will
be removed in the next release. Use this class as an asynchronous context
manager instead.
"""
warn(
"Using this class as a synchronous context manager is deprecated and will "
"be removed in the next release. Use this class as an asynchronous context "
"manager instead.",
category=DeprecationWarning,
stacklevel=3,
)
return self # pragma: no cover

def __exit__(self, *_args):
"""Handle the context manager close."""

async def close(self):
"""Close the requestor."""
await self.requestor.close()

def __init__(
self,
site_name: str = None,
Expand Down Expand Up @@ -175,6 +202,24 @@ async def request(self, *args, **kwargs):
reddit = Reddit(..., requestor_class=JSONDebugRequestor,
requestor_kwargs={"session": my_session})
You can automatically close the requestor session by using this class as an
context manager:
.. code-block:: python
async with Reddit(...) as reddit:
print(await reddit.user.me()
You can also call :meth:`.Reddit.close`:
.. code-block:: python
reddit = Reddit(...)
# do stuff with reddit
...
# then close the reqestor when done
await reddit.close()
"""
self._core = self._authorized_core = self._read_only_core = None
self._objector = None
Expand Down Expand Up @@ -221,7 +266,7 @@ async def request(self, *args, **kwargs):
)

self._prepare_objector()
self._prepare_asyncprawcore(requestor_class, requestor_kwargs)
self.requestor = self._prepare_asyncprawcore(requestor_class, requestor_kwargs)

self.auth = models.Auth(self, None)
"""An instance of :class:`.Auth`.
Expand Down Expand Up @@ -418,6 +463,8 @@ def _prepare_asyncprawcore(self, requestor_class=None, requestor_kwargs=None):
else:
self._prepare_untrusted_asyncprawcore(requestor)

return requestor

def _prepare_trusted_asyncprawcore(self, requestor):
authenticator = TrustedAuthenticator(
requestor,
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/test_deprecations.py
Expand Up @@ -60,3 +60,14 @@ def test_gold_method(self):
excinfo.value.args[0]
== "`subreddits.gold` has be renamed to `subreddits.premium`."
)

def test_synchronous_context_manager(self):
with pytest.raises(DeprecationWarning) as excinfo:
with self.reddit:
pass
assert (
excinfo.value.args[0]
== "Using this class as a synchronous context manager is deprecated"
" and will be removed in the next release. Use this class as an "
"asynchronous context manager instead."
)
33 changes: 21 additions & 12 deletions tests/unit/test_reddit.py
Expand Up @@ -20,12 +20,21 @@ class TestReddit(UnitTest):
x: "dummy" for x in ["client_id", "client_secret", "user_agent"]
}

async def test_close_session(self):
temp_reddit = Reddit(**self.REQUIRED_DUMMY_SETTINGS)
assert not temp_reddit.requestor._http.closed
async with temp_reddit as reddit:
pass
assert reddit.requestor._http.closed and temp_reddit.requestor._http.closed

def test_comment(self):
assert Comment(self.reddit, id="cklfmye").id == "cklfmye"

def test_context_manager(self):
with Reddit(**self.REQUIRED_DUMMY_SETTINGS) as reddit:
async def test_context_manager(self):
async with Reddit(**self.REQUIRED_DUMMY_SETTINGS) as reddit:
assert not reddit._validate_on_submit
assert not reddit.requestor._http.closed
assert reddit.requestor._http.closed

def test_info__invalid_param(self):
with pytest.raises(TypeError) as excinfo:
Expand Down Expand Up @@ -153,8 +162,8 @@ async def test_post_ratelimit(self, __, _):
response = await self.reddit.post("test")
assert response == {}

def test_read_only__with_authenticated_core(self):
with Reddit(
async def test_read_only__with_authenticated_core(self):
async with Reddit(
password=None,
refresh_token="refresh",
username=None,
Expand All @@ -166,8 +175,8 @@ def test_read_only__with_authenticated_core(self):
reddit.read_only = False
assert not reddit.read_only

def test_read_only__with_authenticated_core__non_confidential(self):
with Reddit(
async def test_read_only__with_authenticated_core__non_confidential(self):
async with Reddit(
client_id="dummy",
client_secret=None,
redirect_uri="dummy",
Expand All @@ -180,8 +189,8 @@ def test_read_only__with_authenticated_core__non_confidential(self):
reddit.read_only = False
assert not reddit.read_only

def test_read_only__with_script_authenticated_core(self):
with Reddit(
async def test_read_only__with_script_authenticated_core(self):
async with Reddit(
password="dummy", username="dummy", **self.REQUIRED_DUMMY_SETTINGS
) as reddit:
assert not reddit.read_only
Expand All @@ -190,8 +199,8 @@ def test_read_only__with_script_authenticated_core(self):
reddit.read_only = False
assert not reddit.read_only

def test_read_only__without_trusted_authenticated_core(self):
with Reddit(
async def test_read_only__without_trusted_authenticated_core(self):
async with Reddit(
password=None, username=None, **self.REQUIRED_DUMMY_SETTINGS
) as reddit:
assert reddit.read_only
Expand All @@ -201,10 +210,10 @@ def test_read_only__without_trusted_authenticated_core(self):
reddit.read_only = True
assert reddit.read_only

def test_read_only__without_untrusted_authenticated_core(self):
async def test_read_only__without_untrusted_authenticated_core(self):
required_settings = self.REQUIRED_DUMMY_SETTINGS.copy()
required_settings["client_secret"] = None
with Reddit(password=None, username=None, **required_settings) as reddit:
async with Reddit(password=None, username=None, **required_settings) as reddit:
assert reddit.read_only
with pytest.raises(ClientException):
reddit.read_only = False
Expand Down

0 comments on commit b33ad52

Please sign in to comment.