Skip to content

Commit af4f9d0

Browse files
authored
Shard test dataset, reduce accuracies epoch end. (#1845)
* Shard test dataset, reduce accuracies epoch end. * Rendezvous name shortened to test_accuracy
1 parent 2bec7ad commit af4f9d0

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

test/test_train_mp_imagenet.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
import os
4040
import schedulers
41+
import numpy as np
4142
import torch
4243
import torch.nn as nn
4344
import torch.nn.functional as F
@@ -80,16 +81,17 @@
8081
if getattr(FLAGS, arg) is None:
8182
setattr(FLAGS, arg, value)
8283

84+
8385
def get_model_property(key):
8486
default_model_property = {
85-
'img_dim': 224,
86-
'model_fn': getattr(torchvision.models, FLAGS.model)
87+
'img_dim': 224,
88+
'model_fn': getattr(torchvision.models, FLAGS.model)
8789
}
8890
model_properties = {
89-
'inception_v3': {
90-
'img_dim': 299,
91-
'model_fn': lambda: torchvision.models.inception_v3(aux_logits=False)
92-
},
91+
'inception_v3': {
92+
'img_dim': 299,
93+
'model_fn': lambda: torchvision.models.inception_v3(aux_logits=False)
94+
},
9395
}
9496
model_fn = model_properties.get(FLAGS.model, default_model_property)[key]
9597
return model_fn
@@ -139,13 +141,18 @@ def train_imagenet():
139141
normalize,
140142
]))
141143

142-
train_sampler = None
144+
train_sampler, test_sampler = None, None
143145
if xm.xrt_world_size() > 1:
144146
train_sampler = torch.utils.data.distributed.DistributedSampler(
145147
train_dataset,
146148
num_replicas=xm.xrt_world_size(),
147149
rank=xm.get_ordinal(),
148150
shuffle=True)
151+
test_sampler = torch.utils.data.distributed.DistributedSampler(
152+
test_dataset,
153+
num_replicas=xm.xrt_world_size(),
154+
rank=xm.get_ordinal(),
155+
shuffle=False)
149156
train_loader = torch.utils.data.DataLoader(
150157
train_dataset,
151158
batch_size=FLAGS.batch_size,
@@ -156,6 +163,7 @@ def train_imagenet():
156163
test_loader = torch.utils.data.DataLoader(
157164
test_dataset,
158165
batch_size=FLAGS.test_set_batch_size,
166+
sampler=test_sampler,
159167
drop_last=FLAGS.drop_last,
160168
shuffle=False,
161169
num_workers=FLAGS.num_workers)
@@ -206,13 +214,13 @@ def test_loop_fn(loader, epoch):
206214
for step, (data, target) in enumerate(loader):
207215
output = model(data)
208216
pred = output.max(1, keepdim=True)[1]
209-
correct += pred.eq(target.view_as(pred)).sum().item()
217+
correct += pred.eq(target.view_as(pred)).sum()
210218
total_samples += data.size()[0]
211219
if step % FLAGS.log_steps == 0:
212220
xm.add_step_closure(
213221
test_utils.print_test_update, args=(device, None, epoch, step))
214-
accuracy = 100.0 * correct / total_samples
215-
test_utils.print_test_update(device, accuracy=accuracy, epoch=epoch)
222+
accuracy = 100.0 * correct.item() / total_samples
223+
accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
216224
return accuracy
217225

218226
accuracy, max_accuracy = 0.0, 0.0
@@ -223,7 +231,8 @@ def test_loop_fn(loader, epoch):
223231
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
224232
para_loader = pl.ParallelLoader(test_loader, [device])
225233
accuracy = test_loop_fn(para_loader.per_device_loader(device), epoch)
226-
xm.master_print('Epoch {} test end {}'.format(epoch, test_utils.now()))
234+
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
235+
epoch, test_utils.now(), accuracy))
227236
max_accuracy = max(accuracy, max_accuracy)
228237
test_utils.write_to_summary(
229238
writer,

test/test_train_mp_mnist.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import os
1212
import shutil
1313
import sys
14-
import time
14+
import numpy as np
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
@@ -136,15 +136,14 @@ def test_loop_fn(loader):
136136
for data, target in loader:
137137
output = model(data)
138138
pred = output.max(1, keepdim=True)[1]
139-
correct += pred.eq(target.view_as(pred)).sum().item()
139+
correct += pred.eq(target.view_as(pred)).sum()
140140
total_samples += data.size()[0]
141141

142-
accuracy = 100.0 * correct / total_samples
143-
test_utils.print_test_update(device, accuracy)
142+
accuracy = 100.0 * correct.item() / total_samples
143+
accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
144144
return accuracy
145145

146-
accuracy = 0.0
147-
max_accuracy = 0.0
146+
accuracy, max_accuracy = 0.0, 0.0
148147
for epoch in range(1, FLAGS.num_epochs + 1):
149148
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
150149
para_loader = pl.ParallelLoader(train_loader, [device])
@@ -153,7 +152,8 @@ def test_loop_fn(loader):
153152

154153
para_loader = pl.ParallelLoader(test_loader, [device])
155154
accuracy = test_loop_fn(para_loader.per_device_loader(device))
156-
xm.master_print('Epoch {} test end {}'.format(epoch, test_utils.now()))
155+
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
156+
epoch, test_utils.now(), accuracy))
157157
max_accuracy = max(accuracy, max_accuracy)
158158
test_utils.write_to_summary(
159159
writer,

0 commit comments

Comments
 (0)