diff --git a/templates/_base/_argparse.py b/templates/_base/_argparse.py index 7e569326..9ce3da27 100644 --- a/templates/_base/_argparse.py +++ b/templates/_base/_argparse.py @@ -1,5 +1,6 @@ {% block imports %} from argparse import ArgumentParser + {% endblock %} {% block defaults %} diff --git a/templates/_base/_handlers.py b/templates/_base/_handlers.py index cab4dc99..08289b87 100644 --- a/templates/_base/_handlers.py +++ b/templates/_base/_handlers.py @@ -2,12 +2,13 @@ Ignite handlers """ from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union + from ignite.contrib.engines import common from ignite.contrib.handlers.base_logger import BaseLogger from ignite.contrib.handlers.param_scheduler import LRScheduler from ignite.engine.engine import Engine from ignite.engine.events import Events -from ignite.handlers import TimeLimit, Timer, Checkpoint, EarlyStopping +from ignite.handlers import Checkpoint, EarlyStopping, TimeLimit, Timer from torch.nn import Module from torch.optim.optimizer import Optimizer from torch.utils.data.distributed import DistributedSampler diff --git a/templates/gan/_test_internal.py b/templates/gan/_test_internal.py index 70f90c6d..64e25adc 100644 --- a/templates/gan/_test_internal.py +++ b/templates/gan/_test_internal.py @@ -23,7 +23,14 @@ from test_all import set_up from torch import nn, optim from trainers import create_trainers -from utils import hash_checkpoint, log_metrics, resume_from, setup_logging, get_handlers, get_logger +from utils import ( + get_handlers, + get_logger, + hash_checkpoint, + log_metrics, + resume_from, + setup_logging, +) def test_get_handlers(tmp_path): diff --git a/templates/gan/datasets.py b/templates/gan/datasets.py index 32a74770..70a030d8 100644 --- a/templates/gan/datasets.py +++ b/templates/gan/datasets.py @@ -1,6 +1,6 @@ -from torchvision import transforms as T -from torchvision import datasets as dset import ignite.distributed as idist +from torchvision import datasets as dset +from torchvision import transforms as T def get_datasets(dataset, dataroot): diff --git a/templates/gan/main.py b/templates/gan/main.py index 010385f9..5a683875 100644 --- a/templates/gan/main.py +++ b/templates/gan/main.py @@ -7,17 +7,24 @@ from pathlib import Path from typing import Any -import torch import ignite.distributed as idist +import torch +from config import get_default_parser +from datasets import get_datasets +from ignite.contrib.handlers.wandb_logger import WandBLogger from ignite.engine.events import Events from ignite.utils import manual_seed from torchvision import utils as vutils - -from datasets import get_datasets from trainers import create_trainers -from utils import setup_logging, log_metrics, log_basic_info, initialize, resume_from, get_handlers, get_logger -from config import get_default_parser - +from utils import ( + get_handlers, + get_logger, + initialize, + log_basic_info, + log_metrics, + resume_from, + setup_logging, +) FAKE_IMG_FNAME = "fake_sample_epoch_{:04d}.png" REAL_IMG_FNAME = "real_sample_epoch_{:04d}.png" diff --git a/templates/gan/trainers.py b/templates/gan/trainers.py index 2e2e1222..2b5ebc4a 100644 --- a/templates/gan/trainers.py +++ b/templates/gan/trainers.py @@ -3,13 +3,12 @@ """ from typing import Any -import torch import ignite.distributed as idist +import torch from ignite.engine import Engine from torch.cuda.amp import autocast from torch.optim.optimizer import Optimizer - # Edit below functions the way how the model will be training # train_function is how the model will be learning with given batch diff --git a/templates/gan/utils.py b/templates/gan/utils.py index b16a8277..1c32767f 100644 --- a/templates/gan/utils.py +++ b/templates/gan/utils.py @@ -8,19 +8,19 @@ from pathlib import Path from pprint import pformat from typing import Any, Mapping, Optional, Tuple, Union -from ignite.contrib.handlers.param_scheduler import ParamScheduler import ignite.distributed as idist import torch -from torch import nn, optim +from ignite.contrib.handlers.param_scheduler import ParamScheduler from ignite.engine import Engine from ignite.handlers.checkpoint import Checkpoint from ignite.utils import setup_logger +from models import Discriminator, Generator +from torch import nn, optim from torch.nn import Module from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer -from models import Generator, Discriminator {% include "_handlers.py" %} diff --git a/templates/image_classification/_test_internal.py b/templates/image_classification/_test_internal.py index 921b9c60..27e608cb 100644 --- a/templates/image_classification/_test_internal.py +++ b/templates/image_classification/_test_internal.py @@ -30,7 +30,14 @@ train_events_to_attr, train_function, ) -from utils import hash_checkpoint, log_metrics, resume_from, setup_logging, get_handlers, get_logger +from utils import ( + get_handlers, + get_logger, + hash_checkpoint, + log_metrics, + resume_from, + setup_logging, +) def test_get_handlers(tmp_path): diff --git a/templates/image_classification/datasets.py b/templates/image_classification/datasets.py index daff42e3..0c494ffe 100644 --- a/templates/image_classification/datasets.py +++ b/templates/image_classification/datasets.py @@ -1,6 +1,13 @@ -from torchvision import datasets -from torchvision.transforms import Compose, Normalize, Pad, RandomCrop, RandomHorizontalFlip, ToTensor import ignite.distributed as idist +from torchvision import datasets +from torchvision.transforms import ( + Compose, + Normalize, + Pad, + RandomCrop, + RandomHorizontalFlip, + ToTensor, +) train_transform = Compose( [ diff --git a/templates/image_classification/main.py b/templates/image_classification/main.py index 0efd92d2..f75e4c81 100644 --- a/templates/image_classification/main.py +++ b/templates/image_classification/main.py @@ -7,14 +7,22 @@ from typing import Any import ignite.distributed as idist +from config import get_default_parser +from datasets import get_datasets +from ignite.contrib.handlers.wandb_logger import WandBLogger from ignite.engine.events import Events -from ignite.utils import manual_seed from ignite.metrics import Accuracy, Loss - -from datasets import get_datasets +from ignite.utils import manual_seed from trainers import create_trainers -from utils import setup_logging, log_metrics, log_basic_info, initialize, resume_from, get_handlers, get_logger -from config import get_default_parser +from utils import ( + get_handlers, + get_logger, + initialize, + log_basic_info, + log_metrics, + resume_from, + setup_logging, +) def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): diff --git a/templates/image_classification/test_all.py b/templates/image_classification/test_all.py index b514d344..734abc08 100644 --- a/templates/image_classification/test_all.py +++ b/templates/image_classification/test_all.py @@ -11,9 +11,7 @@ from torch.functional import Tensor from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import Dataset -from trainers import ( - evaluate_function, -) +from trainers import evaluate_function from utils import initialize diff --git a/templates/image_classification/trainers.py b/templates/image_classification/trainers.py index dba89bbf..db2bff87 100644 --- a/templates/image_classification/trainers.py +++ b/templates/image_classification/trainers.py @@ -2,12 +2,13 @@ `trainer` and `evaluator` like trainer and evaluator """ from typing import Any, Tuple -from ignite.metrics import loss import torch from ignite.engine import Engine +from ignite.metrics import loss from torch.cuda.amp import autocast from torch.optim.optimizer import Optimizer + {% include "_events.py" %} diff --git a/templates/image_classification/utils.py b/templates/image_classification/utils.py index a30a83a0..b222bd79 100644 --- a/templates/image_classification/utils.py +++ b/templates/image_classification/utils.py @@ -8,19 +8,19 @@ from pathlib import Path from pprint import pformat from typing import Any, Mapping, Optional, Tuple, Union -from ignite.contrib.handlers.param_scheduler import ParamScheduler import ignite.distributed as idist import torch +from ignite.contrib.handlers import PiecewiseLinear +from ignite.contrib.handlers.param_scheduler import ParamScheduler from ignite.engine import Engine from ignite.handlers.checkpoint import Checkpoint from ignite.utils import setup_logger -from ignite.contrib.handlers import PiecewiseLinear -from torch.nn import Module, CrossEntropyLoss +from models import get_model +from torch.nn import CrossEntropyLoss, Module +from torch.optim import SGD, Optimizer from torch.optim.lr_scheduler import _LRScheduler -from torch.optim import Optimizer, SGD -from models import get_model {% include "_handlers.py" %} diff --git a/templates/single/_test_internal.py b/templates/single/_test_internal.py index dc79819c..3fc0b1e6 100644 --- a/templates/single/_test_internal.py +++ b/templates/single/_test_internal.py @@ -8,7 +8,6 @@ import pytest import torch from config import get_default_parser - from ignite.contrib.handlers import ( ClearMLLogger, MLflowLogger, @@ -32,7 +31,14 @@ train_events_to_attr, train_function, ) -from utils import hash_checkpoint, log_metrics, resume_from, setup_logging, get_handlers, get_logger +from utils import ( + get_handlers, + get_logger, + hash_checkpoint, + log_metrics, + resume_from, + setup_logging, +) def set_up(): diff --git a/templates/single/main.py b/templates/single/main.py index 9cb27405..7f0d6a66 100644 --- a/templates/single/main.py +++ b/templates/single/main.py @@ -7,13 +7,21 @@ from typing import Any import ignite.distributed as idist +from config import get_default_parser +from datasets import get_datasets +from ignite.contrib.handlers.wandb_logger import WandBLogger from ignite.engine.events import Events from ignite.utils import manual_seed - -from datasets import get_datasets -from trainers import create_trainers, TrainEvents -from utils import setup_logging, log_metrics, log_basic_info, initialize, resume_from, get_handlers, get_logger -from config import get_default_parser +from trainers import TrainEvents, create_trainers +from utils import ( + get_handlers, + get_logger, + initialize, + log_basic_info, + log_metrics, + resume_from, + setup_logging, +) def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): diff --git a/templates/single/trainers.py b/templates/single/trainers.py index 3b41c86a..807ff1a0 100644 --- a/templates/single/trainers.py +++ b/templates/single/trainers.py @@ -7,6 +7,7 @@ from ignite.engine import Engine from torch.cuda.amp import autocast from torch.optim.optimizer import Optimizer + {% include "_events.py" %} diff --git a/templates/single/utils.py b/templates/single/utils.py index 8af101f1..22374598 100644 --- a/templates/single/utils.py +++ b/templates/single/utils.py @@ -8,16 +8,17 @@ from pathlib import Path from pprint import pformat from typing import Any, Mapping, Optional, Tuple, Union -from ignite.contrib.handlers.param_scheduler import ParamScheduler import ignite.distributed as idist import torch +from ignite.contrib.handlers.param_scheduler import ParamScheduler from ignite.engine import Engine from ignite.handlers.checkpoint import Checkpoint from ignite.utils import setup_logger from torch.nn import Module from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer + {% include "_handlers.py" %}