diff --git a/src/templates/template-text-classification/data.py b/src/templates/template-text-classification/data.py index 87850c67..233eff87 100644 --- a/src/templates/template-text-classification/data.py +++ b/src/templates/template-text-classification/data.py @@ -42,10 +42,12 @@ def __len__(self): def setup_data(config): + #::: if (it.use_dist) { :::# local_rank = idist.get_local_rank() if local_rank > 0: idist.barrier() + #::: } :::# dataset_train, dataset_eval = load_dataset( "imdb", split=["train", "test"], cache_dir=config.data_path @@ -61,9 +63,10 @@ def setup_data(config): dataset_eval = TransformerDataset( test_texts, test_labels, tokenizer, config.max_length ) - + #::: if (it.use_dist) { :::# if local_rank == 0: idist.barrier() + #::: } :::# dataloader_train = idist.auto_dataloader( dataset_train, diff --git a/src/templates/template-vision-dcgan/data.py b/src/templates/template-vision-dcgan/data.py index e1ddf1e5..78503100 100644 --- a/src/templates/template-vision-dcgan/data.py +++ b/src/templates/template-vision-dcgan/data.py @@ -12,7 +12,9 @@ def setup_data(config: Any): ---------- config: needs to contain `data_path`, `train_batch_size`, `eval_batch_size`, and `num_workers` """ + #::: if (it.use_dist) { :::# local_rank = idist.get_local_rank() + #::: } :::# transform = T.Compose( [ T.Resize(64), @@ -20,10 +22,11 @@ def setup_data(config: Any): T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) - + #::: if (it.use_dist) { :::# if local_rank > 0: # Ensure that only rank 0 download the dataset idist.barrier() + #::: } :::# dataset_train = torchvision.datasets.CIFAR10( root=config.data_path, @@ -38,9 +41,11 @@ def setup_data(config: Any): transform=transform, ) nc = 3 + #::: if (it.use_dist) { :::# if local_rank == 0: # Ensure that only rank 0 download the dataset idist.barrier() + #::: } :::# dataloader_train = idist.auto_dataloader( dataset_train, diff --git a/src/templates/template-vision-segmentation/data.py b/src/templates/template-vision-segmentation/data.py index bc286793..f9c22020 100644 --- a/src/templates/template-vision-segmentation/data.py +++ b/src/templates/template-vision-segmentation/data.py @@ -177,14 +177,17 @@ def prepare_image_mask(batch, device, non_blocking): def download_datasets(data_path): + #::: if (it.use_dist) { :::# local_rank = idist.get_local_rank() if local_rank > 0: # Ensure that only rank 0 download the dataset idist.barrier() + #::: } :::# VOCSegmentation(data_path, image_set="train", download=True) VOCSegmentation(data_path, image_set="val", download=True) - + #::: if (it.use_dist) { :::# if local_rank == 0: # Ensure that only rank 0 download the dataset idist.barrier() + #::: } :::#