diff --git a/beginner_source/ddp_series_multigpu.rst b/beginner_source/ddp_series_multigpu.rst index 46059f286b1..73e49c6c299 100644 --- a/beginner_source/ddp_series_multigpu.rst +++ b/beginner_source/ddp_series_multigpu.rst @@ -177,8 +177,8 @@ Running the distributed training job + ddp_setup(rank, world_size) dataset, model, optimizer = load_train_objs() train_data = prepare_dataloader(dataset, batch_size=32) - - trainer = Trainer(model, dataset, optimizer, device, save_every) - + trainer = Trainer(model, dataset, optimizer, rank, save_every) + - trainer = Trainer(model, train_data, optimizer, device, save_every) + + trainer = Trainer(model, train_data, optimizer, rank, save_every) trainer.train(total_epochs) + destroy_process_group()