From 41b02b1dcfc50f066ec3a63e5a5b247abd83f66f Mon Sep 17 00:00:00 2001 From: ydcjeff <32727188+ydcjeff@users.noreply.github.com> Date: Sat, 17 Apr 2021 23:00:14 +0630 Subject: [PATCH] style: formatted with isort black --- templates/_base/_argparse.py | 1 + templates/_base/_handlers.py | 3 ++- templates/gan/_test_internal.py | 9 +++++++- templates/gan/datasets.py | 4 ++-- templates/gan/main.py | 20 +++++++++++------- templates/gan/trainers.py | 3 +-- templates/gan/utils.py | 6 +++--- .../image_classification/_test_internal.py | 9 +++++++- templates/image_classification/datasets.py | 11 ++++++++-- templates/image_classification/main.py | 21 ++++++++++++------- templates/image_classification/test_all.py | 4 +--- templates/image_classification/trainers.py | 3 ++- templates/image_classification/utils.py | 10 ++++----- templates/single/_test_internal.py | 10 +++++++-- templates/single/main.py | 19 +++++++++++------ templates/single/trainers.py | 1 + templates/single/utils.py | 3 ++- 17 files changed, 93 insertions(+), 44 deletions(-) diff --git a/templates/_base/_argparse.py b/templates/_base/_argparse.py index a2a89791..32069d09 100644 --- a/templates/_base/_argparse.py +++ b/templates/_base/_argparse.py @@ -1,5 +1,6 @@ {% block imports %} from argparse import ArgumentParser + {% endblock %} {% block defaults %} diff --git a/templates/_base/_handlers.py b/templates/_base/_handlers.py index e90677a9..af45b1fb 100644 --- a/templates/_base/_handlers.py +++ b/templates/_base/_handlers.py @@ -2,12 +2,13 @@ Ignite handlers """ from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union + from ignite.contrib.engines import common from ignite.contrib.handlers.base_logger import BaseLogger from ignite.contrib.handlers.param_scheduler import LRScheduler from ignite.engine.engine import Engine from ignite.engine.events import Events -from ignite.handlers import TimeLimit, Timer, Checkpoint, EarlyStopping +from ignite.handlers import Checkpoint, EarlyStopping, TimeLimit, Timer from torch.nn import Module from torch.optim.optimizer import Optimizer from torch.utils.data.distributed import DistributedSampler diff --git a/templates/gan/_test_internal.py b/templates/gan/_test_internal.py index 1710f354..f7f7f041 100644 --- a/templates/gan/_test_internal.py +++ b/templates/gan/_test_internal.py @@ -23,7 +23,14 @@ from test_all import set_up from torch import nn, optim from trainers import create_trainers -from utils import hash_checkpoint, log_metrics, resume_from, setup_logging, get_handlers, get_logger +from utils import ( + get_handlers, + get_logger, + hash_checkpoint, + log_metrics, + resume_from, + setup_logging, +) def test_get_handlers(tmp_path): diff --git a/templates/gan/datasets.py b/templates/gan/datasets.py index 32a74770..70a030d8 100644 --- a/templates/gan/datasets.py +++ b/templates/gan/datasets.py @@ -1,6 +1,6 @@ -from torchvision import transforms as T -from torchvision import datasets as dset import ignite.distributed as idist +from torchvision import datasets as dset +from torchvision import transforms as T def get_datasets(dataset, dataroot): diff --git a/templates/gan/main.py b/templates/gan/main.py index 454eb2ff..74916dae 100644 --- a/templates/gan/main.py +++ b/templates/gan/main.py @@ -6,19 +6,25 @@ from datetime import datetime from pathlib import Path from typing import Any -from ignite.contrib.handlers.wandb_logger import WandBLogger -import torch import ignite.distributed as idist +import torch +from config import get_default_parser +from datasets import get_datasets +from ignite.contrib.handlers.wandb_logger import WandBLogger from ignite.engine.events import Events from ignite.utils import manual_seed from torchvision import utils as vutils - -from datasets import get_datasets from trainers import create_trainers -from utils import setup_logging, log_metrics, log_basic_info, initialize, resume_from, get_handlers, get_logger -from config import get_default_parser - +from utils import ( + get_handlers, + get_logger, + initialize, + log_basic_info, + log_metrics, + resume_from, + setup_logging, +) FAKE_IMG_FNAME = "fake_sample_epoch_{:04d}.png" REAL_IMG_FNAME = "real_sample_epoch_{:04d}.png" diff --git a/templates/gan/trainers.py b/templates/gan/trainers.py index 71cebb06..a108bbc5 100644 --- a/templates/gan/trainers.py +++ b/templates/gan/trainers.py @@ -3,13 +3,12 @@ """ from typing import Any -import torch import ignite.distributed as idist +import torch from ignite.engine import Engine from torch.cuda.amp import autocast from torch.optim.optimizer import Optimizer - # Edit below functions the way how the model will be training # train_function is how the model will be learning with given batch diff --git a/templates/gan/utils.py b/templates/gan/utils.py index b16a8277..1c32767f 100644 --- a/templates/gan/utils.py +++ b/templates/gan/utils.py @@ -8,19 +8,19 @@ from pathlib import Path from pprint import pformat from typing import Any, Mapping, Optional, Tuple, Union -from ignite.contrib.handlers.param_scheduler import ParamScheduler import ignite.distributed as idist import torch -from torch import nn, optim +from ignite.contrib.handlers.param_scheduler import ParamScheduler from ignite.engine import Engine from ignite.handlers.checkpoint import Checkpoint from ignite.utils import setup_logger +from models import Discriminator, Generator +from torch import nn, optim from torch.nn import Module from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer -from models import Generator, Discriminator {% include "_handlers.py" %} diff --git a/templates/image_classification/_test_internal.py b/templates/image_classification/_test_internal.py index 3a92e89d..b405a502 100644 --- a/templates/image_classification/_test_internal.py +++ b/templates/image_classification/_test_internal.py @@ -30,7 +30,14 @@ train_events_to_attr, train_function, ) -from utils import hash_checkpoint, log_metrics, resume_from, setup_logging, get_handlers, get_logger +from utils import ( + get_handlers, + get_logger, + hash_checkpoint, + log_metrics, + resume_from, + setup_logging, +) def test_get_handlers(tmp_path): diff --git a/templates/image_classification/datasets.py b/templates/image_classification/datasets.py index daff42e3..0c494ffe 100644 --- a/templates/image_classification/datasets.py +++ b/templates/image_classification/datasets.py @@ -1,6 +1,13 @@ -from torchvision import datasets -from torchvision.transforms import Compose, Normalize, Pad, RandomCrop, RandomHorizontalFlip, ToTensor import ignite.distributed as idist +from torchvision import datasets +from torchvision.transforms import ( + Compose, + Normalize, + Pad, + RandomCrop, + RandomHorizontalFlip, + ToTensor, +) train_transform = Compose( [ diff --git a/templates/image_classification/main.py b/templates/image_classification/main.py index fe687758..39771a69 100644 --- a/templates/image_classification/main.py +++ b/templates/image_classification/main.py @@ -5,17 +5,24 @@ from datetime import datetime from pathlib import Path from typing import Any -from ignite.contrib.handlers.wandb_logger import WandBLogger import ignite.distributed as idist +from config import get_default_parser +from datasets import get_datasets +from ignite.contrib.handlers.wandb_logger import WandBLogger from ignite.engine.events import Events -from ignite.utils import manual_seed from ignite.metrics import Accuracy, Loss - -from datasets import get_datasets -from trainers import create_trainers, TrainEvents -from utils import setup_logging, log_metrics, log_basic_info, initialize, resume_from, get_handlers, get_logger -from config import get_default_parser +from ignite.utils import manual_seed +from trainers import TrainEvents, create_trainers +from utils import ( + get_handlers, + get_logger, + initialize, + log_basic_info, + log_metrics, + resume_from, + setup_logging, +) def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): diff --git a/templates/image_classification/test_all.py b/templates/image_classification/test_all.py index b514d344..734abc08 100644 --- a/templates/image_classification/test_all.py +++ b/templates/image_classification/test_all.py @@ -11,9 +11,7 @@ from torch.functional import Tensor from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import Dataset -from trainers import ( - evaluate_function, -) +from trainers import evaluate_function from utils import initialize diff --git a/templates/image_classification/trainers.py b/templates/image_classification/trainers.py index 0d551025..3c9f7955 100644 --- a/templates/image_classification/trainers.py +++ b/templates/image_classification/trainers.py @@ -2,12 +2,13 @@ `train_engine` and `eval_engine` like trainer and evaluator """ from typing import Any, Tuple -from ignite.metrics import loss import torch from ignite.engine import Engine +from ignite.metrics import loss from torch.cuda.amp import autocast from torch.optim.optimizer import Optimizer + {% include "_events.py" %} diff --git a/templates/image_classification/utils.py b/templates/image_classification/utils.py index a30a83a0..b222bd79 100644 --- a/templates/image_classification/utils.py +++ b/templates/image_classification/utils.py @@ -8,19 +8,19 @@ from pathlib import Path from pprint import pformat from typing import Any, Mapping, Optional, Tuple, Union -from ignite.contrib.handlers.param_scheduler import ParamScheduler import ignite.distributed as idist import torch +from ignite.contrib.handlers import PiecewiseLinear +from ignite.contrib.handlers.param_scheduler import ParamScheduler from ignite.engine import Engine from ignite.handlers.checkpoint import Checkpoint from ignite.utils import setup_logger -from ignite.contrib.handlers import PiecewiseLinear -from torch.nn import Module, CrossEntropyLoss +from models import get_model +from torch.nn import CrossEntropyLoss, Module +from torch.optim import SGD, Optimizer from torch.optim.lr_scheduler import _LRScheduler -from torch.optim import Optimizer, SGD -from models import get_model {% include "_handlers.py" %} diff --git a/templates/single/_test_internal.py b/templates/single/_test_internal.py index 140f26c9..d3c6b80c 100644 --- a/templates/single/_test_internal.py +++ b/templates/single/_test_internal.py @@ -8,7 +8,6 @@ import pytest import torch from config import get_default_parser - from ignite.contrib.handlers import ( ClearMLLogger, MLflowLogger, @@ -32,7 +31,14 @@ train_events_to_attr, train_function, ) -from utils import hash_checkpoint, log_metrics, resume_from, setup_logging, get_handlers, get_logger +from utils import ( + get_handlers, + get_logger, + hash_checkpoint, + log_metrics, + resume_from, + setup_logging, +) def set_up(): diff --git a/templates/single/main.py b/templates/single/main.py index 8ad7c496..0cb833ca 100644 --- a/templates/single/main.py +++ b/templates/single/main.py @@ -5,16 +5,23 @@ from datetime import datetime from pathlib import Path from typing import Any -from ignite.contrib.handlers.wandb_logger import WandBLogger import ignite.distributed as idist +from config import get_default_parser +from datasets import get_datasets +from ignite.contrib.handlers.wandb_logger import WandBLogger from ignite.engine.events import Events from ignite.utils import manual_seed - -from datasets import get_datasets -from trainers import create_trainers, TrainEvents -from utils import setup_logging, log_metrics, log_basic_info, initialize, resume_from, get_handlers, get_logger -from config import get_default_parser +from trainers import TrainEvents, create_trainers +from utils import ( + get_handlers, + get_logger, + initialize, + log_basic_info, + log_metrics, + resume_from, + setup_logging, +) def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): diff --git a/templates/single/trainers.py b/templates/single/trainers.py index eef5518d..3bd20b38 100644 --- a/templates/single/trainers.py +++ b/templates/single/trainers.py @@ -7,6 +7,7 @@ from ignite.engine import Engine from torch.cuda.amp import autocast from torch.optim.optimizer import Optimizer + {% include "_events.py" %} diff --git a/templates/single/utils.py b/templates/single/utils.py index 8af101f1..22374598 100644 --- a/templates/single/utils.py +++ b/templates/single/utils.py @@ -8,16 +8,17 @@ from pathlib import Path from pprint import pformat from typing import Any, Mapping, Optional, Tuple, Union -from ignite.contrib.handlers.param_scheduler import ParamScheduler import ignite.distributed as idist import torch +from ignite.contrib.handlers.param_scheduler import ParamScheduler from ignite.engine import Engine from ignite.handlers.checkpoint import Checkpoint from ignite.utils import setup_logger from torch.nn import Module from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer + {% include "_handlers.py" %}