diff --git a/.travis.yml b/.travis.yml index e148849bbb98..d9a8b82f74b5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -38,9 +38,11 @@ before_install: &before_install - pip install -r requirements-dev.txt install: + - pip install "mypy==0.782" - python setup.py install script: + - mypy --config-file mypy.ini - CI_PYTHON_VERSION="$TRAVIS_PYTHON_VERSION" sh tests/run_cpu_tests.sh after_success: @@ -51,12 +53,11 @@ jobs: - stage: Lint check python: "3.7" before_install: # Nothing to do - install: pip install flake8 "black==19.10b0" "isort==4.3.21" "mypy==0.782" + install: pip install flake8 "black==19.10b0" "isort==4.3.21" script: - flake8 . - black --check . - isort -rc -c . - - mypy --config-file mypy.ini after_success: # Nothing to do # GitHub Pages Deployment: https://docs.travis-ci.com/user/deployment/pages/ diff --git a/ignite/utils.py b/ignite/utils.py index 75c41e7ed0ea..b28b0b2544b6 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -1,7 +1,7 @@ import collections.abc as collections import logging import random -from typing import Any, Callable, Optional, Tuple, Type, Union +from typing import Any, Callable, Optional, Tuple, Type, Union, cast import torch @@ -41,11 +41,13 @@ def apply_to_type( if isinstance(input_, (str, bytes)): return input_ if isinstance(input_, collections.Mapping): - return type(input_)({k: apply_to_type(sample, input_type, func) for k, sample in input_.items()}) + return cast(Callable, type(input_))( + {k: apply_to_type(sample, input_type, func) for k, sample in input_.items()} + ) if isinstance(input_, tuple) and hasattr(input_, "_fields"): # namedtuple - return type(input_)(*(apply_to_type(sample, input_type, func) for sample in input_)) + return cast(Callable, type(input_))(*(apply_to_type(sample, input_type, func) for sample in input_)) if isinstance(input_, collections.Sequence): - return type(input_)([apply_to_type(sample, input_type, func) for sample in input_]) + return cast(Callable, type(input_))([apply_to_type(sample, input_type, func) for sample in input_]) raise TypeError(("input must contain {}, dicts or lists; found {}".format(input_type, type(input_)))) diff --git a/mypy.ini b/mypy.ini index 5ef86a5e30fe..586ae4633cf3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -27,6 +27,5 @@ ignore_errors = True ignore_errors = True -[mypy-ignite.utils.*] - -ignore_errors = True +[mypy-numpy.*] +ignore_missing_imports = True