Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ Changelog
=========


- :release:`10.3.0 <29th August 2023>`
- :support:`193` Migrate testing helpers from bot repo.


- :release:`10.2.0 <28th August 2023>`
- :support:`192` Bump Discord.py to :literal-url:`2.3.2 <https://github.com/Rapptz/discord.py/releases/tag/v2.3.2>`.

Expand Down
8 changes: 8 additions & 0 deletions pydis_core/utils/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydis_core.utils.tests import base, helpers

__all__ = [
base,
helpers
]

__all__ = [module.__name__ for module in __all__]
65 changes: 65 additions & 0 deletions pydis_core/utils/tests/_autospec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import contextlib
import functools
import pkgutil
import unittest.mock
from collections.abc import Callable


@functools.wraps(unittest.mock._patch.decoration_helper)
@contextlib.contextmanager
def _decoration_helper(self, patched, args, keywargs):
"""Skips adding patchings as args if their `dont_pass` attribute is True."""
# Don't ask what this does. It's just a copy from stdlib, but with the dont_pass check added.
extra_args = []
with contextlib.ExitStack() as exit_stack:
for patching in patched.patchings:
arg = exit_stack.enter_context(patching)
if not getattr(patching, "dont_pass", False):
# Only add the patching as an arg if dont_pass is False.
if patching.attribute_name is not None:
keywargs.update(arg)
elif patching.new is unittest.mock.DEFAULT:
extra_args.append(arg)

args += tuple(extra_args)
yield args, keywargs


@functools.wraps(unittest.mock._patch.copy)
def _copy(self):
"""Copy the `dont_pass` attribute along with the standard copy operation."""
patcher_copy = _copy.original(self)
patcher_copy.dont_pass = getattr(self, "dont_pass", False)
return patcher_copy


# Monkey-patch the patcher class :)
_copy.original = unittest.mock._patch.copy
unittest.mock._patch.copy = _copy
unittest.mock._patch.decoration_helper = _decoration_helper


def autospec(target, *attributes: str, pass_mocks: bool = True, **patch_kwargs) -> Callable:
"""
Patch multiple `attributes` of a `target` with autospecced mocks and `spec_set` as True.

If `pass_mocks` is True, pass the autospecced mocks as arguments to the decorated object.
"""
# Caller's kwargs should take priority and overwrite the defaults.
kwargs = dict(spec_set=True, autospec=True)
kwargs.update(patch_kwargs)

# Import the target if it's a string.
# This is to support both object and string targets like patch.multiple.
if isinstance(target, str):
target = pkgutil.resolve_name(target)

def decorator(func):
for attribute in attributes:
patcher = unittest.mock.patch.object(target, attribute, **kwargs)
if not pass_mocks:
# A custom attribute to keep track of which patchings should be skipped.
patcher.dont_pass = True
func = patcher(func)
return func
return decorator
136 changes: 136 additions & 0 deletions pydis_core/utils/tests/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import logging
import unittest
import warnings
from contextlib import contextmanager

import discord
from discord.ext import commands

from pydis_core.utils.logging import get_logger

from . import helpers

try:
from async_rediscache import RedisSession
REDIS_AVAILABLE = True
except ImportError:
RedisSession = object
REDIS_AVAILABLE = False

class _CaptureLogHandler(logging.Handler):
"""A logging handler capturing all (raw and formatted) logging output."""

def __init__(self):
super().__init__()
self.records = []

def emit(self, record):
self.records.append(record)


class LoggingTestsMixin:
"""
A mixin that defines additional test methods for logging behavior.

This mixin relies on the availability of the `fail` attribute defined by the
test classes included in Python's unittest method to signal test failure.
"""

@contextmanager
def assertNotLogs(self, logger=None, level=None, msg=None): # noqa: N802
"""
Asserts that no logs of `level` and higher were emitted by `logger`.

You can specify a specific `logger`, the minimum `logging` level we want to watch and a
custom `msg` to be added to the `AssertionError` if thrown. If the assertion fails, the
recorded log records will be outputted with the `AssertionError` message. The context
manager does not yield a live `look` into the logging records, since we use this context
manager when we're testing under the assumption that no log records will be emitted.
"""
if not isinstance(logger, logging.Logger):
logger = get_logger(logger)

if level:
level = logging._nameToLevel.get(level, level)
else:
level = logging.INFO

handler = _CaptureLogHandler()
old_handlers = logger.handlers[:]
old_level = logger.level
old_propagate = logger.propagate

logger.handlers = [handler]
logger.setLevel(level)
logger.propagate = False

try:
yield
except Exception as exc:
raise exc
finally:
logger.handlers = old_handlers
logger.propagate = old_propagate
logger.setLevel(old_level)

if handler.records:
level_name = logging.getLevelName(level)
n_logs = len(handler.records)
base_message = f"{n_logs} logs of {level_name} or higher were triggered on {logger.name}:\n"
records = [str(record) for record in handler.records]
record_message = "\n".join(records)
standard_message = self._truncateMessage(base_message, record_message)
msg = self._formatMessage(msg, standard_message)
self.fail(msg)


class CommandTestCase(unittest.IsolatedAsyncioTestCase):
"""TestCase with additional assertions that are useful for testing Discord commands."""

async def assertHasPermissionsCheck( # noqa: N802
self,
cmd: commands.Command,
permissions: dict[str, bool],
) -> None:
"""
Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`.

Every permission in `permissions` is expected to be reported as missing. In other words, do
not include permissions which should not raise an exception along with those which should.
"""
# Invert permission values because it's more intuitive to pass to this assertion the same
# permissions as those given to the check decorator.
permissions = {k: not v for k, v in permissions.items()}

ctx = helpers.MockContext()
ctx.channel.permissions_for.return_value = discord.Permissions(**permissions)

with self.assertRaises(commands.MissingPermissions) as cm:
await cmd.can_run(ctx)

self.assertCountEqual(permissions.keys(), cm.exception.missing_permissions)


class RedisTestCase(unittest.IsolatedAsyncioTestCase):
"""
Use this as a base class for any test cases that require a redis session.

This will prepare a fresh redis instance for each test function, and will
not make any assertions on its own. Tests can mutate the instance as they wish.
"""

session = None

async def flush(self):
"""Flush everything from the redis database to prevent carry-overs between tests."""
await self.session.client.flushall()

async def asyncSetUp(self):
if not REDIS_AVAILABLE:
warnings.warn("redis_session kwarg passed, but async-rediscache not installed!", stacklevel=2)
self.session = await RedisSession(use_fakeredis=True).connect()
await self.flush()

async def asyncTearDown(self):
if self.session:
await self.session.client.close()
Loading