diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md index cdd9b6fdb730..dbc7d53f426e 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.md +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -19,4 +19,3 @@ about: Submit a bug report to help us improve Ignite - How you installed Ignite (`conda`, `pip`, source): - Python version: - Any other relevant information: - diff --git a/.github/workflows/autopep8-black.yml b/.github/workflows/code-style.yml similarity index 61% rename from .github/workflows/autopep8-black.yml rename to .github/workflows/code-style.yml index 9bafa6c396a2..6c8bf9a7698d 100644 --- a/.github/workflows/autopep8-black.yml +++ b/.github/workflows/code-style.yml @@ -1,18 +1,18 @@ name: Format python code on: push jobs: - autopep8-black: + code-style: runs-on: ubuntu-latest steps: - uses: actions/checkout@master - - name: autopep8 - uses: peter-evans/autopep8@v1.0.0 + - uses: actions/setup-python@v2 with: - args: --recursive --in-place --aggressive --aggressive . - - name: autoblack - uses: lgeiger/black-action@v1.0.1 - with: - args: "." + python-version: '3.7' + - run: | + python -m pip install autopep8 black isort + isort -rc . + autopep8 --recursive --in-place --aggressive --aggressive . + black . - name: Commit and push changes uses: stefanzweifel/git-auto-commit-action@v2.0.0 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 69846a72cb9f..f25aa11b8130 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,15 @@ repos: + - repo: https://github.com/asottile/seed-isort-config + rev: v1.9.4 + hooks: + - id: seed-isort-config + args: [--exclude=^((examples|docs)/.*)$] + + - repo: https://github.com/timothycrosley/isort + rev: 4.3.21-2 + hooks: + - id: isort + - repo: https://github.com/python/black rev: 19.10b0 hooks: diff --git a/.travis.yml b/.travis.yml index 82516c5e798b..bc30e0c9ec6f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -51,10 +51,11 @@ jobs: - stage: Lint check python: "3.7" before_install: # Nothing to do - install: pip install flake8 black + install: pip install flake8 black isort script: - flake8 . - black --check . + - isort -rc -c . after_success: # Nothing to do # GitHub Pages Deployment: https://docs.travis-ci.com/user/deployment/pages/ diff --git a/ignite/__init__.py b/ignite/__init__.py index a42a07b75d7c..e60a1f8d3bbd 100644 --- a/ignite/__init__.py +++ b/ignite/__init__.py @@ -1,8 +1,8 @@ +import ignite.contrib import ignite.engine +import ignite.exceptions import ignite.handlers import ignite.metrics -import ignite.exceptions -import ignite.contrib import ignite.utils __version__ = "0.4.0" diff --git a/ignite/_utils.py b/ignite/_utils.py index fe5d830151b5..cd2f428fe316 100644 --- a/ignite/_utils.py +++ b/ignite/_utils.py @@ -1,7 +1,7 @@ -from typing import Union, Tuple +from typing import Tuple, Union # For compatibilty -from ignite.utils import convert_tensor, apply_to_tensor, apply_to_type, to_onehot +from ignite.utils import apply_to_tensor, apply_to_type, convert_tensor, to_onehot def _to_hours_mins_secs(time_taken: Union[float, int]) -> Tuple[int, int, int]: diff --git a/ignite/contrib/engines/__init__.py b/ignite/contrib/engines/__init__.py index 47fab99d4be1..597801bebf61 100644 --- a/ignite/contrib/engines/__init__.py +++ b/ignite/contrib/engines/__init__.py @@ -1,2 +1 @@ -from ignite.contrib.engines.tbptt import create_supervised_tbptt_trainer -from ignite.contrib.engines.tbptt import Tbptt_Events +from ignite.contrib.engines.tbptt import Tbptt_Events, create_supervised_tbptt_trainer diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index aa50c4089016..3eed6a3ec383 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -1,21 +1,24 @@ import numbers import warnings -from collections.abc import Sequence, Mapping +from collections.abc import Mapping, Sequence from functools import partial import torch import torch.distributed as dist -from ignite.contrib.handlers import MLflowLogger -from ignite.contrib.handlers import PolyaxonLogger -from ignite.contrib.handlers import ProgressBar -from ignite.contrib.handlers import TensorboardLogger, global_step_from_engine -from ignite.contrib.handlers import VisdomLogger -from ignite.contrib.handlers import NeptuneLogger -from ignite.contrib.handlers import WandBLogger +from ignite.contrib.handlers import ( + MLflowLogger, + NeptuneLogger, + PolyaxonLogger, + ProgressBar, + TensorboardLogger, + VisdomLogger, + WandBLogger, + global_step_from_engine, +) from ignite.contrib.metrics import GpuInfo from ignite.engine import Engine, Events -from ignite.handlers import TerminateOnNan, ModelCheckpoint, EarlyStopping +from ignite.handlers import EarlyStopping, ModelCheckpoint, TerminateOnNan from ignite.metrics import RunningAverage diff --git a/ignite/contrib/engines/tbptt.py b/ignite/contrib/engines/tbptt.py index 551c6f4c4ecc..f959673c2b0c 100644 --- a/ignite/contrib/engines/tbptt.py +++ b/ignite/contrib/engines/tbptt.py @@ -2,8 +2,8 @@ import torch +from ignite.engine import Engine, EventEnum, _prepare_batch from ignite.utils import apply_to_tensor -from ignite.engine import Engine, _prepare_batch, EventEnum class Tbptt_Events(EventEnum): diff --git a/ignite/contrib/handlers/__init__.py b/ignite/contrib/handlers/__init__.py index 024509e33a1f..a7655237d67f 100644 --- a/ignite/contrib/handlers/__init__.py +++ b/ignite/contrib/handlers/__init__.py @@ -1,21 +1,19 @@ +from ignite.contrib.handlers.base_logger import global_step_from_engine +from ignite.contrib.handlers.custom_events import CustomPeriodicEvent +from ignite.contrib.handlers.lr_finder import FastaiLRFinder +from ignite.contrib.handlers.mlflow_logger import MLflowLogger +from ignite.contrib.handlers.neptune_logger import NeptuneLogger from ignite.contrib.handlers.param_scheduler import ( - LinearCyclicalScheduler, - CosineAnnealingScheduler, ConcatScheduler, + CosineAnnealingScheduler, + LinearCyclicalScheduler, LRScheduler, - create_lr_scheduler_with_warmup, - PiecewiseLinear, ParamGroupScheduler, + PiecewiseLinear, + create_lr_scheduler_with_warmup, ) - -from ignite.contrib.handlers.custom_events import CustomPeriodicEvent - -from ignite.contrib.handlers.tqdm_logger import ProgressBar +from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger +from ignite.contrib.handlers.tqdm_logger import ProgressBar from ignite.contrib.handlers.visdom_logger import VisdomLogger -from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger -from ignite.contrib.handlers.mlflow_logger import MLflowLogger from ignite.contrib.handlers.wandb_logger import WandBLogger -from ignite.contrib.handlers.neptune_logger import NeptuneLogger -from ignite.contrib.handlers.base_logger import global_step_from_engine -from ignite.contrib.handlers.lr_finder import FastaiLRFinder diff --git a/ignite/contrib/handlers/base_logger.py b/ignite/contrib/handlers/base_logger.py index 1f5bec91ce15..e8825b45005b 100644 --- a/ignite/contrib/handlers/base_logger.py +++ b/ignite/contrib/handlers/base_logger.py @@ -1,12 +1,11 @@ -from abc import ABCMeta, abstractmethod import numbers import warnings - -from typing import Mapping, Any +from abc import ABCMeta, abstractmethod +from typing import Any, Mapping import torch -from ignite.engine import State, Engine +from ignite.engine import Engine, State from ignite.handlers import global_step_from_engine diff --git a/ignite/contrib/handlers/custom_events.py b/ignite/contrib/handlers/custom_events.py index c9ec147db98a..1bdf43fd3039 100644 --- a/ignite/contrib/handlers/custom_events.py +++ b/ignite/contrib/handlers/custom_events.py @@ -1,6 +1,7 @@ -from ignite.engine import Events, State, EventEnum import warnings +from ignite.engine import EventEnum, Events, State + class CustomPeriodicEvent: """DEPRECATED. Use filtered events instead. diff --git a/ignite/contrib/handlers/lr_finder.py b/ignite/contrib/handlers/lr_finder.py index e4db2cddaffa..a6804ea3555b 100644 --- a/ignite/contrib/handlers/lr_finder.py +++ b/ignite/contrib/handlers/lr_finder.py @@ -1,17 +1,17 @@ # coding: utf-8 +import contextlib import logging +import tempfile import warnings from collections.abc import Mapping -import tempfile -import contextlib from pathlib import Path import torch from torch.optim.lr_scheduler import _LRScheduler -from ignite.engine import Events, Engine +from ignite.contrib.handlers.param_scheduler import LRScheduler, PiecewiseLinear +from ignite.engine import Engine, Events from ignite.handlers import Checkpoint -from ignite.contrib.handlers import LRScheduler, PiecewiseLinear class FastaiLRFinder: diff --git a/ignite/contrib/handlers/mlflow_logger.py b/ignite/contrib/handlers/mlflow_logger.py index 721b11f9f718..2ffd25306996 100644 --- a/ignite/contrib/handlers/mlflow_logger.py +++ b/ignite/contrib/handlers/mlflow_logger.py @@ -5,8 +5,8 @@ from ignite.contrib.handlers.base_logger import ( BaseLogger, - BaseOutputHandler, BaseOptimizerParamsHandler, + BaseOutputHandler, global_step_from_engine, ) diff --git a/ignite/contrib/handlers/neptune_logger.py b/ignite/contrib/handlers/neptune_logger.py index 36feef6f532d..0e2e29cb6cb1 100644 --- a/ignite/contrib/handlers/neptune_logger.py +++ b/ignite/contrib/handlers/neptune_logger.py @@ -1,13 +1,11 @@ import numbers import tempfile -from typing import Mapping import warnings +from typing import Mapping import torch import ignite - -from ignite.handlers.checkpoint import BaseSaveHandler from ignite.contrib.handlers.base_logger import ( BaseLogger, BaseOptimizerParamsHandler, @@ -15,6 +13,7 @@ BaseWeightsScalarHandler, global_step_from_engine, ) +from ignite.handlers.checkpoint import BaseSaveHandler __all__ = [ "NeptuneLogger", diff --git a/ignite/contrib/handlers/param_scheduler.py b/ignite/contrib/handlers/param_scheduler.py index 9f9299397365..ecbe91213ef4 100644 --- a/ignite/contrib/handlers/param_scheduler.py +++ b/ignite/contrib/handlers/param_scheduler.py @@ -1,16 +1,13 @@ -from collections import OrderedDict -from copy import copy - import math import numbers - from abc import ABCMeta, abstractmethod - -from collections.abc import Sequence, Mapping +from collections import OrderedDict +from collections.abc import Mapping, Sequence +from copy import copy import torch -from torch.optim.optimizer import Optimizer from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer class ParamScheduler(metaclass=ABCMeta): diff --git a/ignite/contrib/handlers/polyaxon_logger.py b/ignite/contrib/handlers/polyaxon_logger.py index 5471a60d9fee..705be6a23aa1 100644 --- a/ignite/contrib/handlers/polyaxon_logger.py +++ b/ignite/contrib/handlers/polyaxon_logger.py @@ -5,8 +5,8 @@ from ignite.contrib.handlers.base_logger import ( BaseLogger, - BaseOutputHandler, BaseOptimizerParamsHandler, + BaseOutputHandler, global_step_from_engine, ) diff --git a/ignite/contrib/handlers/tensorboard_logger.py b/ignite/contrib/handlers/tensorboard_logger.py index b8b9febd9437..c341bdb3834f 100644 --- a/ignite/contrib/handlers/tensorboard_logger.py +++ b/ignite/contrib/handlers/tensorboard_logger.py @@ -7,12 +7,11 @@ BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler, - BaseWeightsScalarHandler, BaseWeightsHistHandler, + BaseWeightsScalarHandler, global_step_from_engine, ) - __all__ = [ "TensorboardLogger", "OptimizerParamsHandler", diff --git a/ignite/contrib/handlers/tqdm_logger.py b/ignite/contrib/handlers/tqdm_logger.py index 8341f636a44e..840298a28818 100644 --- a/ignite/contrib/handlers/tqdm_logger.py +++ b/ignite/contrib/handlers/tqdm_logger.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- import warnings - -from typing import Mapping, Any +from typing import Any, Mapping import torch -from ignite.engine import Events, Engine -from ignite.engine.events import CallableEventWithFilter from ignite.contrib.handlers.base_logger import BaseLogger, BaseOutputHandler +from ignite.engine import Engine, Events +from ignite.engine.events import CallableEventWithFilter class _OutputHandler(BaseOutputHandler): diff --git a/ignite/contrib/handlers/wandb_logger.py b/ignite/contrib/handlers/wandb_logger.py index 8195e1f45063..03fb65215fbd 100644 --- a/ignite/contrib/handlers/wandb_logger.py +++ b/ignite/contrib/handlers/wandb_logger.py @@ -1,7 +1,7 @@ from ignite.contrib.handlers.base_logger import ( BaseLogger, - BaseOutputHandler, BaseOptimizerParamsHandler, + BaseOutputHandler, global_step_from_engine, ) diff --git a/ignite/contrib/metrics/__init__.py b/ignite/contrib/metrics/__init__.py index 8191b3fafce7..e51efe469cac 100644 --- a/ignite/contrib/metrics/__init__.py +++ b/ignite/contrib/metrics/__init__.py @@ -1,5 +1,5 @@ -from ignite.contrib.metrics.average_precision import AveragePrecision -from ignite.contrib.metrics.roc_auc import ROC_AUC, RocCurve -from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve import ignite.contrib.metrics.regression +from ignite.contrib.metrics.average_precision import AveragePrecision from ignite.contrib.metrics.gpu_info import GpuInfo +from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve +from ignite.contrib.metrics.roc_auc import ROC_AUC, RocCurve diff --git a/ignite/contrib/metrics/gpu_info.py b/ignite/contrib/metrics/gpu_info.py index 5e6cad9eb074..d4b60df289ae 100644 --- a/ignite/contrib/metrics/gpu_info.py +++ b/ignite/contrib/metrics/gpu_info.py @@ -3,8 +3,8 @@ import torch -from ignite.metrics import Metric from ignite.engine import Events +from ignite.metrics import Metric class GpuInfo(Metric): diff --git a/ignite/contrib/metrics/regression/__init__.py b/ignite/contrib/metrics/regression/__init__.py index e443646737f7..fbee310e3b29 100644 --- a/ignite/contrib/metrics/regression/__init__.py +++ b/ignite/contrib/metrics/regression/__init__.py @@ -1,15 +1,15 @@ -from ignite.contrib.metrics.regression.maximum_absolute_error import MaximumAbsoluteError +from ignite.contrib.metrics.regression.canberra_metric import CanberraMetric +from ignite.contrib.metrics.regression.fractional_absolute_error import FractionalAbsoluteError from ignite.contrib.metrics.regression.fractional_bias import FractionalBias +from ignite.contrib.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError +from ignite.contrib.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError from ignite.contrib.metrics.regression.manhattan_distance import ManhattanDistance +from ignite.contrib.metrics.regression.maximum_absolute_error import MaximumAbsoluteError +from ignite.contrib.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError from ignite.contrib.metrics.regression.mean_error import MeanError from ignite.contrib.metrics.regression.mean_normalized_bias import MeanNormalizedBias -from ignite.contrib.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError -from ignite.contrib.metrics.regression.canberra_metric import CanberraMetric -from ignite.contrib.metrics.regression.fractional_absolute_error import FractionalAbsoluteError -from ignite.contrib.metrics.regression.wave_hedges_distance import WaveHedgesDistance -from ignite.contrib.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError from ignite.contrib.metrics.regression.median_absolute_error import MedianAbsoluteError -from ignite.contrib.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError from ignite.contrib.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError -from ignite.contrib.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError +from ignite.contrib.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError from ignite.contrib.metrics.regression.r2_score import R2Score +from ignite.contrib.metrics.regression.wave_hedges_distance import WaveHedgesDistance diff --git a/ignite/contrib/metrics/regression/_base.py b/ignite/contrib/metrics/regression/_base.py index 31455d352d91..ef70856e1e8a 100644 --- a/ignite/contrib/metrics/regression/_base.py +++ b/ignite/contrib/metrics/regression/_base.py @@ -2,7 +2,7 @@ import torch -from ignite.metrics import Metric, EpochMetric +from ignite.metrics import EpochMetric, Metric def _check_output_shapes(output): diff --git a/ignite/contrib/metrics/regression/fractional_absolute_error.py b/ignite/contrib/metrics/regression/fractional_absolute_error.py index 1a9a678c1319..70bb48dad768 100644 --- a/ignite/contrib/metrics/regression/fractional_absolute_error.py +++ b/ignite/contrib/metrics/regression/fractional_absolute_error.py @@ -1,7 +1,7 @@ import torch -from ignite.exceptions import NotComputableError from ignite.contrib.metrics.regression._base import _BaseRegression +from ignite.exceptions import NotComputableError class FractionalAbsoluteError(_BaseRegression): diff --git a/ignite/contrib/metrics/regression/fractional_bias.py b/ignite/contrib/metrics/regression/fractional_bias.py index 9968ff21e452..70a9571dd0cd 100644 --- a/ignite/contrib/metrics/regression/fractional_bias.py +++ b/ignite/contrib/metrics/regression/fractional_bias.py @@ -1,7 +1,7 @@ import torch -from ignite.exceptions import NotComputableError from ignite.contrib.metrics.regression._base import _BaseRegression +from ignite.exceptions import NotComputableError class FractionalBias(_BaseRegression): diff --git a/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py b/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py index b2248640eb04..79d77393d37e 100644 --- a/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py +++ b/ignite/contrib/metrics/regression/geometric_mean_absolute_error.py @@ -1,7 +1,7 @@ import torch -from ignite.exceptions import NotComputableError from ignite.contrib.metrics.regression._base import _BaseRegression +from ignite.exceptions import NotComputableError class GeometricMeanAbsoluteError(_BaseRegression): diff --git a/ignite/contrib/metrics/regression/maximum_absolute_error.py b/ignite/contrib/metrics/regression/maximum_absolute_error.py index 85d6d5228c90..fb487095246d 100644 --- a/ignite/contrib/metrics/regression/maximum_absolute_error.py +++ b/ignite/contrib/metrics/regression/maximum_absolute_error.py @@ -1,7 +1,7 @@ import torch -from ignite.exceptions import NotComputableError from ignite.contrib.metrics.regression._base import _BaseRegression +from ignite.exceptions import NotComputableError class MaximumAbsoluteError(_BaseRegression): diff --git a/ignite/contrib/metrics/regression/mean_absolute_relative_error.py b/ignite/contrib/metrics/regression/mean_absolute_relative_error.py index affcd0ac2569..6f14e41e29a9 100644 --- a/ignite/contrib/metrics/regression/mean_absolute_relative_error.py +++ b/ignite/contrib/metrics/regression/mean_absolute_relative_error.py @@ -1,7 +1,7 @@ import torch -from ignite.exceptions import NotComputableError from ignite.contrib.metrics.regression._base import _BaseRegression +from ignite.exceptions import NotComputableError class MeanAbsoluteRelativeError(_BaseRegression): diff --git a/ignite/contrib/metrics/regression/mean_error.py b/ignite/contrib/metrics/regression/mean_error.py index cc8e2483082f..ca13bb481ae5 100644 --- a/ignite/contrib/metrics/regression/mean_error.py +++ b/ignite/contrib/metrics/regression/mean_error.py @@ -1,7 +1,7 @@ import torch -from ignite.exceptions import NotComputableError from ignite.contrib.metrics.regression._base import _BaseRegression +from ignite.exceptions import NotComputableError class MeanError(_BaseRegression): diff --git a/ignite/contrib/metrics/regression/mean_normalized_bias.py b/ignite/contrib/metrics/regression/mean_normalized_bias.py index ece0f10bb7ce..54863c7e37c2 100644 --- a/ignite/contrib/metrics/regression/mean_normalized_bias.py +++ b/ignite/contrib/metrics/regression/mean_normalized_bias.py @@ -1,7 +1,7 @@ import torch -from ignite.exceptions import NotComputableError from ignite.contrib.metrics.regression._base import _BaseRegression +from ignite.exceptions import NotComputableError class MeanNormalizedBias(_BaseRegression): diff --git a/ignite/contrib/metrics/regression/r2_score.py b/ignite/contrib/metrics/regression/r2_score.py index 6724841609ac..59cc23c9db1a 100644 --- a/ignite/contrib/metrics/regression/r2_score.py +++ b/ignite/contrib/metrics/regression/r2_score.py @@ -1,7 +1,7 @@ import torch -from ignite.exceptions import NotComputableError from ignite.contrib.metrics.regression._base import _BaseRegression +from ignite.exceptions import NotComputableError class R2Score(_BaseRegression): diff --git a/ignite/engine/__init__.py b/ignite/engine/__init__.py index 60ce5370649b..3f2dcb08909f 100644 --- a/ignite/engine/__init__.py +++ b/ignite/engine/__init__.py @@ -1,12 +1,12 @@ -from typing import Sequence, Union, Optional, Callable, Dict, Any, Tuple +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union + import torch +from ignite.engine.deterministic import DeterministicEngine from ignite.engine.engine import Engine -from ignite.engine.events import State, Events, EventEnum, CallableEventWithFilter -from ignite.utils import convert_tensor +from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, State from ignite.metrics import Metric -from ignite.engine.deterministic import DeterministicEngine - +from ignite.utils import convert_tensor __all__ = [ "State", diff --git a/ignite/engine/deterministic.py b/ignite/engine/deterministic.py index 6a0e256cff0d..cba625802888 100644 --- a/ignite/engine/deterministic.py +++ b/ignite/engine/deterministic.py @@ -1,15 +1,15 @@ import random import warnings -from functools import wraps -from typing import Optional, Generator, Callable, Iterator from collections import OrderedDict +from functools import wraps +from typing import Callable, Generator, Iterator, Optional import torch + from ignite.engine.engine import Engine from ignite.engine.events import Events from ignite.utils import manual_seed - __all__ = ["update_dataloader", "keep_random_state", "ReproducibleBatchSampler", "DeterministicEngine"] diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 95128abbc3c7..73857fd18c67 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -1,18 +1,16 @@ +import functools import logging import time -from collections import defaultdict, OrderedDict -from collections.abc import Mapping -import weakref import warnings -import functools -from typing import Optional, Callable, Iterable, Any, Tuple, List +import weakref +from collections import OrderedDict, defaultdict +from collections.abc import Mapping +from typing import Any, Callable, Iterable, List, Optional, Tuple -from ignite.engine.events import Events, State, CallableEventWithFilter, RemovableEventHandle, EventsList -from ignite.engine.utils import _check_signature from ignite._utils import _to_hours_mins_secs - from ignite.base import Serializable - +from ignite.engine.events import CallableEventWithFilter, Events, EventsList, RemovableEventHandle, State +from ignite.engine.utils import _check_signature __all__ = ["Engine"] diff --git a/ignite/engine/events.py b/ignite/engine/events.py index aab1863c22c5..82af5a4ef28f 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -1,9 +1,8 @@ -from typing import Callable, Optional, Union, Any - -from enum import Enum import numbers import weakref +from enum import Enum from types import DynamicClassAttribute +from typing import Any, Callable, Optional, Union from ignite.engine.utils import _check_signature diff --git a/ignite/handlers/__init__.py b/ignite/handlers/__init__.py index 2c462372de9e..9100d630a65c 100644 --- a/ignite/handlers/__init__.py +++ b/ignite/handlers/__init__.py @@ -1,12 +1,11 @@ -from typing import Callable, Any, Union +from typing import Any, Callable, Union from ignite.engine import Engine -from ignite.engine.events import EventEnum, CallableEventWithFilter - -from ignite.handlers.checkpoint import ModelCheckpoint, Checkpoint, DiskSaver -from ignite.handlers.timing import Timer +from ignite.engine.events import CallableEventWithFilter, EventEnum +from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint from ignite.handlers.early_stopping import EarlyStopping from ignite.handlers.terminate_on_nan import TerminateOnNan +from ignite.handlers.timing import Timer __all__ = [ "ModelCheckpoint", diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 20021d15f643..6c117f52d318 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -1,17 +1,15 @@ +import collections.abc as collections +import numbers import os import tempfile -import numbers import warnings from abc import ABCMeta, abstractmethod from collections import namedtuple -import collections.abc as collections - - -from typing import Optional, Callable, Mapping, Union +from typing import Callable, Mapping, Optional, Union import torch -from ignite.engine import Events, Engine +from ignite.engine import Engine, Events __all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"] diff --git a/ignite/handlers/terminate_on_nan.py b/ignite/handlers/terminate_on_nan.py index 3ef74b32c995..cff9a3496eef 100644 --- a/ignite/handlers/terminate_on_nan.py +++ b/ignite/handlers/terminate_on_nan.py @@ -1,11 +1,11 @@ import logging import numbers -from typing import Union, Callable +from typing import Callable, Union import torch -from ignite.utils import apply_to_type from ignite.engine import Engine +from ignite.utils import apply_to_type __all__ = ["TerminateOnNan"] diff --git a/ignite/handlers/timing.py b/ignite/handlers/timing.py index b08f759de3e2..336968588586 100644 --- a/ignite/handlers/timing.py +++ b/ignite/handlers/timing.py @@ -1,7 +1,7 @@ from time import perf_counter from typing import Optional -from ignite.engine import Events, Engine +from ignite.engine import Engine, Events __all__ = ["Timer"] diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 74721194dc36..dd477df5deb9 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -1,20 +1,20 @@ -from ignite.metrics.metric import Metric +from ignite.metrics.accumulation import Average, GeometricAverage, VariableAccumulation from ignite.metrics.accuracy import Accuracy +from ignite.metrics.confusion_matrix import ConfusionMatrix, DiceCoefficient, IoU, mIoU +from ignite.metrics.epoch_metric import EpochMetric +from ignite.metrics.fbeta import Fbeta +from ignite.metrics.frequency import Frequency from ignite.metrics.loss import Loss from ignite.metrics.mean_absolute_error import MeanAbsoluteError from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance from ignite.metrics.mean_squared_error import MeanSquaredError -from ignite.metrics.epoch_metric import EpochMetric +from ignite.metrics.metric import Metric +from ignite.metrics.metrics_lambda import MetricsLambda from ignite.metrics.precision import Precision from ignite.metrics.recall import Recall from ignite.metrics.root_mean_squared_error import RootMeanSquaredError -from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy from ignite.metrics.running_average import RunningAverage -from ignite.metrics.metrics_lambda import MetricsLambda -from ignite.metrics.confusion_matrix import ConfusionMatrix, IoU, mIoU, DiceCoefficient -from ignite.metrics.accumulation import VariableAccumulation, Average, GeometricAverage -from ignite.metrics.fbeta import Fbeta -from ignite.metrics.frequency import Frequency +from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy __all__ = [ "Metric", diff --git a/ignite/metrics/accumulation.py b/ignite/metrics/accumulation.py index 1a748151e736..2b178b81cb85 100644 --- a/ignite/metrics/accumulation.py +++ b/ignite/metrics/accumulation.py @@ -1,12 +1,10 @@ import numbers +from typing import Any, Callable, Optional, Union -from typing import Callable, Union, Any, Optional +import torch -from ignite.metrics import Metric -from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced from ignite.exceptions import NotComputableError - -import torch +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce __all__ = ["VariableAccumulation", "GeometricAverage", "Average"] diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 699ba055e61c..d885119ba9f0 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -1,11 +1,10 @@ -from typing import Callable, Union, Optional, Sequence - -from ignite.metrics import Metric -from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced -from ignite.exceptions import NotComputableError +from typing import Callable, Optional, Sequence, Union import torch +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + __all__ = ["Accuracy"] diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index 9ad089101d98..7dc93b181f8c 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -1,11 +1,11 @@ import numbers -from typing import Optional, Union, Any, Callable, Sequence +from typing import Any, Callable, Optional, Sequence, Union import torch -from ignite.metrics import Metric, MetricsLambda from ignite.exceptions import NotComputableError -from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce +from ignite.metrics.metrics_lambda import MetricsLambda __all__ = ["ConfusionMatrix", "mIoU", "IoU", "DiceCoefficient", "cmAccuracy", "cmPrecision", "cmRecall"] diff --git a/ignite/metrics/fbeta.py b/ignite/metrics/fbeta.py index e0b34698004d..b66b9cdc1573 100644 --- a/ignite/metrics/fbeta.py +++ b/ignite/metrics/fbeta.py @@ -1,10 +1,12 @@ -from typing import Optional, Union, Callable - -__all__ = ["Fbeta"] +from typing import Callable, Optional, Union import torch -from ignite.metrics import Precision, Recall, MetricsLambda +from ignite.metrics.metrics_lambda import MetricsLambda +from ignite.metrics.precision import Precision +from ignite.metrics.recall import Recall + +__all__ = ["Fbeta"] def Fbeta( diff --git a/ignite/metrics/frequency.py b/ignite/metrics/frequency.py index 7ae1e0ce8b8d..595dc18522eb 100644 --- a/ignite/metrics/frequency.py +++ b/ignite/metrics/frequency.py @@ -2,9 +2,8 @@ import torch.distributed as dist from ignite.engine import Events -from ignite.metrics import Metric from ignite.handlers.timing import Timer -from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce class Frequency(Metric): diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index 5d3a3fb282b8..a163d527be7d 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -1,10 +1,9 @@ -from typing import Callable, Union, Optional, Sequence +from typing import Callable, Optional, Sequence, Union import torch from ignite.exceptions import NotComputableError -from ignite.metrics import Metric -from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce __all__ = ["Loss"] diff --git a/ignite/metrics/mean_absolute_error.py b/ignite/metrics/mean_absolute_error.py index 11bf5e3d5eb6..dd979b593009 100644 --- a/ignite/metrics/mean_absolute_error.py +++ b/ignite/metrics/mean_absolute_error.py @@ -3,8 +3,7 @@ import torch from ignite.exceptions import NotComputableError -from ignite.metrics.metric import Metric -from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce __all__ = ["MeanAbsoluteError"] diff --git a/ignite/metrics/mean_pairwise_distance.py b/ignite/metrics/mean_pairwise_distance.py index e8f9e85c3c4e..f33f3b370dbc 100644 --- a/ignite/metrics/mean_pairwise_distance.py +++ b/ignite/metrics/mean_pairwise_distance.py @@ -1,11 +1,10 @@ -from typing import Union, Sequence, Optional, Callable +from typing import Callable, Optional, Sequence, Union import torch from torch.nn.functional import pairwise_distance from ignite.exceptions import NotComputableError -from ignite.metrics.metric import Metric -from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce __all__ = ["MeanPairwiseDistance"] diff --git a/ignite/metrics/mean_squared_error.py b/ignite/metrics/mean_squared_error.py index 5c60623cc2cc..2aa5eef62b13 100644 --- a/ignite/metrics/mean_squared_error.py +++ b/ignite/metrics/mean_squared_error.py @@ -1,10 +1,9 @@ -from typing import Union, Sequence +from typing import Sequence, Union import torch from ignite.exceptions import NotComputableError -from ignite.metrics.metric import Metric -from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce __all__ = ["MeanSquaredError"] diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 19357e49876b..fef147d5d4d7 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -1,15 +1,14 @@ import numbers +import warnings from abc import ABCMeta, abstractmethod -from functools import wraps from collections.abc import Mapping -import warnings - -from typing import Callable, Union, Optional, Any +from functools import wraps +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist -from ignite.engine import Events, Engine +from ignite.engine import Engine, Events __all__ = ["Metric"] @@ -214,77 +213,77 @@ def is_attached(self, engine: Engine) -> bool: return engine.has_event_handler(self.completed, Events.EPOCH_COMPLETED) def __add__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x + y, self, other) def __radd__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x + y, other, self) def __sub__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x - y, self, other) def __rsub__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x - y, other, self) def __mul__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x * y, self, other) def __rmul__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x * y, other, self) def __pow__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x ** y, self, other) def __rpow__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x ** y, other, self) def __mod__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x % y, self, other) def __div__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x.__div__(y), self, other) def __rdiv__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x.__div__(y), other, self) def __truediv__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x.__truediv__(y), self, other) def __rtruediv__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x.__truediv__(y), other, self) def __floordiv__(self, other): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x, y: x // y, self, other) def __getattr__(self, attr: str) -> Callable: - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda def fn(x, *args, **kwargs): return getattr(x, attr)(*args, **kwargs) @@ -295,7 +294,7 @@ def wrapper(*args, **kwargs): return wrapper def __getitem__(self, index: Any): - from ignite.metrics import MetricsLambda + from ignite.metrics.metrics_lambda import MetricsLambda return MetricsLambda(lambda x: x[index], self) diff --git a/ignite/metrics/metrics_lambda.py b/ignite/metrics/metrics_lambda.py index 224e35c89cd7..5a0477eca37a 100644 --- a/ignite/metrics/metrics_lambda.py +++ b/ignite/metrics/metrics_lambda.py @@ -1,8 +1,8 @@ import itertools -from typing import Callable, Any +from typing import Any, Callable +from ignite.engine import Engine, Events from ignite.metrics.metric import Metric, reinit__is_reduced -from ignite.engine import Events, Engine __all__ = ["MetricsLambda"] diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index 4a4063c92902..2c3e75d07fc1 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -1,12 +1,12 @@ import warnings -from typing import Sequence, Callable, Optional, Union +from typing import Callable, Optional, Sequence, Union import torch -from ignite.metrics.accuracy import _BaseClassification from ignite.exceptions import NotComputableError -from ignite.utils import to_onehot +from ignite.metrics.accuracy import _BaseClassification from ignite.metrics.metric import reinit__is_reduced +from ignite.utils import to_onehot __all__ = ["Precision"] diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index c7a6739bf0e1..fd185df6095a 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -1,10 +1,10 @@ -from typing import Sequence, Callable, Optional, Union +from typing import Callable, Optional, Sequence, Union import torch +from ignite.metrics.metric import reinit__is_reduced from ignite.metrics.precision import _BasePrecisionRecall from ignite.utils import to_onehot -from ignite.metrics.metric import reinit__is_reduced __all__ = ["Recall"] diff --git a/ignite/metrics/running_average.py b/ignite/metrics/running_average.py index b771ac824b8b..be1830b43749 100644 --- a/ignite/metrics/running_average.py +++ b/ignite/metrics/running_average.py @@ -1,10 +1,9 @@ -from typing import Optional, Union, Callable, Sequence +from typing import Callable, Optional, Sequence, Union import torch -from ignite.engine import Events, Engine -from ignite.metrics import Metric -from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce +from ignite.engine import Engine, Events +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce __all__ = ["RunningAverage"] diff --git a/ignite/metrics/top_k_categorical_accuracy.py b/ignite/metrics/top_k_categorical_accuracy.py index 60ed320a4205..e17cedca9c0d 100644 --- a/ignite/metrics/top_k_categorical_accuracy.py +++ b/ignite/metrics/top_k_categorical_accuracy.py @@ -1,10 +1,9 @@ -from typing import Union, Optional, Callable, Sequence +from typing import Callable, Optional, Sequence, Union import torch -from ignite.metrics.metric import Metric from ignite.exceptions import NotComputableError -from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce __all__ = ["TopKCategoricalAccuracy"] diff --git a/ignite/utils.py b/ignite/utils.py index 254ee3056741..e9f7f6ac8e45 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -1,8 +1,8 @@ -import random import collections.abc as collections import logging +import random from functools import wraps -from typing import Union, Optional, Callable, Any, Type, Tuple +from typing import Any, Callable, Optional, Tuple, Type, Union import torch import torch.distributed as dist diff --git a/setup.cfg b/setup.cfg index f45305bf23e3..20ddf43a4581 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,6 +9,16 @@ exclude = .eggs,*.egg,build,docs/*,.git,versioneer.py,*/conf.py ignore = E402, E721 max_line_length = 120 +[isort] +known_third_party=matplotlib,numpy,pytest,setuptools,sklearn,torch +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +use_parentheses=True +line_length=120 +skip_glob=docs/**,examples/** +filter_files=True + [flake8] max-line-length = 120 ignore = E203,E231,E305,E402,E721,E722,E741,F401,F403,F405,F821,F841,F999,W503 diff --git a/setup.py b/setup.py index 667c5b279bdf..2cdc0ee04735 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ -import os import io +import os import re -from setuptools import setup, find_packages + +from setuptools import find_packages, setup def read(*names, **kwargs): diff --git a/tests/ignite/conftest.py b/tests/ignite/conftest.py index eb2b51723d71..1893baf524ca 100644 --- a/tests/ignite/conftest.py +++ b/tests/ignite/conftest.py @@ -1,11 +1,10 @@ -import tempfile import shutil +import tempfile +import pytest import torch import torch.distributed as dist -import pytest - @pytest.fixture() def dirname(): diff --git a/tests/ignite/contrib/engines/test_common.py b/tests/ignite/contrib/engines/test_common.py index c92d588e241a..9fb20763b72d 100644 --- a/tests/ignite/contrib/engines/test_common.py +++ b/tests/ignite/contrib/engines/test_common.py @@ -1,23 +1,21 @@ import os +from unittest.mock import MagicMock +import pytest import torch import torch.nn as nn -from ignite.engine import Events, Engine +import ignite.contrib.handlers.tensorboard_logger as tb_logger_module +import ignite.contrib.handlers.visdom_logger as visdom_logger_module from ignite.contrib.engines.common import ( - setup_common_training_handlers, - save_best_model_by_val_score, add_early_stopping_by_val_score, + save_best_model_by_val_score, + setup_common_training_handlers, setup_tb_logging, setup_visdom_logging, ) - +from ignite.engine import Engine, Events from ignite.handlers import TerminateOnNan -import ignite.contrib.handlers.tensorboard_logger as tb_logger_module -import ignite.contrib.handlers.visdom_logger as visdom_logger_module - -import pytest -from unittest.mock import MagicMock class DummyModel(nn.Module): diff --git a/tests/ignite/contrib/engines/test_tbptt.py b/tests/ignite/contrib/engines/test_tbptt.py index d7f2700bd200..0582852d4d74 100644 --- a/tests/ignite/contrib/engines/test_tbptt.py +++ b/tests/ignite/contrib/engines/test_tbptt.py @@ -1,13 +1,14 @@ # coding: utf-8 +import unittest.mock as mock + +import pytest import torch import torch.nn as nn -import torch.optim as optim import torch.nn.functional as F -import pytest -import unittest.mock as mock +import torch.optim as optim -from ignite.contrib.engines import create_supervised_tbptt_trainer, Tbptt_Events +from ignite.contrib.engines import Tbptt_Events, create_supervised_tbptt_trainer from ignite.contrib.engines.tbptt import _detach_hidden diff --git a/tests/ignite/contrib/handlers/conftest.py b/tests/ignite/contrib/handlers/conftest.py index 75b890a0b05a..a0c874022120 100644 --- a/tests/ignite/contrib/handlers/conftest.py +++ b/tests/ignite/contrib/handlers/conftest.py @@ -1,10 +1,9 @@ +from unittest.mock import Mock + import numpy as np import pytest - import torch -from unittest.mock import Mock - @pytest.fixture() def norm_mock(): diff --git a/tests/ignite/contrib/handlers/test_base_logger.py b/tests/ignite/contrib/handlers/test_base_logger.py index 872b84745c1d..6b41640d32a6 100644 --- a/tests/ignite/contrib/handlers/test_base_logger.py +++ b/tests/ignite/contrib/handlers/test_base_logger.py @@ -1,14 +1,12 @@ import math -import torch - -from ignite.engine import Engine, State, Events -from ignite.contrib.handlers.base_logger import BaseLogger, BaseOutputHandler, BaseOptimizerParamsHandler -from ignite.contrib.handlers import global_step_from_engine -from ignite.contrib.handlers import CustomPeriodicEvent +from unittest.mock import MagicMock import pytest +import torch -from unittest.mock import MagicMock +from ignite.contrib.handlers import CustomPeriodicEvent, global_step_from_engine +from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler +from ignite.engine import Engine, Events, State class DummyOutputHandler(BaseOutputHandler): diff --git a/tests/ignite/contrib/handlers/test_custom_events.py b/tests/ignite/contrib/handlers/test_custom_events.py index 006ac5fc8dd6..9e9fb73da778 100644 --- a/tests/ignite/contrib/handlers/test_custom_events.py +++ b/tests/ignite/contrib/handlers/test_custom_events.py @@ -1,8 +1,9 @@ import math + import pytest -from ignite.engine import Engine from ignite.contrib.handlers.custom_events import CustomPeriodicEvent +from ignite.engine import Engine def test_bad_input(): diff --git a/tests/ignite/contrib/handlers/test_lr_finder.py b/tests/ignite/contrib/handlers/test_lr_finder.py index 26da79bba269..8de6d45f23fb 100644 --- a/tests/ignite/contrib/handlers/test_lr_finder.py +++ b/tests/ignite/contrib/handlers/test_lr_finder.py @@ -1,17 +1,15 @@ import copy import matplotlib - -matplotlib.use("agg") - +import pytest import torch from torch import nn from torch.optim import SGD -from ignite.engine import create_supervised_trainer from ignite.contrib.handlers import FastaiLRFinder +from ignite.engine import create_supervised_trainer -import pytest +matplotlib.use("agg") @pytest.fixture diff --git a/tests/ignite/contrib/handlers/test_mlflow_logger.py b/tests/ignite/contrib/handlers/test_mlflow_logger.py index df646f2278c0..5de42801d7b3 100644 --- a/tests/ignite/contrib/handlers/test_mlflow_logger.py +++ b/tests/ignite/contrib/handlers/test_mlflow_logger.py @@ -1,12 +1,11 @@ import os -import pytest - from unittest.mock import MagicMock, call +import pytest import torch -from ignite.engine import Engine, Events, State from ignite.contrib.handlers.mlflow_logger import * +from ignite.engine import Engine, Events, State def test_output_handler_with_wrong_logger_type(): diff --git a/tests/ignite/contrib/handlers/test_neptune_logger.py b/tests/ignite/contrib/handlers/test_neptune_logger.py index 5b1ba12c3648..b15b38e79ba7 100644 --- a/tests/ignite/contrib/handlers/test_neptune_logger.py +++ b/tests/ignite/contrib/handlers/test_neptune_logger.py @@ -1,13 +1,13 @@ import math import warnings +from unittest.mock import ANY, MagicMock, call -from unittest.mock import call, ANY, MagicMock import pytest import torch +from ignite.contrib.handlers.neptune_logger import * from ignite.engine import Engine, Events, State from ignite.handlers.checkpoint import Checkpoint -from ignite.contrib.handlers.neptune_logger import * def test_optimizer_params_handler_wrong_setup(): diff --git a/tests/ignite/contrib/handlers/test_param_scheduler.py b/tests/ignite/contrib/handlers/test_param_scheduler.py index 7aced3ef5d47..f7b74c19494e 100644 --- a/tests/ignite/contrib/handlers/test_param_scheduler.py +++ b/tests/ignite/contrib/handlers/test_param_scheduler.py @@ -1,13 +1,17 @@ -import pytest - import numpy as np - +import pytest import torch +from ignite.contrib.handlers.param_scheduler import ( + ConcatScheduler, + CosineAnnealingScheduler, + LinearCyclicalScheduler, + LRScheduler, + ParamGroupScheduler, + PiecewiseLinear, + create_lr_scheduler_with_warmup, +) from ignite.engine import Engine, Events -from ignite.contrib.handlers.param_scheduler import LinearCyclicalScheduler, CosineAnnealingScheduler -from ignite.contrib.handlers.param_scheduler import ConcatScheduler, LRScheduler, create_lr_scheduler_with_warmup -from ignite.contrib.handlers.param_scheduler import ParamGroupScheduler, PiecewiseLinear def test_linear_scheduler(): diff --git a/tests/ignite/contrib/handlers/test_polyaxon_logger.py b/tests/ignite/contrib/handlers/test_polyaxon_logger.py index f167192c1882..4e28791d9a35 100644 --- a/tests/ignite/contrib/handlers/test_polyaxon_logger.py +++ b/tests/ignite/contrib/handlers/test_polyaxon_logger.py @@ -1,12 +1,11 @@ import os -import pytest - from unittest.mock import MagicMock, call +import pytest import torch -from ignite.engine import Engine, Events, State from ignite.contrib.handlers.polyaxon_logger import * +from ignite.engine import Engine, Events, State os.environ["POLYAXON_NO_OP"] = "1" diff --git a/tests/ignite/contrib/handlers/test_tensorboard_logger.py b/tests/ignite/contrib/handlers/test_tensorboard_logger.py index f88800f4ad62..b945e0275f57 100644 --- a/tests/ignite/contrib/handlers/test_tensorboard_logger.py +++ b/tests/ignite/contrib/handlers/test_tensorboard_logger.py @@ -1,14 +1,12 @@ -import os import math +import os +from unittest.mock import ANY, MagicMock, call, patch import pytest - -from unittest.mock import MagicMock, call, ANY, patch - import torch -from ignite.engine import Engine, Events, State from ignite.contrib.handlers.tensorboard_logger import * +from ignite.engine import Engine, Events, State def test_optimizer_params_handler_wrong_setup(): diff --git a/tests/ignite/contrib/handlers/test_time_profilers.py b/tests/ignite/contrib/handlers/test_time_profilers.py index 0c9195964a64..169fcf7a6720 100644 --- a/tests/ignite/contrib/handlers/test_time_profilers.py +++ b/tests/ignite/contrib/handlers/test_time_profilers.py @@ -1,11 +1,11 @@ -import time import os - -from ignite.engine import Engine, Events -from ignite.contrib.handlers.time_profilers import BasicTimeProfiler +import time from pytest import approx +from ignite.contrib.handlers.time_profilers import BasicTimeProfiler +from ignite.engine import Engine, Events + def _do_nothing_update_fn(engine, batch): pass diff --git a/tests/ignite/contrib/handlers/test_tqdm_logger.py b/tests/ignite/contrib/handlers/test_tqdm_logger.py index 9fbf40cbec17..6649357ca4d5 100644 --- a/tests/ignite/contrib/handlers/test_tqdm_logger.py +++ b/tests/ignite/contrib/handlers/test_tqdm_logger.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- import time + import numpy as np import pytest import torch -from ignite.contrib.handlers import ProgressBar +from ignite.contrib.handlers import CustomPeriodicEvent, ProgressBar from ignite.engine import Engine, Events -from ignite.metrics import RunningAverage -from ignite.contrib.handlers import CustomPeriodicEvent from ignite.handlers import TerminateOnNan +from ignite.metrics import RunningAverage def update_fn(engine, batch): diff --git a/tests/ignite/contrib/handlers/test_visdom_logger.py b/tests/ignite/contrib/handlers/test_visdom_logger.py index 457f1d3b6b1f..ab94d98b4bd8 100644 --- a/tests/ignite/contrib/handlers/test_visdom_logger.py +++ b/tests/ignite/contrib/handlers/test_visdom_logger.py @@ -1,11 +1,11 @@ -import torch -import pytest +from unittest.mock import ANY, MagicMock, call -from unittest.mock import MagicMock, call, ANY +import pytest +import torch -from ignite.engine import Engine, Events, State from ignite.contrib.handlers.visdom_logger import * from ignite.contrib.handlers.visdom_logger import _DummyExecutor +from ignite.engine import Engine, Events, State @pytest.fixture diff --git a/tests/ignite/contrib/handlers/test_wandb_logger.py b/tests/ignite/contrib/handlers/test_wandb_logger.py index 28173c62242a..ae6f8528c1a2 100644 --- a/tests/ignite/contrib/handlers/test_wandb_logger.py +++ b/tests/ignite/contrib/handlers/test_wandb_logger.py @@ -1,9 +1,10 @@ -from unittest.mock import call, MagicMock +from unittest.mock import MagicMock, call + import pytest import torch -from ignite.engine import Events, State from ignite.contrib.handlers.wandb_logger import * +from ignite.engine import Events, State def test_optimizer_params_handler_wrong_setup(): diff --git a/tests/ignite/contrib/metrics/regression/test__base.py b/tests/ignite/contrib/metrics/regression/test__base.py index c47646d79e47..897225911811 100644 --- a/tests/ignite/contrib/metrics/regression/test__base.py +++ b/tests/ignite/contrib/metrics/regression/test__base.py @@ -1,6 +1,7 @@ -import torch import numpy as np import pytest +import torch + from ignite.contrib.metrics.regression._base import _BaseRegression, _BaseRegressionEpoch diff --git a/tests/ignite/contrib/metrics/regression/test_canberra_metric.py b/tests/ignite/contrib/metrics/regression/test_canberra_metric.py index 4ad552016a67..4c939720ebf5 100644 --- a/tests/ignite/contrib/metrics/regression/test_canberra_metric.py +++ b/tests/ignite/contrib/metrics/regression/test_canberra_metric.py @@ -1,6 +1,6 @@ -import torch import numpy as np import pytest +import torch from ignite.contrib.metrics.regression import CanberraMetric diff --git a/tests/ignite/contrib/metrics/regression/test_fractional_absolute_error.py b/tests/ignite/contrib/metrics/regression/test_fractional_absolute_error.py index 70f41ecb033d..369dcd5c2f08 100644 --- a/tests/ignite/contrib/metrics/regression/test_fractional_absolute_error.py +++ b/tests/ignite/contrib/metrics/regression/test_fractional_absolute_error.py @@ -1,8 +1,9 @@ -from ignite.exceptions import NotComputableError -from ignite.contrib.metrics.regression import FractionalAbsoluteError -import torch import numpy as np import pytest +import torch + +from ignite.contrib.metrics.regression import FractionalAbsoluteError +from ignite.exceptions import NotComputableError def test_zero_div(): diff --git a/tests/ignite/contrib/metrics/regression/test_fractional_bias.py b/tests/ignite/contrib/metrics/regression/test_fractional_bias.py index 30e940d7e7b8..3bea55b2be8b 100644 --- a/tests/ignite/contrib/metrics/regression/test_fractional_bias.py +++ b/tests/ignite/contrib/metrics/regression/test_fractional_bias.py @@ -1,8 +1,9 @@ -from ignite.exceptions import NotComputableError -from ignite.contrib.metrics.regression import FractionalBias -import torch import numpy as np import pytest +import torch + +from ignite.contrib.metrics.regression import FractionalBias +from ignite.exceptions import NotComputableError def test_zero_div(): diff --git a/tests/ignite/contrib/metrics/regression/test_geometric_mean_absolute_error.py b/tests/ignite/contrib/metrics/regression/test_geometric_mean_absolute_error.py index b75a5bc323d4..9e8d1f43480b 100644 --- a/tests/ignite/contrib/metrics/regression/test_geometric_mean_absolute_error.py +++ b/tests/ignite/contrib/metrics/regression/test_geometric_mean_absolute_error.py @@ -1,8 +1,9 @@ -from ignite.exceptions import NotComputableError -from ignite.contrib.metrics.regression import GeometricMeanAbsoluteError -import torch import numpy as np import pytest +import torch + +from ignite.contrib.metrics.regression import GeometricMeanAbsoluteError +from ignite.exceptions import NotComputableError def test_zero_div(): diff --git a/tests/ignite/contrib/metrics/regression/test_geometric_mean_relative_absolute_error.py b/tests/ignite/contrib/metrics/regression/test_geometric_mean_relative_absolute_error.py index 546fc68f3d70..28d25375fb1a 100644 --- a/tests/ignite/contrib/metrics/regression/test_geometric_mean_relative_absolute_error.py +++ b/tests/ignite/contrib/metrics/regression/test_geometric_mean_relative_absolute_error.py @@ -1,8 +1,9 @@ -import torch import numpy as np import pytest -from ignite.engine import Engine +import torch + from ignite.contrib.metrics.regression import GeometricMeanRelativeAbsoluteError +from ignite.engine import Engine def test_wrong_input_shapes(): diff --git a/tests/ignite/contrib/metrics/regression/test_manhattan_distance.py b/tests/ignite/contrib/metrics/regression/test_manhattan_distance.py index 4cd1c24458ca..c9f9aa7bae56 100644 --- a/tests/ignite/contrib/metrics/regression/test_manhattan_distance.py +++ b/tests/ignite/contrib/metrics/regression/test_manhattan_distance.py @@ -1,6 +1,6 @@ -import torch import numpy as np import pytest +import torch from ignite.contrib.metrics.regression import ManhattanDistance diff --git a/tests/ignite/contrib/metrics/regression/test_maximum_absolute_error.py b/tests/ignite/contrib/metrics/regression/test_maximum_absolute_error.py index ef5d3aafe2a8..1b5b98ccb625 100644 --- a/tests/ignite/contrib/metrics/regression/test_maximum_absolute_error.py +++ b/tests/ignite/contrib/metrics/regression/test_maximum_absolute_error.py @@ -1,8 +1,9 @@ -from ignite.exceptions import NotComputableError -from ignite.contrib.metrics.regression import MaximumAbsoluteError -import torch import numpy as np import pytest +import torch + +from ignite.contrib.metrics.regression import MaximumAbsoluteError +from ignite.exceptions import NotComputableError def test_zero_div(): diff --git a/tests/ignite/contrib/metrics/regression/test_mean_absolute_relative_error.py b/tests/ignite/contrib/metrics/regression/test_mean_absolute_relative_error.py index a2e622c995d1..7a1e2da58138 100644 --- a/tests/ignite/contrib/metrics/regression/test_mean_absolute_relative_error.py +++ b/tests/ignite/contrib/metrics/regression/test_mean_absolute_relative_error.py @@ -1,8 +1,8 @@ import torch from pytest import approx, raises -from ignite.exceptions import NotComputableError from ignite.contrib.metrics.regression import MeanAbsoluteRelativeError +from ignite.exceptions import NotComputableError def test_wrong_input_shapes(): diff --git a/tests/ignite/contrib/metrics/regression/test_mean_error.py b/tests/ignite/contrib/metrics/regression/test_mean_error.py index 7e92be9753c2..86a30f3895f8 100644 --- a/tests/ignite/contrib/metrics/regression/test_mean_error.py +++ b/tests/ignite/contrib/metrics/regression/test_mean_error.py @@ -1,9 +1,9 @@ -import torch import numpy as np import pytest +import torch -from ignite.exceptions import NotComputableError from ignite.contrib.metrics.regression import MeanError +from ignite.exceptions import NotComputableError def test_zero_div(): diff --git a/tests/ignite/contrib/metrics/regression/test_mean_normalized_bias.py b/tests/ignite/contrib/metrics/regression/test_mean_normalized_bias.py index 143e8abefb94..a48bae39a844 100644 --- a/tests/ignite/contrib/metrics/regression/test_mean_normalized_bias.py +++ b/tests/ignite/contrib/metrics/regression/test_mean_normalized_bias.py @@ -1,9 +1,9 @@ -import torch import numpy as np import pytest +import torch -from ignite.exceptions import NotComputableError from ignite.contrib.metrics.regression import MeanNormalizedBias +from ignite.exceptions import NotComputableError def test_zero_div(): diff --git a/tests/ignite/contrib/metrics/regression/test_median_absolute_error.py b/tests/ignite/contrib/metrics/regression/test_median_absolute_error.py index 53acde1e86c0..9d32c60cc271 100644 --- a/tests/ignite/contrib/metrics/regression/test_median_absolute_error.py +++ b/tests/ignite/contrib/metrics/regression/test_median_absolute_error.py @@ -1,8 +1,9 @@ -import torch import numpy as np import pytest -from ignite.engine import Engine +import torch + from ignite.contrib.metrics.regression import MedianAbsoluteError +from ignite.engine import Engine def test_wrong_input_shapes(): diff --git a/tests/ignite/contrib/metrics/regression/test_median_absolute_percentage_error.py b/tests/ignite/contrib/metrics/regression/test_median_absolute_percentage_error.py index 34069556820b..bc00077f71f0 100644 --- a/tests/ignite/contrib/metrics/regression/test_median_absolute_percentage_error.py +++ b/tests/ignite/contrib/metrics/regression/test_median_absolute_percentage_error.py @@ -1,8 +1,9 @@ -import torch import numpy as np import pytest -from ignite.engine import Engine +import torch + from ignite.contrib.metrics.regression import MedianAbsolutePercentageError +from ignite.engine import Engine def test_wrong_input_shapes(): diff --git a/tests/ignite/contrib/metrics/regression/test_median_relative_absolute_error.py b/tests/ignite/contrib/metrics/regression/test_median_relative_absolute_error.py index c445ed44f3fe..68d2f6cc289b 100644 --- a/tests/ignite/contrib/metrics/regression/test_median_relative_absolute_error.py +++ b/tests/ignite/contrib/metrics/regression/test_median_relative_absolute_error.py @@ -1,8 +1,9 @@ -import torch import numpy as np import pytest -from ignite.engine import Engine +import torch + from ignite.contrib.metrics.regression import MedianRelativeAbsoluteError +from ignite.engine import Engine def test_wrong_input_shapes(): diff --git a/tests/ignite/contrib/metrics/regression/test_r2_score.py b/tests/ignite/contrib/metrics/regression/test_r2_score.py index 3fb4e411bee5..9f3e48bdecca 100644 --- a/tests/ignite/contrib/metrics/regression/test_r2_score.py +++ b/tests/ignite/contrib/metrics/regression/test_r2_score.py @@ -1,10 +1,11 @@ -import torch import numpy as np import pytest -from ignite.engine import Engine -from ignite.contrib.metrics.regression import R2Score +import torch from sklearn.metrics import r2_score +from ignite.contrib.metrics.regression import R2Score +from ignite.engine import Engine + def test_wrong_input_shapes(): m = R2Score() diff --git a/tests/ignite/contrib/metrics/regression/test_wave_hedges_distance.py b/tests/ignite/contrib/metrics/regression/test_wave_hedges_distance.py index bd8d1d558f73..06dddc1b1b93 100644 --- a/tests/ignite/contrib/metrics/regression/test_wave_hedges_distance.py +++ b/tests/ignite/contrib/metrics/regression/test_wave_hedges_distance.py @@ -1,6 +1,6 @@ -import torch import numpy as np import pytest +import torch from ignite.contrib.metrics.regression import WaveHedgesDistance diff --git a/tests/ignite/contrib/metrics/test_average_precision.py b/tests/ignite/contrib/metrics/test_average_precision.py index c32357eab351..f37a82f02064 100644 --- a/tests/ignite/contrib/metrics/test_average_precision.py +++ b/tests/ignite/contrib/metrics/test_average_precision.py @@ -1,10 +1,9 @@ import numpy as np -from sklearn.metrics import average_precision_score - import torch +from sklearn.metrics import average_precision_score -from ignite.engine import Engine from ignite.contrib.metrics import AveragePrecision +from ignite.engine import Engine def test_ap_score(): diff --git a/tests/ignite/contrib/metrics/test_gpu_info.py b/tests/ignite/contrib/metrics/test_gpu_info.py index 05ef7a7fc785..5c4965150214 100644 --- a/tests/ignite/contrib/metrics/test_gpu_info.py +++ b/tests/ignite/contrib/metrics/test_gpu_info.py @@ -1,12 +1,11 @@ import sys +from unittest.mock import Mock, patch +import pytest import torch -from ignite.engine import Engine, State from ignite.contrib.metrics import GpuInfo - -import pytest -from unittest.mock import Mock, patch +from ignite.engine import Engine, State python_below_36 = (sys.version[0] == "3" and int(sys.version[2]) < 6) or int(sys.version[0]) < 2 diff --git a/tests/ignite/contrib/metrics/test_precision_recall_curve.py b/tests/ignite/contrib/metrics/test_precision_recall_curve.py index e48d37919abc..7a5d643aecb0 100644 --- a/tests/ignite/contrib/metrics/test_precision_recall_curve.py +++ b/tests/ignite/contrib/metrics/test_precision_recall_curve.py @@ -1,7 +1,6 @@ import numpy as np -from sklearn.metrics import precision_recall_curve - import torch +from sklearn.metrics import precision_recall_curve from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve from ignite.engine import Engine diff --git a/tests/ignite/contrib/metrics/test_roc_auc.py b/tests/ignite/contrib/metrics/test_roc_auc.py index 4fff4ce6521c..92ce92d775ea 100644 --- a/tests/ignite/contrib/metrics/test_roc_auc.py +++ b/tests/ignite/contrib/metrics/test_roc_auc.py @@ -1,10 +1,9 @@ import numpy as np -from sklearn.metrics import roc_auc_score - import torch +from sklearn.metrics import roc_auc_score -from ignite.engine import Engine from ignite.contrib.metrics import ROC_AUC +from ignite.engine import Engine def test_roc_auc_score(): diff --git a/tests/ignite/contrib/metrics/test_roc_curve.py b/tests/ignite/contrib/metrics/test_roc_curve.py index 0564dc81d007..696078b4ae63 100644 --- a/tests/ignite/contrib/metrics/test_roc_curve.py +++ b/tests/ignite/contrib/metrics/test_roc_curve.py @@ -1,7 +1,6 @@ import numpy as np -from sklearn.metrics import roc_curve - import torch +from sklearn.metrics import roc_curve from ignite.contrib.metrics.roc_auc import RocCurve from ignite.engine import Engine diff --git a/tests/ignite/engine/test_create_supervised.py b/tests/ignite/engine/test_create_supervised.py index 706c40412cb5..3788705d6098 100644 --- a/tests/ignite/engine/test_create_supervised.py +++ b/tests/ignite/engine/test_create_supervised.py @@ -1,14 +1,13 @@ from typing import Optional import pytest -from pytest import approx - import torch +from pytest import approx from torch.nn import Linear from torch.nn.functional import mse_loss from torch.optim import SGD -from ignite.engine import create_supervised_trainer, create_supervised_evaluator +from ignite.engine import create_supervised_evaluator, create_supervised_trainer from ignite.metrics import MeanSquaredError try: diff --git a/tests/ignite/engine/test_custom_events.py b/tests/ignite/engine/test_custom_events.py index 02bf04f1b8aa..adbbc15ec42f 100644 --- a/tests/ignite/engine/test_custom_events.py +++ b/tests/ignite/engine/test_custom_events.py @@ -1,14 +1,12 @@ from enum import Enum - from unittest.mock import MagicMock +import pytest import torch from ignite.engine import Engine, Events from ignite.engine.events import CallableEventWithFilter, EventEnum, EventsList -import pytest - def test_custom_events(): class CustomEvents(EventEnum): diff --git a/tests/ignite/engine/test_deterministic.py b/tests/ignite/engine/test_deterministic.py index 0633725a00d6..9e567d163305 100644 --- a/tests/ignite/engine/test_deterministic.py +++ b/tests/ignite/engine/test_deterministic.py @@ -1,24 +1,21 @@ import os -import pytest import random from unittest.mock import patch import numpy as np - +import pytest import torch import torch.nn as nn +from ignite.engine import Events from ignite.engine.deterministic import ( + DeterministicEngine, ReproducibleBatchSampler, - update_dataloader, keep_random_state, - DeterministicEngine, + update_dataloader, ) - -from ignite.engine import Events from ignite.utils import manual_seed - -from tests.ignite.engine import setup_sampler, BatchChecker +from tests.ignite.engine import BatchChecker, setup_sampler def test_update_dataloader(): diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index e9707dcafb8b..8ecec15803e2 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -1,16 +1,15 @@ import os import time -import pytest -from unittest.mock import call, MagicMock, Mock +from unittest.mock import MagicMock, Mock, call import numpy as np +import pytest import torch from ignite.engine import Engine, Events, State from ignite.engine.deterministic import keep_random_state from ignite.metrics import Average - -from tests.ignite.engine import IterationCounter, EpochCounter, BatchChecker +from tests.ignite.engine import BatchChecker, EpochCounter, IterationCounter def test_terminate(): diff --git a/tests/ignite/engine/test_engine_state_dict.py b/tests/ignite/engine/test_engine_state_dict.py index eff30963d9f9..bb4d485ff024 100644 --- a/tests/ignite/engine/test_engine_state_dict.py +++ b/tests/ignite/engine/test_engine_state_dict.py @@ -1,12 +1,10 @@ import os -import pytest - from collections.abc import Mapping +import pytest import torch -from ignite.engine import Engine, State, Events - +from ignite.engine import Engine, Events, State from tests.ignite.engine import BatchChecker, EpochCounter, IterationCounter diff --git a/tests/ignite/engine/test_event_handlers.py b/tests/ignite/engine/test_event_handlers.py index 1dd1d67ea30f..fa1a1ec6c30d 100644 --- a/tests/ignite/engine/test_event_handlers.py +++ b/tests/ignite/engine/test_event_handlers.py @@ -1,6 +1,5 @@ import gc - -from unittest.mock import call, MagicMock, create_autospec +from unittest.mock import MagicMock, call, create_autospec import pytest from pytest import raises diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 30efb90164a2..c646bc8578e1 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1,16 +1,15 @@ import os import warnings +from unittest.mock import MagicMock +import pytest import torch import torch.nn as nn from ignite.engine import Engine, Events, State -from ignite.handlers import ModelCheckpoint, Checkpoint, DiskSaver +from ignite.handlers import Checkpoint, DiskSaver, ModelCheckpoint from ignite.handlers.checkpoint import BaseSaveHandler -import pytest -from unittest.mock import MagicMock - _PREFIX = "PREFIX" diff --git a/tests/ignite/handlers/test_early_stopping.py b/tests/ignite/handlers/test_early_stopping.py index a626fca165cf..23e9d5777d78 100644 --- a/tests/ignite/handlers/test_early_stopping.py +++ b/tests/ignite/handlers/test_early_stopping.py @@ -1,11 +1,11 @@ import os + +import pytest import torch from ignite.engine import Engine, Events from ignite.handlers import EarlyStopping -import pytest - def do_nothing_update_fn(engine, batch): pass diff --git a/tests/ignite/metrics/test_accumulation.py b/tests/ignite/metrics/test_accumulation.py index 2a101a7099d8..a50920cf7b87 100644 --- a/tests/ignite/metrics/test_accumulation.py +++ b/tests/ignite/metrics/test_accumulation.py @@ -1,13 +1,12 @@ import os -import torch -import numpy as np +import numpy as np import pytest +import torch -from ignite.metrics.accumulation import VariableAccumulation, Average, GeometricAverage +from ignite.engine import Engine, Events from ignite.exceptions import NotComputableError -from ignite.engine import Events, Engine - +from ignite.metrics.accumulation import Average, GeometricAverage, VariableAccumulation torch.manual_seed(15) diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index 5622ce7cf12a..ef0a6834fd5f 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -1,10 +1,11 @@ import os -from ignite.exceptions import NotComputableError -from ignite.metrics import Accuracy + import pytest import torch from sklearn.metrics import accuracy_score +from ignite.exceptions import NotComputableError +from ignite.metrics import Accuracy torch.manual_seed(12) diff --git a/tests/ignite/metrics/test_confusion_matrix.py b/tests/ignite/metrics/test_confusion_matrix.py index 6e7ed4ec5a21..f2cea2cf33c2 100644 --- a/tests/ignite/metrics/test_confusion_matrix.py +++ b/tests/ignite/metrics/test_confusion_matrix.py @@ -1,14 +1,13 @@ import os -import torch import numpy as np -from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score +import pytest +import torch +from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score from ignite.exceptions import NotComputableError from ignite.metrics import ConfusionMatrix, IoU, mIoU -from ignite.metrics.confusion_matrix import cmAccuracy, cmPrecision, cmRecall, DiceCoefficient -import pytest - +from ignite.metrics.confusion_matrix import DiceCoefficient, cmAccuracy, cmPrecision, cmRecall torch.manual_seed(12) diff --git a/tests/ignite/metrics/test_dill.py b/tests/ignite/metrics/test_dill.py index c70afd324a6d..8a7b2d85408d 100644 --- a/tests/ignite/metrics/test_dill.py +++ b/tests/ignite/metrics/test_dill.py @@ -1,5 +1,4 @@ import dill - from ignite.metrics import Metric diff --git a/tests/ignite/metrics/test_epoch_metric.py b/tests/ignite/metrics/test_epoch_metric.py index 3989c6c3475e..516572471e7e 100644 --- a/tests/ignite/metrics/test_epoch_metric.py +++ b/tests/ignite/metrics/test_epoch_metric.py @@ -1,11 +1,11 @@ import os + +import pytest import torch from ignite.metrics import EpochMetric from ignite.metrics.epoch_metric import EpochMetricWarning -import pytest - def test_epoch_metric_wrong_setup_or_input(): diff --git a/tests/ignite/metrics/test_fbeta.py b/tests/ignite/metrics/test_fbeta.py index 71a75ab9cce3..14050e7a0de8 100644 --- a/tests/ignite/metrics/test_fbeta.py +++ b/tests/ignite/metrics/test_fbeta.py @@ -1,15 +1,13 @@ import os -import numpy as np -from sklearn.metrics import fbeta_score +import numpy as np +import pytest import torch +from sklearn.metrics import fbeta_score from ignite.engine import Engine from ignite.metrics import Fbeta, Precision, Recall -import pytest - - torch.manual_seed(12) diff --git a/tests/ignite/metrics/test_frequency.py b/tests/ignite/metrics/test_frequency.py index 275e585c382b..fedb943a6b4f 100644 --- a/tests/ignite/metrics/test_frequency.py +++ b/tests/ignite/metrics/test_frequency.py @@ -1,7 +1,6 @@ import time import pytest - import torch.distributed as dist from ignite.engine import Engine, Events @@ -63,7 +62,6 @@ def test_frequency_with_engine_distributed(distributed_context_single_node_gloo) def test_frequency_with_engine_with_every(): device = "cpu" _test_frequency_with_engine(device, workers=1, every=1) - _test_frequency_with_engine(device, workers=1, every=2) _test_frequency_with_engine(device, workers=1, every=10) @@ -71,5 +69,4 @@ def test_frequency_with_engine_with_every(): def test_frequency_with_engine_distributed_with_every(distributed_context_single_node_gloo): device = "cpu" _test_frequency_with_engine(device, workers=dist.get_world_size(), every=1) - _test_frequency_with_engine(device, workers=dist.get_world_size(), every=2) _test_frequency_with_engine(device, workers=dist.get_world_size(), every=10) diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 4daad2d8aec2..ed471c8e6c97 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -1,15 +1,14 @@ import os +import pytest import torch +from numpy.testing import assert_almost_equal from torch import nn from torch.nn.functional import nll_loss from ignite.exceptions import NotComputableError from ignite.metrics import Loss -import pytest -from numpy.testing import assert_almost_equal - def test_zero_div(): loss = Loss(nll_loss) diff --git a/tests/ignite/metrics/test_mean_absolute_error.py b/tests/ignite/metrics/test_mean_absolute_error.py index 37012fe8f270..c562f02a271f 100644 --- a/tests/ignite/metrics/test_mean_absolute_error.py +++ b/tests/ignite/metrics/test_mean_absolute_error.py @@ -1,12 +1,11 @@ import os +import pytest import torch from ignite.exceptions import NotComputableError from ignite.metrics import MeanAbsoluteError -import pytest - def test_zero_div(): mae = MeanAbsoluteError() diff --git a/tests/ignite/metrics/test_mean_pairwise_distance.py b/tests/ignite/metrics/test_mean_pairwise_distance.py index 7b0dd4b4a99b..a7ecc71462b4 100644 --- a/tests/ignite/metrics/test_mean_pairwise_distance.py +++ b/tests/ignite/metrics/test_mean_pairwise_distance.py @@ -1,13 +1,12 @@ import os +import pytest import torch +from pytest import approx from ignite.exceptions import NotComputableError from ignite.metrics import MeanPairwiseDistance -import pytest -from pytest import approx - def test_zero_div(): mpd = MeanPairwiseDistance() diff --git a/tests/ignite/metrics/test_mean_squared_error.py b/tests/ignite/metrics/test_mean_squared_error.py index f8c690753386..acc396fd3997 100644 --- a/tests/ignite/metrics/test_mean_squared_error.py +++ b/tests/ignite/metrics/test_mean_squared_error.py @@ -1,12 +1,11 @@ import os +import pytest import torch from ignite.exceptions import NotComputableError from ignite.metrics import MeanSquaredError -import pytest - def test_zero_div(): mse = MeanSquaredError() diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index 198e4c969720..53d978ba3547 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -1,19 +1,17 @@ import numbers import os import sys - -import torch - -from ignite.metrics import Metric, Precision, Recall, ConfusionMatrix -from ignite.metrics.metric import reinit__is_reduced -from ignite.engine import Engine, State - from unittest.mock import MagicMock + +import numpy as np import pytest +import torch from pytest import approx, raises +from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score -import numpy as np -from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix +from ignite.engine import Engine, State +from ignite.metrics import ConfusionMatrix, Metric, Precision, Recall +from ignite.metrics.metric import reinit__is_reduced class DummyMetric1(Metric): diff --git a/tests/ignite/metrics/test_metrics_lambda.py b/tests/ignite/metrics/test_metrics_lambda.py index f8621aa4b17d..d25b305723a8 100644 --- a/tests/ignite/metrics/test_metrics_lambda.py +++ b/tests/ignite/metrics/test_metrics_lambda.py @@ -1,16 +1,14 @@ import os import numpy as np -from sklearn.metrics import precision_score, recall_score, f1_score - +import pytest import torch +from pytest import approx +from sklearn.metrics import f1_score, precision_score, recall_score from ignite.engine import Engine from ignite.metrics import Metric, MetricsLambda, Precision, Recall -import pytest -from pytest import approx - class ListGatherMetric(Metric): def __init__(self, index): diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index f18c834856f5..b2b267ea4edd 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -1,16 +1,14 @@ import os +import warnings +import pytest import torch +from sklearn.exceptions import UndefinedMetricWarning +from sklearn.metrics import precision_score from ignite.exceptions import NotComputableError from ignite.metrics import Precision -import pytest -import warnings - -from sklearn.metrics import precision_score -from sklearn.exceptions import UndefinedMetricWarning - torch.manual_seed(12) diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index d43459d6c9d7..2ad27a7b1364 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -1,16 +1,14 @@ import os +import warnings +import pytest import torch +from sklearn.exceptions import UndefinedMetricWarning +from sklearn.metrics import recall_score from ignite.exceptions import NotComputableError from ignite.metrics import Recall -import pytest -import warnings - -from sklearn.metrics import recall_score -from sklearn.exceptions import UndefinedMetricWarning - torch.manual_seed(12) diff --git a/tests/ignite/metrics/test_root_mean_squared_error.py b/tests/ignite/metrics/test_root_mean_squared_error.py index 04e87ea86bf3..077a41fc7f51 100644 --- a/tests/ignite/metrics/test_root_mean_squared_error.py +++ b/tests/ignite/metrics/test_root_mean_squared_error.py @@ -1,12 +1,11 @@ import os +import pytest import torch from ignite.exceptions import NotComputableError from ignite.metrics import RootMeanSquaredError -import pytest - def test_zero_div(): rmse = RootMeanSquaredError() diff --git a/tests/ignite/metrics/test_running_average.py b/tests/ignite/metrics/test_running_average.py index af9d37bff77a..929bb7b9d587 100644 --- a/tests/ignite/metrics/test_running_average.py +++ b/tests/ignite/metrics/test_running_average.py @@ -2,13 +2,12 @@ from functools import partial import numpy as np +import pytest import torch from ignite.engine import Engine, Events from ignite.metrics import Accuracy, RunningAverage -import pytest - def test_wrong_input_args(): with pytest.raises(TypeError, match=r"Argument src should be a Metric or None."): diff --git a/tests/ignite/metrics/test_top_k_categorical_accuracy.py b/tests/ignite/metrics/test_top_k_categorical_accuracy.py index 3dae259fd68e..8b37355a12cc 100644 --- a/tests/ignite/metrics/test_top_k_categorical_accuracy.py +++ b/tests/ignite/metrics/test_top_k_categorical_accuracy.py @@ -1,12 +1,11 @@ import os +import pytest import torch from ignite.exceptions import NotComputableError from ignite.metrics import TopKCategoricalAccuracy -import pytest - def test_zero_div(): acc = TopKCategoricalAccuracy(2) diff --git a/tests/ignite/test_utils.py b/tests/ignite/test_utils.py index e5077e4fd16d..5426c16cc16a 100644 --- a/tests/ignite/test_utils.py +++ b/tests/ignite/test_utils.py @@ -1,12 +1,13 @@ -import os import logging +import os +from collections import namedtuple + import pytest import torch import torch.distributed as dist -from collections import namedtuple -from ignite.utils import convert_tensor, to_onehot, setup_logger, one_rank_only from ignite.engine import Engine, Events +from ignite.utils import convert_tensor, one_rank_only, setup_logger, to_onehot def test_convert_tensor():