Skip to content

Commit

Permalink
fix(template-segmentation): call .step() / attach to engine (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
ydcjeff committed May 25, 2021
1 parent a073e93 commit 3806a38
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/templates/template-vision-segmentation/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from functools import partial
from pprint import pformat
from typing import Any
from typing import Any, cast

import ignite.distributed as idist
import yaml
from data import denormalize, download_datasets, setup_data
from ignite.contrib.handlers import LRScheduler
from ignite.engine import Events
from ignite.metrics import ConfusionMatrix, IoU, mIoU
from ignite.utils import manual_seed
from torch import nn, optim
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
from torch.utils.data.distributed import DistributedSampler
from torchvision.models.segmentation import deeplabv3_resnet101
from trainers import setup_evaluator, setup_trainer
Expand Down Expand Up @@ -80,6 +81,16 @@ def set_epoch():
):
dataloader_train.sampler.set_epoch(trainer.state.epoch - 1)

if isinstance(lr_scheduler, _LRScheduler):
trainer.add_event_handler(
Events.ITERATION_COMPLETED,
lambda engine: cast(_LRScheduler, lr_scheduler).step(),
)
elif isinstance(lr_scheduler, LRScheduler):
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
else:
trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

# setup ignite handlers
#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::#

Expand Down

0 comments on commit 3806a38

Please sign in to comment.