Skip to content

Commit

Permalink
unittest.mock: use ParamSpec in patch (#10325)
Browse files Browse the repository at this point in the history
Fixes #10324
  • Loading branch information
hauntsaninja committed Jun 20, 2023
1 parent 7114aec commit 9e86c60
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
5 changes: 3 additions & 2 deletions stdlib/unittest/mock.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ from collections.abc import Awaitable, Callable, Coroutine, Iterable, Mapping, S
from contextlib import _GeneratorContextManager
from types import TracebackType
from typing import Any, Generic, TypeVar, overload
from typing_extensions import Final, Literal, Self, TypeAlias
from typing_extensions import Final, Literal, ParamSpec, Self, TypeAlias

_T = TypeVar("_T")
_TT = TypeVar("_TT", bound=type[Any])
_R = TypeVar("_R")
_F = TypeVar("_F", bound=Callable[..., Any])
_AF = TypeVar("_AF", bound=Callable[..., Coroutine[Any, Any, Any]])
_P = ParamSpec("_P")

if sys.version_info >= (3, 8):
__all__ = (
Expand Down Expand Up @@ -234,7 +235,7 @@ class _patch(Generic[_T]):
@overload
def __call__(self, func: _TT) -> _TT: ...
@overload
def __call__(self, func: Callable[..., _R]) -> Callable[..., _R]: ...
def __call__(self, func: Callable[_P, _R]) -> Callable[_P, _R]: ...
if sys.version_info >= (3, 8):
def decoration_helper(
self, patched: _patch[Any], args: Sequence[Any], keywargs: Any
Expand Down
29 changes: 29 additions & 0 deletions test_cases/stdlib/check_unittest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import unittest
from collections.abc import Callable
from datetime import datetime, timedelta
from decimal import Decimal
from fractions import Fraction
from typing_extensions import assert_type
from unittest.mock import Mock, patch

case = unittest.TestCase()

Expand Down Expand Up @@ -86,3 +89,29 @@ def __gt__(self, other: Bacon) -> bool:
case.assertGreater(Spam(), Eggs()) # type: ignore
case.assertGreater(Ham(), Bacon()) # type: ignore
case.assertGreater(Bacon(), Ham()) # type: ignore

###
# Tests for mock.patch
###


@patch("sys.exit", new=Mock())
def f(i: int) -> str:
return "asdf"


assert_type(f(1), str)
f("a") # type: ignore


@patch("sys.exit", new=Mock())
class TestXYZ(unittest.TestCase):
attr: int = 5

@staticmethod
def method() -> int:
return 123


assert_type(TestXYZ.attr, int)
assert_type(TestXYZ.method, Callable[[], int])

Check failure on line 117 in test_cases/stdlib/check_unittest.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.10)

"assert_type" mismatch: expected "() -> int" but received "() -> int" (reportGeneralTypeIssues)
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ def testcase_dir_from_package_name(package_name: str) -> Path:


def get_all_testcase_directories() -> list[PackageInfo]:
testcase_directories = [PackageInfo("stdlib", Path("test_cases"))]
testcase_directories: list[PackageInfo] = []
for package_name in os.listdir("stubs"):
potential_testcase_dir = testcase_dir_from_package_name(package_name)
if potential_testcase_dir.is_dir():
testcase_directories.append(PackageInfo(package_name, potential_testcase_dir))
return sorted(testcase_directories)
return [PackageInfo("stdlib", Path("test_cases"))] + sorted(testcase_directories)


# ====================================================================
Expand Down

0 comments on commit 9e86c60

Please sign in to comment.