Skip to content

Commit

Permalink
Cleanup tests and fully migrate from nose to pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
LilSpazJoekp committed Dec 5, 2022
1 parent 16fabf8 commit 553520e
Show file tree
Hide file tree
Showing 42 changed files with 5,689 additions and 6,875 deletions.
17 changes: 7 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""asyncpraw setup.py"""
import os
import re
from codecs import open
from os import path
Expand All @@ -24,15 +23,13 @@
"sphinxcontrib-trio",
],
"test": [
"asynctest >=0.13.0 ; python_version < '3.8'",
"mock >=0.8",
"pytest ==7.2.*",
"pytest-asyncio",
"pytest-vcr",
"testfixtures >4.13.2, <7",
"vcrpy >=4.1.1"
if os.getenv("PYPI_UPLOAD", False)
else "vcrpy@git+https://github.com/kevin1024/vcrpy.git@b1bc5c3a02db0447c28ab9a4cee314aeb6cdf1a7",
"asynctest ==0.13.* ; python_version < '3.8'", # TODO: Remove me when support for 3.7 is dropped
"mock ==4.*",
"pytest ==7.*",
"pytest-asyncio ==0.18.*",
"pytest-vcr ==1.*",
"testfixtures ==6.*",
"vcrpy ==4.*",
],
}
extras["lint"] += extras["readthedocs"]
Expand Down
186 changes: 29 additions & 157 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,158 +1,15 @@
"""Prepare py.test."""
import asyncio
import json
import os
from base64 import b64encode
from datetime import datetime
from functools import wraps

import pytest
from _pytest.tmpdir import _mk_tmp
from vcr import VCR
from vcr.cassette import Cassette
from vcr.persisters.filesystem import FilesystemPersister
from vcr.serialize import deserialize, serialize


# Prevent calls to sleep
async def _sleep(*args):
raise Exception("Call to sleep")


asyncio.sleep = _sleep


def b64_string(input_string):
"""Return a base64 encoded string (not bytes) from input_string."""
return b64encode(input_string.encode("utf-8")).decode("utf-8")


def env_default(key):
"""Return environment variable or placeholder string."""
return os.environ.get(f"prawtest_{key}", f"placeholder_{key}")


def filter_access_token(response):
"""Add VCR callback to filter access token."""
request_uri = response["url"]
if "api/v1/access_token" not in request_uri or response["status"]["code"] != 200:
return response
body = response["body"]["string"].decode()
try:
token = json.loads(body)["access_token"]
response["body"]["string"] = response["body"]["string"].replace(
token.encode("utf-8"), b"<ACCESS_TOKEN>"
)
placeholders["access_token"] = token
except (KeyError, TypeError, ValueError):
pass
return response


def serialize_dict(data: dict):
"""This is to filter out buffered readers."""
new_dict = {}
for key, value in data.items():
if key == "file":
new_dict[key] = serialize_file(value.name)
elif isinstance(value, dict):
new_dict[key] = serialize_dict(value)
elif isinstance(value, list):
new_dict[key] = serialize_list(value)
else:
new_dict[key] = value
return new_dict


def serialize_file(file_name):
with open(file_name, "rb") as f:
return f.read().decode("utf-8", "replace")


def serialize_list(data: list):
new_list = []
for item in data:
if isinstance(item, dict):
new_list.append(serialize_dict(item))
elif isinstance(item, list):
new_list.append(serialize_list(item))
elif isinstance(item, tuple):
if item[0] == "file":
item = (item[0], serialize_file(item[1].name))
new_list.append(item)
else:
new_list.append(item)
return new_list


placeholders = {
x: env_default(x)
for x in (
"auth_code client_id client_secret password redirect_uri test_subreddit"
" user_agent username refresh_token"
).split()
}

placeholders["basic_auth"] = b64_string(
f"{placeholders['client_id']}:{placeholders['client_secret']}"
)


class CustomPersister(FilesystemPersister):
@classmethod
def load_cassette(cls, cassette_path, serializer):
try:
with open(cassette_path) as f:
cassette_content = f.read()
except OSError:
raise ValueError("Cassette not found.")
for replacement, value in [
(v, f"<{k.upper()}>") for k, v in placeholders.items()
]:
cassette_content = cassette_content.replace(value, replacement)
cassette = deserialize(cassette_content, serializer)
return cassette

@staticmethod
def save_cassette(cassette_path, cassette_dict, serializer):
data = serialize(cassette_dict, serializer)
for replacement, value in [
(f"<{k.upper()}>", v) for k, v in placeholders.items()
]:
data = data.replace(value, replacement)
dirname, filename = os.path.split(cassette_path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
with open(cassette_path, "w") as f:
f.write(data)


class CustomSerializer(object):
@staticmethod
def serialize(cassette_dict):
cassette_dict["recorded_at"] = datetime.now().isoformat()[:-7]
return (
f"{json.dumps(serialize_dict(cassette_dict), sort_keys=True, indent=2)}\n"
)

@staticmethod
def deserialize(cassette_string):
return json.loads(cassette_string)


vcr = VCR(
before_record_response=filter_access_token,
cassette_library_dir="tests/integration/cassettes",
match_on=["uri", "method"],
path_transformer=VCR.ensure_suffix(".json"),
serializer="custom_serializer",
)
vcr.register_serializer("custom_serializer", CustomSerializer)
vcr.register_persister(CustomPersister)


def after_init(func, *args):
func(*args)
class Placeholders:
def __init__(self, _dict):
self.__dict__ = _dict


def add_init_hook(original_init):
Expand All @@ -166,25 +23,40 @@ def wrapper(self, *args, **kwargs):
return wrapper


Cassette.__init__ = add_init_hook(Cassette.__init__)
def after_init(func, *args):
func(*args)


def init_hook(cassette):
if not cassette.requests:
pytest.set_up_record() # dynamically defined in __init__.py


class Placeholders:
def __init__(self, _dict):
self.__dict__ = _dict


def pytest_configure():
pytest.placeholders = Placeholders(placeholders)


@pytest.fixture
def tmp_path(request, tmp_path_factory):
# Manually create tmp_path fixture since asynctest does not play nicely with
# fixtures as args
request.cls.tmp_path = _mk_tmp(request, tmp_path_factory)
def pytest_load_initial_conftests(early_config, parser, args):
early_config.addinivalue_line(
"markers", "cassette_name: Name of cassette to use for test."
)
early_config.addinivalue_line(
"markers", "recorder_args: Arguments to pass to the recorder."
)


Cassette.__init__ = add_init_hook(Cassette.__init__)

os.environ["praw_check_for_updates"] = "False"

placeholders = {
x: os.environ.get(f"prawtest_{x}", f"placeholder_{x}")
for x in (
"auth_code client_id client_secret password redirect_uri refresh_token"
" test_subreddit user_agent username"
).split()
}

placeholders["basic_auth"] = b64encode(
f"{placeholders['client_id']}:{placeholders['client_secret']}".encode("utf-8")
).decode("utf-8")
Loading

0 comments on commit 553520e

Please sign in to comment.