Skip to content

Commit

Permalink
Merge pull request #646 from JP-Ellis/fix/decorator-type-hinting
Browse files Browse the repository at this point in the history
fix(typing): improve decorator type hinting
  • Loading branch information
youtux committed Dec 2, 2023
2 parents 206530b + 9c60589 commit 8a694ff
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 30 deletions.
23 changes: 15 additions & 8 deletions src/pytest_bdd/plugin.py
@@ -1,16 +1,15 @@
"""Pytest plugin entry point. Used for any fixtures needed."""
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, cast
from typing import TYPE_CHECKING, Any, Callable, Generator, TypeVar, cast

import pytest
from typing_extensions import ParamSpec

from . import cucumber_json, generation, gherkin_terminal_reporter, given, reporting, then, when
from .utils import CONFIG_STACK

if TYPE_CHECKING:
from typing import Any, Generator

from _pytest.config import Config, PytestPluginManager
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
Expand All @@ -21,6 +20,10 @@
from .parser import Feature, Scenario, Step


P = ParamSpec("P")
T = TypeVar("T")


def pytest_addhooks(pluginmanager: PytestPluginManager) -> None:
"""Register plugin hooks."""
from pytest_bdd import hooks
Expand Down Expand Up @@ -94,7 +97,7 @@ def pytest_bdd_step_error(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable,
step_func: Callable[..., Any],
step_func_args: dict,
exception: Exception,
) -> None:
Expand All @@ -103,7 +106,11 @@ def pytest_bdd_step_error(

@pytest.hookimpl(tryfirst=True)
def pytest_bdd_before_step(
request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable
request: FixtureRequest,
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable[..., Any],
) -> None:
reporting.before_step(request, feature, scenario, step, step_func)

Expand All @@ -114,7 +121,7 @@ def pytest_bdd_after_step(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable,
step_func: Callable[..., Any],
step_func_args: dict[str, Any],
) -> None:
reporting.after_step(request, feature, scenario, step, step_func, step_func_args)
Expand All @@ -124,7 +131,7 @@ def pytest_cmdline_main(config: Config) -> int | None:
return generation.cmdline_main(config)


def pytest_bdd_apply_tag(tag: str, function: Callable) -> Callable:
def pytest_bdd_apply_tag(tag: str, function: Callable[P, T]) -> Callable[P, T]:
mark = getattr(pytest.mark, tag)
marked = mark(function)
return cast(Callable, marked)
return cast(Callable[P, T], marked)
10 changes: 8 additions & 2 deletions src/pytest_bdd/reporting.py
Expand Up @@ -155,15 +155,21 @@ def step_error(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable,
step_func: Callable[..., Any],
step_func_args: dict,
exception: Exception,
) -> None:
"""Finalize the step report as failed."""
request.node.__scenario_report__.fail()


def before_step(request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable) -> None:
def before_step(
request: FixtureRequest,
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable[..., Any],
) -> None:
"""Store step start time."""
request.node.__scenario_report__.add_step_report(StepReport(step=step))

Expand Down
20 changes: 12 additions & 8 deletions src/pytest_bdd/scenario.py
Expand Up @@ -16,24 +16,25 @@
import logging
import os
import re
from typing import TYPE_CHECKING, Callable, Iterator, cast
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, cast

import pytest
from _pytest.fixtures import FixtureDef, FixtureManager, FixtureRequest, call_fixture_func
from _pytest.nodes import iterparentnodeids
from typing_extensions import ParamSpec

from . import exceptions
from .feature import get_feature, get_features
from .steps import StepFunctionContext, get_step_fixture_name, inject_fixture
from .utils import CONFIG_STACK, get_args, get_caller_module_locals, get_caller_module_path

if TYPE_CHECKING:
from typing import Any, Iterable

from _pytest.mark.structures import ParameterSet

from .parser import Feature, Scenario, ScenarioTemplate, Step

P = ParamSpec("P")
T = TypeVar("T")

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -197,14 +198,14 @@ def _execute_scenario(feature: Feature, scenario: Scenario, request: FixtureRequ

def _get_scenario_decorator(
feature: Feature, feature_name: str, templated_scenario: ScenarioTemplate, scenario_name: str
) -> Callable[[Callable], Callable]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
# HACK: Ideally we would use `def decorator(fn)`, but we want to return a custom exception
# when the decorator is misused.
# Pytest inspect the signature to determine the required fixtures, and in that case it would look
# for a fixture called "fn" that doesn't exist (if it exists then it's even worse).
# It will error with a "fixture 'fn' not found" message instead.
# We can avoid this hack by using a pytest hook and check for misuse instead.
def decorator(*args: Callable) -> Callable:
def decorator(*args: Callable[P, T]) -> Callable[P, T]:
if not args:
raise exceptions.ScenarioIsDecoratorOnly(
"scenario function can only be used as a decorator. Refer to the documentation."
Expand Down Expand Up @@ -236,7 +237,7 @@ def scenario_wrapper(request: FixtureRequest, _pytest_bdd_example: dict[str, str

scenario_wrapper.__doc__ = f"{feature_name}: {scenario_name}"
scenario_wrapper.__scenario__ = templated_scenario
return cast(Callable, scenario_wrapper)
return cast(Callable[P, T], scenario_wrapper)

return decorator

Expand All @@ -251,8 +252,11 @@ def collect_example_parametrizations(


def scenario(
feature_name: str, scenario_name: str, encoding: str = "utf-8", features_base_dir=None
) -> Callable[[Callable], Callable]:
feature_name: str,
scenario_name: str,
encoding: str = "utf-8",
features_base_dir: str | None = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Scenario decorator.
:param str feature_name: Feature file name. Absolute or relative to the configured feature base path.
Expand Down
24 changes: 13 additions & 11 deletions src/pytest_bdd/steps.py
Expand Up @@ -43,13 +43,15 @@ def _(article):

import pytest
from _pytest.fixtures import FixtureDef, FixtureRequest
from typing_extensions import ParamSpec

from .parser import Step
from .parsers import StepParser, get_parser
from .types import GIVEN, THEN, WHEN
from .utils import get_caller_module_locals

TCallable = TypeVar("TCallable", bound=Callable[..., Any])
P = ParamSpec("P")
T = TypeVar("T")


@enum.unique
Expand All @@ -63,7 +65,7 @@ class StepFunctionContext:
type: Literal["given", "when", "then"] | None
step_func: Callable[..., Any]
parser: StepParser
converters: dict[str, Callable[..., Any]] = field(default_factory=dict)
converters: dict[str, Callable[[str], Any]] = field(default_factory=dict)
target_fixture: str | None = None


Expand All @@ -74,10 +76,10 @@ def get_step_fixture_name(step: Step) -> str:

def given(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[str], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Given step decorator.
:param name: Step name or a parser object.
Expand All @@ -93,10 +95,10 @@ def given(

def when(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[str], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""When step decorator.
:param name: Step name or a parser object.
Expand All @@ -112,10 +114,10 @@ def when(

def then(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[str], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Then step decorator.
:param name: Step name or a parser object.
Expand All @@ -132,10 +134,10 @@ def then(
def step(
name: str | StepParser,
type_: Literal["given", "when", "then"] | None = None,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[str], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable[[TCallable], TCallable]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Generic step decorator.
:param name: Step name as in the feature file.
Expand All @@ -155,7 +157,7 @@ def step(
if converters is None:
converters = {}

def decorator(func: TCallable) -> TCallable:
def decorator(func: Callable[P, T]) -> Callable[P, T]:
parser = get_parser(name)

context = StepFunctionContext(
Expand Down
2 changes: 1 addition & 1 deletion src/pytest_bdd/utils.py
Expand Up @@ -19,7 +19,7 @@
CONFIG_STACK: list[Config] = []


def get_args(func: Callable) -> list[str]:
def get_args(func: Callable[..., Any]) -> list[str]:
"""Get a list of argument names for a function.
:param func: The function to inspect.
Expand Down

0 comments on commit 8a694ff

Please sign in to comment.