From caa0dee656d634b297e90bd05df67a412edf97dd Mon Sep 17 00:00:00 2001 From: ydcjeff Date: Sat, 10 Apr 2021 16:41:07 +0630 Subject: [PATCH] fix: download datasets only on rank 0 --- templates/gan/main.py | 9 +++++++++ templates/image_classification/main.py | 9 +++++++++ templates/single/main.py | 9 +++++++++ 3 files changed, 27 insertions(+) diff --git a/templates/gan/main.py b/templates/gan/main.py index 25b786d3..e5b36c40 100644 --- a/templates/gan/main.py +++ b/templates/gan/main.py @@ -43,7 +43,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): # datasets and dataloaders # ----------------------------- + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset, num_channels = get_datasets(config.dataset, config.data_path) + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader( train_dataset, batch_size=config.batch_size, diff --git a/templates/image_classification/main.py b/templates/image_classification/main.py index 7efb0e71..37a05341 100644 --- a/templates/image_classification/main.py +++ b/templates/image_classification/main.py @@ -36,7 +36,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset, eval_dataset = get_datasets(path=config.data_path) + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader( train_dataset, batch_size=config.train_batch_size, diff --git a/templates/single/main.py b/templates/single/main.py index 130dcdfb..8d5915e8 100644 --- a/templates/single/main.py +++ b/templates/single/main.py @@ -34,8 +34,17 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset = ... eval_dataset = ... + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader(train_dataset, **kwargs) eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs)