Skip to content

Commit

Permalink
More type annotations (#845
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelboulton committed Feb 10, 2023
1 parent 334ad88 commit 924463f
Show file tree
Hide file tree
Showing 34 changed files with 381 additions and 305 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ ignore-imports = true

[tool.pylint.TYPECHECK]
ignored-classes = "RememberComposer"
ignored-modules = "distutils"

[tool.flake8]
ignore = ["E501", "W503", "C901", "W504"]
Expand Down
25 changes: 10 additions & 15 deletions tavern/_core/dict_util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import collections.abc
import logging
import os
import re
import string
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Union

import box
from box import Box
from box.box import Box
import jmespath

from tavern._core import exceptions
Expand All @@ -19,12 +18,12 @@
)

from .formatted_str import FormattedString
from .strict_util import StrictSetting, extract_strict_setting
from .strict_util import StrictSetting, StrictSettingKinds, extract_strict_setting

logger = logging.getLogger(__name__)


def _check_and_format_values(to_format, box_vars):
def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str:
formatter = string.Formatter()
would_format = formatter.parse(to_format)

Expand Down Expand Up @@ -91,7 +90,7 @@ def _attempt_find_include(to_format: str, box_vars: box.Box):
return formatter.convert_field(would_replace, conversion) # type: ignore


def format_keys(val, variables, no_double_format=True):
def format_keys(val, variables: Mapping, no_double_format: bool = True):
"""recursively format a dictionary with the given values
Args:
Expand Down Expand Up @@ -138,7 +137,7 @@ def format_keys(val, variables, no_double_format=True):
return formatted


def recurse_access_key(data, query):
def recurse_access_key(data, query: str):
"""
Search for something in the given data using the given query.
Expand Down Expand Up @@ -225,7 +224,7 @@ def _deprecated_recurse_access_key(current_val, keys):
raise


def deep_dict_merge(initial_dct: Dict, merge_dct: collections.abc.Mapping) -> dict:
def deep_dict_merge(initial_dct: Dict, merge_dct: Mapping) -> dict:
"""Recursive dict merge. Instead of updating only top-level keys,
dict_merge recurses down into dicts nested to an arbitrary depth
and returns the merged dict. Keys values present in merge_dct take
Expand All @@ -242,19 +241,15 @@ def deep_dict_merge(initial_dct: Dict, merge_dct: collections.abc.Mapping) -> di
dct = initial_dct.copy()

for k in merge_dct:
if (
k in dct
and isinstance(dct[k], dict)
and isinstance(merge_dct[k], collections.abc.Mapping)
):
if k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], Mapping):
dct[k] = deep_dict_merge(dct[k], merge_dct[k])
else:
dct[k] = merge_dct[k]

return dct


def check_expected_keys(expected, actual):
def check_expected_keys(expected, actual) -> None:
"""Check that a set of expected keys is a superset of the actual keys
Args:
Expand Down Expand Up @@ -328,7 +323,7 @@ def check_keys_match_recursive(
expected_val: Any,
actual_val: Any,
keys: List[Union[str, int]],
strict: Optional[Union[StrictSetting, bool]] = True,
strict: StrictSettingKinds = True,
) -> None:
"""Utility to recursively check response values
Expand Down
4 changes: 2 additions & 2 deletions tavern/_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class BadSchemaError(TavernException):
class TestFailError(TavernException):
"""Test failed somehow"""

def __init__(self, msg, failures=None):
def __init__(self, msg, failures=None) -> None:
super().__init__(msg)
self.failures = failures or []

Expand Down Expand Up @@ -117,7 +117,7 @@ class InvalidFormattedJsonError(TavernException):
class InvalidExtBlockException(TavernException):
"""Tried to use the '$ext' block in a place it is no longer valid to use it"""

def __init__(self, block):
def __init__(self, block) -> None:
super().__init__(
"$ext function found in block {} - this has been moved to verify_response_with block - see documentation".format(
block
Expand Down
15 changes: 7 additions & 8 deletions tavern/_core/extfunctions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import collections.abc
import functools
import importlib
import logging
from typing import Any, List
from typing import Any, List, Mapping, Optional

from tavern._core import exceptions

Expand All @@ -22,7 +21,7 @@ def is_ext_function(block: Any) -> bool:
return isinstance(block, dict) and block.get("$ext", None) is not None


def get_pykwalify_logger(module):
def get_pykwalify_logger(module: Optional[str]) -> logging.Logger:
"""Get logger for this module
Have to do it like this because the way that pykwalify load extension
Expand All @@ -36,7 +35,7 @@ def get_pykwalify_logger(module):
return logging.getLogger(module)


def _getlogger():
def _getlogger() -> logging.Logger:
return get_pykwalify_logger("tavern._core.extfunctions")


Expand Down Expand Up @@ -80,7 +79,7 @@ def import_ext_function(entrypoint: str):
return function


def get_wrapped_response_function(ext: collections.abc.Mapping):
def get_wrapped_response_function(ext: Mapping):
"""Wraps a ext function with arguments given in the test file
This is similar to functools.wrap, but this makes sure that 'response' is
Expand All @@ -107,7 +106,7 @@ def inner(response):
return inner


def get_wrapped_create_function(ext: collections.abc.Mapping):
def get_wrapped_create_function(ext: Mapping):
"""Same as get_wrapped_response_function, but don't require a response"""

func, args, kwargs = _get_ext_values(ext)
Expand All @@ -123,7 +122,7 @@ def inner():
return inner


def _get_ext_values(ext: collections.abc.Mapping):
def _get_ext_values(ext: Mapping):
args = ext.get("extra_args") or ()
kwargs = ext.get("extra_kwargs") or {}
try:
Expand All @@ -136,7 +135,7 @@ def _get_ext_values(ext: collections.abc.Mapping):
return func, args, kwargs


def update_from_ext(request_args: dict, keys_to_check: List[str]):
def update_from_ext(request_args: dict, keys_to_check: List[str]) -> None:
"""
Updates the request_args dict with any values from external functions
Expand Down
17 changes: 10 additions & 7 deletions tavern/_core/jmesutils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import operator
import re
from typing import Any, Dict, List, Sized

from tavern._core import exceptions


def test_type(val, mytype):
def test_type(val, mytype) -> bool:
"""Check value fits one of the types, if so return true, else false"""
typelist = TYPES.get(str(mytype).lower())
if typelist is None:
Expand All @@ -13,11 +14,11 @@ def test_type(val, mytype):
)
try:
for testtype in typelist:
if isinstance(val, testtype):
if isinstance(val, testtype): # type: ignore
return True
return False
except TypeError:
return isinstance(val, typelist)
return isinstance(val, typelist) # type: ignore


COMPARATORS = {
Expand All @@ -36,7 +37,7 @@ def test_type(val, mytype):
"regex": lambda x, y: regex_compare(str(x), str(y)),
"type": test_type,
}
TYPES = {
TYPES: Dict[str, List[Any]] = {
"none": [type(None)],
"number": [int, float],
"int": [int],
Expand All @@ -48,11 +49,11 @@ def test_type(val, mytype):
}


def regex_compare(_input, regex):
def regex_compare(_input, regex) -> bool:
return bool(re.search(regex, _input))


def safe_length(var):
def safe_length(var: Sized) -> int:
"""Exception-safe length check, returns -1 if no length on type or error"""
try:
return len(var)
Expand Down Expand Up @@ -82,7 +83,9 @@ def validate_comparison(each_comparison):
return jmespath, _operator, expected


def actual_validation(_operator, _actual, expected, _expression, expression):
def actual_validation(
_operator: str, _actual, expected, _expression, expression
) -> None:
if not COMPARATORS[_operator](_actual, expected):
raise exceptions.JMESError(
"Validation '{}' ({}) failed!".format(expression, _expression)
Expand Down
4 changes: 2 additions & 2 deletions tavern/_core/loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# https://gist.github.com/joshbode/569627ced3076931b02f
from abc import abstractmethod
import dataclasses
from distutils.util import strtobool # pylint: disable=deprecated-module
from itertools import chain
import logging
import os.path
Expand All @@ -19,6 +18,7 @@

from tavern._core import exceptions
from tavern._core.exceptions import BadSchemaError
from tavern._core.strtobool import strtobool

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -364,7 +364,7 @@ class StrToBoolConstructor:
"""Using `bool` as a constructor directly will evaluate all strings to `True`."""

def __new__(cls, s):
return bool(strtobool(s))
return strtobool(s)


class BoolToken(TypeConvertToken):
Expand Down
30 changes: 11 additions & 19 deletions tavern/_core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
This is here mainly to make MQTT easier, this will almost defintiely change
significantly if/when a proper plugin system is implemented!
"""
import collections.abc
import dataclasses
import logging
from typing import Any, List, Optional
from typing import Any, List, Mapping, Optional

import stevedore

Expand All @@ -30,7 +29,7 @@ def plugin_load_error(mgr, entry_point, err):
raise exceptions.PluginLoadError(msg) from err


def is_valid_reqresp_plugin(ext: Any):
def is_valid_reqresp_plugin(ext: Any) -> bool:
"""Whether this is a valid 'reqresp' plugin
Requires certain functions/variables to be present
Expand Down Expand Up @@ -70,9 +69,6 @@ class _PluginCache:

plugins: List[Any] = dataclasses.field(default_factory=list)

# def __init__(self):
# self.plugins = []

def __call__(self, config: Optional[TestConfig] = None):
if not config and not self.plugins:
raise exceptions.PluginLoadError("No config to load plugins from")
Expand Down Expand Up @@ -139,9 +135,7 @@ def enabled(ext):
load_plugins = _PluginCache()


def get_extra_sessions(
test_spec: collections.abc.Mapping, test_block_config: TestConfig
) -> dict:
def get_extra_sessions(test_spec: Mapping, test_block_config: TestConfig) -> dict:
"""Get extra 'sessions' for any extra test types
Args:
Expand Down Expand Up @@ -172,9 +166,9 @@ def get_extra_sessions(


def get_request_type(
stage: collections.abc.Mapping,
stage: Mapping,
test_block_config: TestConfig,
sessions: collections.abc.Mapping,
sessions: Mapping,
) -> BaseRequest:
"""Get the request object for this stage
Expand Down Expand Up @@ -231,9 +225,7 @@ class ResponseVerifier(dict):
plugin_name: str


def _foreach_response(
stage: collections.abc.Mapping, test_block_config: TestConfig, action
):
def _foreach_response(stage: Mapping, test_block_config: TestConfig, action):
"""Do something for each response
Args:
Expand All @@ -258,9 +250,9 @@ def _foreach_response(


def get_expected(
stage: collections.abc.Mapping,
stage: Mapping,
test_block_config: TestConfig,
sessions: collections.abc.Mapping,
sessions: Mapping,
):
"""Get expected responses for each type of request
Expand Down Expand Up @@ -294,10 +286,10 @@ def action(p, response_block):


def get_verifiers(
stage: collections.abc.Mapping,
stage: Mapping,
test_block_config: TestConfig,
sessions: collections.abc.Mapping,
expected: collections.abc.Mapping,
sessions: Mapping,
expected: Mapping,
):
"""Get one or more response validators for this stage
Expand Down
7 changes: 4 additions & 3 deletions tavern/_core/pytest/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import dataclasses
from typing import Any

from tavern._core.strict_util import StrictSetting
from tavern._core.strict_util import StrictLevel


@dataclasses.dataclass(frozen=True)
Expand All @@ -22,10 +22,11 @@ class TestConfig:
follow_redirects: whether the test should follow redirects
variables: variables available for use in the stage
strict: Strictness for test/stage
stages: Any extra stages imported from other config files
"""

variables: dict
strict: StrictSetting
strict: StrictLevel
follow_redirects: bool
stages: list

Expand All @@ -42,6 +43,6 @@ def with_new_variables(self) -> "TestConfig":
copied = self.copy()
return dataclasses.replace(copied, variables=copy.copy(self.variables))

def with_strictness(self, new_strict: StrictSetting) -> "TestConfig":
def with_strictness(self, new_strict: StrictLevel) -> "TestConfig":
"""Create a copy of the config but with a new strictness setting"""
return dataclasses.replace(self, strict=new_strict)

0 comments on commit 924463f

Please sign in to comment.