Skip to content

Commit fb9e972

Browse files
author
Jeff Yang
authored
fix: download datasets only on rank 0 (#62)
1 parent 3f9bdf7 commit fb9e972

File tree

3 files changed

+27
-0
lines changed

3 files changed

+27
-0
lines changed

templates/gan/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
4343
# datasets and dataloaders
4444
# -----------------------------
4545

46+
if rank > 0:
47+
# Ensure that only rank 0 download the dataset
48+
idist.barrier()
49+
4650
train_dataset, num_channels = get_datasets(config.dataset, config.data_path)
51+
52+
if rank == 0:
53+
# Ensure that only rank 0 download the dataset
54+
idist.barrier()
55+
4756
train_dataloader = idist.auto_dataloader(
4857
train_dataset,
4958
batch_size=config.batch_size,

templates/image_classification/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
3636
# TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
3737
# See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader
3838

39+
if rank > 0:
40+
# Ensure that only rank 0 download the dataset
41+
idist.barrier()
42+
3943
train_dataset, eval_dataset = get_datasets(path=config.data_path)
44+
45+
if rank == 0:
46+
# Ensure that only rank 0 download the dataset
47+
idist.barrier()
48+
4049
train_dataloader = idist.auto_dataloader(
4150
train_dataset,
4251
batch_size=config.train_batch_size,

templates/single/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,17 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
3434
# TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
3535
# See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader
3636

37+
if rank > 0:
38+
# Ensure that only rank 0 download the dataset
39+
idist.barrier()
40+
3741
train_dataset = ...
3842
eval_dataset = ...
43+
44+
if rank == 0:
45+
# Ensure that only rank 0 download the dataset
46+
idist.barrier()
47+
3948
train_dataloader = idist.auto_dataloader(train_dataset, **kwargs)
4049
eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs)
4150

0 commit comments

Comments
 (0)