Skip to content

Commit

Permalink
Merge pull request #144 from praw-dev/vcr_improvements
Browse files Browse the repository at this point in the history
VCR/test improvements
  • Loading branch information
LilSpazJoekp committed Nov 7, 2021
2 parents 056d558 + 816d7aa commit eecc268
Show file tree
Hide file tree
Showing 9 changed files with 343 additions and 55 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"pytest-asyncio",
"pytest-vcr",
"testfixtures >4.13.2, <7",
"vcrpy==4.0.2",
"vcrpy==4.1.1",
],
}
extras["dev"] += extras["lint"] + extras["test"]
Expand Down
20 changes: 6 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,23 +140,15 @@ def deserialize(cassette_string):
return json.loads(cassette_string)


class CustomVCR(VCR):
"""Derived from VCR to make setting paths easier."""

def use_cassette(self, path="", **kwargs):
"""Use a cassette."""
path += ".json"
return super().use_cassette(path, **kwargs)


VCR = CustomVCR(
serializer="custom_serializer",
vcr = VCR(
before_record_response=filter_access_token,
cassette_library_dir="tests/integration/cassettes",
match_on=["uri", "method"],
before_record_response=filter_access_token,
path_transformer=VCR.ensure_suffix(".json"),
serializer="custom_serializer",
)
VCR.register_serializer("custom_serializer", CustomSerializer)
VCR.register_persister(CustomPersister)
vcr.register_serializer("custom_serializer", CustomSerializer)
vcr.register_persister(CustomPersister)


def after_init(func, *args):
Expand Down
12 changes: 8 additions & 4 deletions tests/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from asyncpraw import Reddit
from tests.conftest import VCR
from tests.conftest import vcr


class IntegrationTest(asynctest.TestCase):
Expand All @@ -26,7 +26,7 @@ async def tearDown(self) -> None:

def setup_vcr(self):
"""Configure VCR instance."""
self.recorder = VCR
self.recorder = vcr

# Disable response compression in order to see the response bodies in
# the VCR cassettes.
Expand Down Expand Up @@ -80,8 +80,7 @@ async def async_list(async_generator):
@staticmethod
async def async_next(async_generator):
"""Return the next item from an async iterator."""
async for item in async_generator:
return item
return await async_generator.__anext__()

def use_cassette(self, cassette_name=None, **kwargs):
"""Use a cassette. The cassette name is dynamically generated.
Expand All @@ -103,6 +102,11 @@ def use_cassette(self, cassette_name=None, **kwargs):
f"Dynamic cassette name for function {dynamic_name} does not match"
f" the provided cassette name: {cassette_name}"
)
match_on = kwargs.get(
"match_requests_on", None
) # keep interface same as in PRAW
if match_on:
kwargs["match_on"] = kwargs.pop("match_requests_on")
return self.recorder.use_cassette(cassette_name or dynamic_name, **kwargs)

def get_cassette_name(self) -> str:
Expand Down

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/integration/models/reddit/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class TestSubredditCollections(IntegrationTest):
async def test_call(self, _):
with self.use_cassette():
subreddit = await self.reddit.subreddit(pytest.placeholders.test_subreddit)
collection = await self.async_next(subreddit.collections)
collection = next(iter(await self.async_list(subreddit.collections)))
test_collection = await subreddit.collections(collection.collection_id)
assert collection == test_collection
test_collection = await subreddit.collections(
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/models/reddit/test_more.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def test_comments(self):
"cu5pbdh",
],
}
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
more = MoreComments(self.reddit, data)
more.submission = await self.reddit.submission("3hahrw")
comments = await more.comments()
Expand All @@ -41,7 +41,7 @@ async def test_comments__continue_thread_type(self):
"parent_id": "t1_cu5v5h7",
"children": [],
}
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
more = MoreComments(self.reddit, data)
more.submission = await self.reddit.submission("3hahrw")
comments = await more.comments()
Expand Down
33 changes: 14 additions & 19 deletions tests/integration/models/reddit/test_subreddit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,15 +1118,15 @@ async def test_delete(self, _):
self.reddit.read_only = False
with self.use_cassette():
subreddit = await self.reddit.subreddit(pytest.placeholders.test_subreddit)
template = await self.async_next(subreddit.flair.templates)
template = next(iter(await self.async_list(subreddit.flair.templates)))
await subreddit.flair.templates.delete(template["id"])

@mock.patch("asyncio.sleep", return_value=None)
async def test_update(self, _):
self.reddit.read_only = False
with self.use_cassette():
subreddit = await self.reddit.subreddit(pytest.placeholders.test_subreddit)
template = await self.async_next(subreddit.flair.templates)
template = next(iter(await self.async_list(subreddit.flair.templates)))
await subreddit.flair.templates.update(
template["id"],
"PRAW updated",
Expand Down Expand Up @@ -1157,7 +1157,7 @@ async def test_update_fetch(self, _):
self.reddit.read_only = False
with self.use_cassette():
subreddit = await self.reddit.subreddit(pytest.placeholders.test_subreddit)
template = await self.async_next(subreddit.flair.templates)
template = next(iter(await self.async_list(subreddit.flair.templates)))
await subreddit.flair.templates.update(
template["id"],
"PRAW updated",
Expand All @@ -1172,7 +1172,7 @@ async def test_update_fetch_no_css_class(self, _):
self.reddit.read_only = False
with self.use_cassette():
subreddit = await self.reddit.subreddit(pytest.placeholders.test_subreddit)
template = await self.async_next(subreddit.flair.templates)
template = next(iter(await self.async_list(subreddit.flair.templates)))
await subreddit.flair.templates.update(
template["id"],
"PRAW updated",
Expand All @@ -1186,7 +1186,7 @@ async def test_update_fetch_no_text(self, _):
self.reddit.read_only = False
with self.use_cassette():
subreddit = await self.reddit.subreddit(pytest.placeholders.test_subreddit)
template = await self.async_next(subreddit.flair.templates)
template = next(iter(await self.async_list(subreddit.flair.templates)))
await subreddit.flair.templates.update(
template["id"],
css_class="myCSS",
Expand All @@ -1200,7 +1200,7 @@ async def test_update_fetch_no_text_or_css_class(self, _):
self.reddit.read_only = False
with self.use_cassette():
subreddit = await self.reddit.subreddit(pytest.placeholders.test_subreddit)
template = await self.async_next(subreddit.flair.templates)
template = next(iter(await self.async_list(subreddit.flair.templates)))
await subreddit.flair.templates.update(
template["id"],
text_color="dark",
Expand All @@ -1213,7 +1213,7 @@ async def test_update_fetch_only(self, _):
self.reddit.read_only = False
with self.use_cassette():
subreddit = await self.reddit.subreddit(pytest.placeholders.test_subreddit)
template = await self.async_next(subreddit.flair.templates)
template = next(iter(await self.async_list(subreddit.flair.templates)))
await subreddit.flair.templates.update(template["id"], fetch=True)
newtemplate = list(
filter(
Expand All @@ -1228,7 +1228,7 @@ async def test_update_false(self, _):
self.reddit.read_only = False
with self.use_cassette():
subreddit = await self.reddit.subreddit(pytest.placeholders.test_subreddit)
template = await self.async_next(subreddit.flair.templates)
template = next(iter(await self.async_list(subreddit.flair.templates)))
await subreddit.flair.templates.update(
template["id"], text_editable=True, fetch=True
)
Expand Down Expand Up @@ -1844,18 +1844,13 @@ async def test_comments__with_pause(self, _):

@mock.patch("asyncio.sleep", return_value=None)
async def test_comments__with_skip_existing(self, _):
with self.use_cassette("TestSubredditStreams.test_comments__with_pause"):
with self.use_cassette():
subreddit = await self.reddit.subreddit("askreddit")
generator = subreddit.stream.comments(skip_existing=True)
count = 0
try:
async for comment in generator:
count += 1
except TypeError:
pass
# This test uses the same cassette as test_comments which shows
# that there are at least 100 comments in the stream.
assert count < 102
generator = subreddit.stream.comments(skip_existing=True, pause_after=-1)
comment = await self.async_next(generator)
assert comment is None
comment = await self.async_next(generator)
assert isinstance(comment, Comment)

@mock.patch("asyncio.sleep", return_value=None)
async def test_submissions(self, _):
Expand Down
20 changes: 10 additions & 10 deletions tests/integration/models/test_comment_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def setUp(self):
self.reddit._core._requestor._http._default_headers["Accept-Encoding"] = "gzip"

async def test_replace__all(self):
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
submission = await self.reddit.submission("3hahrw")
comments = await submission.comments()
before_count = len(await comments.list())
Expand All @@ -26,7 +26,7 @@ async def test_replace__all(self):
assert before_count < len(await comments.list())

async def test_replace__all_large(self):
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
submission = Submission(self.reddit, "n49rw")
comments = await submission.comments()
skipped = await comments.replace_more(None, threshold=0)
Expand All @@ -36,7 +36,7 @@ async def test_replace__all_large(self):
assert len(await comments.list()) == len(submission._comments_by_id)

async def test_replace__all_with_comment_limit(self):
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
submission = await self.reddit.submission("3hahrw")
submission.comment_limit = 10
comments = await submission.comments()
Expand All @@ -46,7 +46,7 @@ async def test_replace__all_with_comment_limit(self):

@mock.patch("asyncio.sleep", return_value=None)
async def test_replace__all_with_comment_sort(self, _):
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
submission = await self.reddit.submission("3hahrw")
submission.comment_sort = "old"
comments = await submission.comments()
Expand All @@ -55,14 +55,14 @@ async def test_replace__all_with_comment_sort(self, _):
assert len(await comments.list()) >= 500

async def test_replace__skip_at_limit(self):
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
submission = await self.reddit.submission("3hahrw")
comments = await submission.comments()
skipped = await comments.replace_more(1)
assert len(skipped) == 5

# async def test_replace__skip_below_threshold(self): # FIXME: not currently working; same with praw
# with self.use_cassette(match_requests_on=["uri", "method", "body"]):
# with self.use_cassette():
# submission = Submission(self.reddit, "hkwbo0")
# comments = await submission.comments()
# before_count = len(await comments.list())
Expand All @@ -73,7 +73,7 @@ async def test_replace__skip_at_limit(self):
# assert before_count < len(await comments.list())

async def test_replace__skip_all(self):
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
submission = await self.reddit.submission("3hahrw")
comments = await submission.comments()
before_count = len(await comments.list())
Expand All @@ -85,7 +85,7 @@ async def test_replace__skip_all(self):

@mock.patch("asyncio.sleep", return_value=None)
async def test_replace__on_comment_from_submission(self, _):
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
submission = await self.reddit.submission("3hahrw")
comments = await submission.comments()
types = [type(x) for x in await comments.list()]
Expand All @@ -100,7 +100,7 @@ async def test_replace__on_comment_from_submission(self, _):

@mock.patch("asyncio.sleep", return_value=None)
async def test_replace__on_direct_comment(self, _):
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
comment = await self.reddit.comment("d8r4im1")
await comment.refresh()
assert any(
Expand All @@ -112,7 +112,7 @@ async def test_replace__on_direct_comment(self, _):
@mock.patch("asyncio.sleep", return_value=None)
async def test_comment_forest_refresh_error(self, _):
self.reddit.read_only = False
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
submission = await self.async_next(self.reddit.front.top())
# await submission._fetch()
submission.comment_limit = 1
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/models/test_inbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ async def test_mark_all_read(self, _):
@mock.patch("asyncio.sleep", return_value=None)
async def test_mark_read(self, _):
self.reddit.read_only = False
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
await self.reddit.inbox.mark_read(
await self.async_list(self.reddit.inbox.unread())
)

@mock.patch("asyncio.sleep", return_value=None)
async def test_mark_unread(self, _):
self.reddit.read_only = False
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
await self.reddit.inbox.mark_unread(
await self.async_list(self.reddit.inbox.all())
)
Expand Down Expand Up @@ -107,15 +107,15 @@ async def test_message__unauthorized(self):
@mock.patch("asyncio.sleep", return_value=None)
async def test_message_collapse(self, _):
self.reddit.read_only = False
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
await self.reddit.inbox.collapse(
await self.async_list(self.reddit.inbox.messages())
)

@mock.patch("asyncio.sleep", return_value=None)
async def test_message_uncollapse(self, _):
self.reddit.read_only = False
with self.use_cassette(match_requests_on=["uri", "method", "body"]):
with self.use_cassette():
await self.reddit.inbox.uncollapse(
await self.async_list(self.reddit.inbox.messages())
)
Expand Down

0 comments on commit eecc268

Please sign in to comment.