diff --git a/templates/gan/main.py b/templates/gan/main.py index 25b786d3..e5b36c40 100644 --- a/templates/gan/main.py +++ b/templates/gan/main.py @@ -43,7 +43,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): # datasets and dataloaders # ----------------------------- + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset, num_channels = get_datasets(config.dataset, config.data_path) + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader( train_dataset, batch_size=config.batch_size, diff --git a/templates/image_classification/main.py b/templates/image_classification/main.py index 7efb0e71..37a05341 100644 --- a/templates/image_classification/main.py +++ b/templates/image_classification/main.py @@ -36,7 +36,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset, eval_dataset = get_datasets(path=config.data_path) + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader( train_dataset, batch_size=config.train_batch_size, diff --git a/templates/single/main.py b/templates/single/main.py index 130dcdfb..8d5915e8 100644 --- a/templates/single/main.py +++ b/templates/single/main.py @@ -34,8 +34,17 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset = ... eval_dataset = ... + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader(train_dataset, **kwargs) eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs)