1919from pytorch_lightning .callbacks .progress import TQDMProgressBar
2020from torch import nn
2121from torch .nn import functional as F
22- from torch .utils .data import DataLoader , random_split
22+ from torch .utils .data import DataLoader , random_split , RandomSampler
2323from torchmetrics import Accuracy
2424from torchvision import transforms
2525from torchvision .datasets import MNIST
@@ -127,7 +127,7 @@ def setup(self, stage=None):
127127 )
128128
129129 def train_dataloader (self ):
130- return DataLoader (self .mnist_train , batch_size = BATCH_SIZE )
130+ return DataLoader (self .mnist_train , batch_size = BATCH_SIZE , sampler = RandomSampler ( self . mnist_train , num_samples = 1000 ) )
131131
132132 def val_dataloader (self ):
133133 return DataLoader (self .mnist_val , batch_size = BATCH_SIZE )
@@ -147,10 +147,11 @@ def test_dataloader(self):
147147trainer = Trainer (
148148 accelerator = "auto" ,
149149 # devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
150- max_epochs = 5 ,
150+ max_epochs = 3 ,
151151 callbacks = [TQDMProgressBar (refresh_rate = 20 )],
152152 num_nodes = int (os .environ .get ("GROUP_WORLD_SIZE" , 1 )),
153153 devices = int (os .environ .get ("LOCAL_WORLD_SIZE" , 1 )),
154+ replace_sampler_ddp = False ,
154155 strategy = "ddp" ,
155156)
156157
0 commit comments