From f6c4dcc0865cfe7acb323de5c224f441d51bb825 Mon Sep 17 00:00:00 2001 From: ydcjeff Date: Sat, 10 Apr 2021 20:48:19 +0630 Subject: [PATCH] fix: create ignite loggers only on rank 0 --- templates/gan/main.py | 21 +++++++++++---------- templates/image_classification/main.py | 18 +++++++++++------- templates/single/main.py | 18 +++++++++++------- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/templates/gan/main.py b/templates/gan/main.py index e5b36c40..94f430ee 100644 --- a/templates/gan/main.py +++ b/templates/gan/main.py @@ -21,13 +21,10 @@ from config import get_default_parser -PRINT_FREQ = 100 FAKE_IMG_FNAME = "fake_sample_epoch_{:04d}.png" REAL_IMG_FNAME = "real_sample_epoch_{:04d}.png" LOGS_FNAME = "logs.tsv" PLOT_FNAME = "plot.svg" -SAMPLES_FNAME = "samples.svg" -CKPT_PREFIX = "networks" def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): @@ -112,7 +109,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): lr_scheduler=lr_scheduler, output_names=["errD", "errG", "D_x", "D_G_z1", "D_G_z2"], ) - logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers) + + # setup ignite logger only on rank 0 + if rank == 0: + logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers) # ----------------------------------- # resume from the saved checkpoints @@ -192,12 +192,13 @@ def create_plots(engine): # close the logger after the training completed / terminated # ------------------------------------------------------------ - if isinstance(logger_handler, WandBLogger): - # why handle differently for wandb ? - # See : https://github.com/pytorch/ignite/issues/1894 - logger_handler.finish() - elif logger_handler: - logger_handler.close() + if rank == 0: + if isinstance(logger_handler, WandBLogger): + # why handle differently for wandb ? + # See : https://github.com/pytorch/ignite/issues/1894 + logger_handler.finish() + elif logger_handler: + logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ? diff --git a/templates/image_classification/main.py b/templates/image_classification/main.py index 37a05341..874279ff 100644 --- a/templates/image_classification/main.py +++ b/templates/image_classification/main.py @@ -125,7 +125,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): lr_scheduler=lr_scheduler, output_names=None, ) - logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) + + # setup ignite logger only on rank 0 + if rank == 0: + logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) # ----------------------------------- # resume from the saved checkpoints @@ -198,12 +201,13 @@ def _(): # close the logger after the training completed / terminated # ------------------------------------------------------------ - if isinstance(logger_handler, WandBLogger): - # why handle differently for wandb ? - # See : https://github.com/pytorch/ignite/issues/1894 - logger_handler.finish() - elif logger_handler: - logger_handler.close() + if rank == 0: + if isinstance(logger_handler, WandBLogger): + # why handle differently for wandb ? + # See : https://github.com/pytorch/ignite/issues/1894 + logger_handler.finish() + elif logger_handler: + logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ? diff --git a/templates/single/main.py b/templates/single/main.py index 8d5915e8..a9a3e082 100644 --- a/templates/single/main.py +++ b/templates/single/main.py @@ -101,7 +101,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): lr_scheduler=lr_scheduler, output_names=None, ) - logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) + + # setup ignite logger only on rank 0 + if rank == 0: + logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) # ----------------------------------- # resume from the saved checkpoints @@ -175,12 +178,13 @@ def _(): # close the logger after the training completed / terminated # ------------------------------------------------------------ - if isinstance(logger_handler, WandBLogger): - # why handle differently for wandb ? - # See : https://github.com/pytorch/ignite/issues/1894 - logger_handler.finish() - elif logger_handler: - logger_handler.close() + if rank == 0: + if isinstance(logger_handler, WandBLogger): + # why handle differently for wandb ? + # See : https://github.com/pytorch/ignite/issues/1894 + logger_handler.finish() + elif logger_handler: + logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ?