From e468a79b6c2fecf3e76c9bb25198e0f5636e927b Mon Sep 17 00:00:00 2001 From: Suraj Subramanian <5676233+suraj813@users.noreply.github.com> Date: Thu, 16 Feb 2023 11:06:57 -0500 Subject: [PATCH] Update ddp_series_multigpu.rst --- beginner_source/ddp_series_multigpu.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/beginner_source/ddp_series_multigpu.rst b/beginner_source/ddp_series_multigpu.rst index 46059f286b1..73e49c6c299 100644 --- a/beginner_source/ddp_series_multigpu.rst +++ b/beginner_source/ddp_series_multigpu.rst @@ -177,8 +177,8 @@ Running the distributed training job + ddp_setup(rank, world_size) dataset, model, optimizer = load_train_objs() train_data = prepare_dataloader(dataset, batch_size=32) - - trainer = Trainer(model, dataset, optimizer, device, save_every) - + trainer = Trainer(model, dataset, optimizer, rank, save_every) + - trainer = Trainer(model, train_data, optimizer, device, save_every) + + trainer = Trainer(model, train_data, optimizer, rank, save_every) trainer.train(total_epochs) + destroy_process_group()