Skip to content

Commit

Permalink
modified the hogwild example to perform testing outside the different…
Browse files Browse the repository at this point in the history
… training processes (#426)
  • Loading branch information
shagunsodhani authored and soumith committed Oct 29, 2018
1 parent 502e45d commit 05ed879
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
8 changes: 7 additions & 1 deletion mnist_hogwild/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F
import torch.multiprocessing as mp

from train import train
from train import train, test

# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
Expand Down Expand Up @@ -55,7 +55,13 @@ def forward(self, x):
processes = []
for rank in range(args.num_processes):
p = mp.Process(target=train, args=(rank, args, model))
# We first train the model across `num_processes` processes
p.start()
processes.append(p)
for p in processes:
p.join()

# Once training is complete, we can test the model
test(args, model)


19 changes: 12 additions & 7 deletions mnist_hogwild/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,22 @@ def train(rank, args, model):
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, num_workers=1)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, num_workers=1)

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train_epoch(epoch, args, model, train_loader, optimizer)
test_epoch(model, test_loader)

def test(args, model):
torch.manual_seed(args.seed)

test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, num_workers=1)

test_epoch(model, test_loader)


def train_epoch(epoch, args, model, data_loader, optimizer):
Expand Down

0 comments on commit 05ed879

Please sign in to comment.