Skip to content

Commit

Permalink
fix: make a seed respect to distributed settings (#60)
Browse files Browse the repository at this point in the history
* fix: make a seed in run

* fix: download datasets only on rank 0 (#62)

* fix: create ignite loggers only on rank 0 (#64)
  • Loading branch information
Jeff Yang committed Apr 10, 2021
1 parent 920ac61 commit b17397b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 27 deletions.
37 changes: 26 additions & 11 deletions templates/gan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,35 @@
from config import get_default_parser


PRINT_FREQ = 100
FAKE_IMG_FNAME = "fake_sample_epoch_{:04d}.png"
REAL_IMG_FNAME = "real_sample_epoch_{:04d}.png"
LOGS_FNAME = "logs.tsv"
PLOT_FNAME = "plot.svg"
SAMPLES_FNAME = "samples.svg"
CKPT_PREFIX = "networks"


def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
"""function to be run by idist.Parallel context manager."""

# ----------------------
# make a certain seed
# ----------------------
rank = idist.get_rank()
manual_seed(config.seed + rank)

# -----------------------------
# datasets and dataloaders
# -----------------------------

if rank > 0:
# Ensure that only rank 0 download the dataset
idist.barrier()

train_dataset, num_channels = get_datasets(config.dataset, config.data_path)

if rank == 0:
# Ensure that only rank 0 download the dataset
idist.barrier()

train_dataloader = idist.auto_dataloader(
train_dataset,
batch_size=config.batch_size,
Expand Down Expand Up @@ -97,7 +109,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
lr_scheduler=lr_scheduler,
output_names=["errD", "errG", "D_x", "D_G_z1", "D_G_z2"],
)
logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers)

# setup ignite logger only on rank 0
if rank == 0:
logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers)

# -----------------------------------
# resume from the saved checkpoints
Expand Down Expand Up @@ -177,12 +192,13 @@ def create_plots(engine):
# close the logger after the training completed / terminated
# ------------------------------------------------------------

if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()
if rank == 0:
if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()

# -----------------------------------------
# where is my best and last checkpoint ?
Expand All @@ -194,7 +210,6 @@ def create_plots(engine):
def main():
parser = ArgumentParser(parents=[get_default_parser()])
config = parser.parse_args()
manual_seed(config.seed)

if config.output_dir:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down
34 changes: 26 additions & 8 deletions templates/image_classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
"""function to be run by idist.Parallel context manager."""

# ----------------------
# make a certain seed
# ----------------------
rank = idist.get_rank()
manual_seed(config.seed + rank)

# -----------------------------
# datasets and dataloaders
# -----------------------------
Expand All @@ -30,7 +36,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
# TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
# See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader

if rank > 0:
# Ensure that only rank 0 download the dataset
idist.barrier()

train_dataset, eval_dataset = get_datasets(path=config.data_path)

if rank == 0:
# Ensure that only rank 0 download the dataset
idist.barrier()

train_dataloader = idist.auto_dataloader(
train_dataset,
batch_size=config.train_batch_size,
Expand Down Expand Up @@ -110,7 +125,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
lr_scheduler=lr_scheduler,
output_names=None,
)
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)

# setup ignite logger only on rank 0
if rank == 0:
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)

# -----------------------------------
# resume from the saved checkpoints
Expand Down Expand Up @@ -183,12 +201,13 @@ def _():
# close the logger after the training completed / terminated
# ------------------------------------------------------------

if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()
if rank == 0:
if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()

# -----------------------------------------
# where is my best and last checkpoint ?
Expand All @@ -200,7 +219,6 @@ def _():
def main():
parser = ArgumentParser(parents=[get_default_parser()])
config = parser.parse_args()
manual_seed(config.seed)

if config.output_dir:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down
34 changes: 26 additions & 8 deletions templates/single/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
"""function to be run by idist.Parallel context manager."""

# ----------------------
# make a certain seed
# ----------------------
rank = idist.get_rank()
manual_seed(config.seed + rank)

# -----------------------------
# datasets and dataloaders
# -----------------------------
Expand All @@ -28,8 +34,17 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
# TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
# See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader

if rank > 0:
# Ensure that only rank 0 download the dataset
idist.barrier()

train_dataset = ...
eval_dataset = ...

if rank == 0:
# Ensure that only rank 0 download the dataset
idist.barrier()

train_dataloader = idist.auto_dataloader(train_dataset, **kwargs)
eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs)

Expand Down Expand Up @@ -86,7 +101,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
lr_scheduler=lr_scheduler,
output_names=None,
)
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)

# setup ignite logger only on rank 0
if rank == 0:
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)

# -----------------------------------
# resume from the saved checkpoints
Expand Down Expand Up @@ -160,12 +178,13 @@ def _():
# close the logger after the training completed / terminated
# ------------------------------------------------------------

if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()
if rank == 0:
if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()

# -----------------------------------------
# where is my best and last checkpoint ?
Expand All @@ -177,7 +196,6 @@ def _():
def main():
parser = ArgumentParser(parents=[get_default_parser()])
config = parser.parse_args()
manual_seed(config.seed)

if config.output_dir:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down

0 comments on commit b17397b

Please sign in to comment.