3838
3939import os
4040import schedulers
41+ import numpy as np
4142import torch
4243import torch .nn as nn
4344import torch .nn .functional as F
8081 if getattr (FLAGS , arg ) is None :
8182 setattr (FLAGS , arg , value )
8283
84+
8385def get_model_property (key ):
8486 default_model_property = {
85- 'img_dim' : 224 ,
86- 'model_fn' : getattr (torchvision .models , FLAGS .model )
87+ 'img_dim' : 224 ,
88+ 'model_fn' : getattr (torchvision .models , FLAGS .model )
8789 }
8890 model_properties = {
89- 'inception_v3' : {
90- 'img_dim' : 299 ,
91- 'model_fn' : lambda : torchvision .models .inception_v3 (aux_logits = False )
92- },
91+ 'inception_v3' : {
92+ 'img_dim' : 299 ,
93+ 'model_fn' : lambda : torchvision .models .inception_v3 (aux_logits = False )
94+ },
9395 }
9496 model_fn = model_properties .get (FLAGS .model , default_model_property )[key ]
9597 return model_fn
@@ -139,13 +141,18 @@ def train_imagenet():
139141 normalize ,
140142 ]))
141143
142- train_sampler = None
144+ train_sampler , test_sampler = None , None
143145 if xm .xrt_world_size () > 1 :
144146 train_sampler = torch .utils .data .distributed .DistributedSampler (
145147 train_dataset ,
146148 num_replicas = xm .xrt_world_size (),
147149 rank = xm .get_ordinal (),
148150 shuffle = True )
151+ test_sampler = torch .utils .data .distributed .DistributedSampler (
152+ test_dataset ,
153+ num_replicas = xm .xrt_world_size (),
154+ rank = xm .get_ordinal (),
155+ shuffle = False )
149156 train_loader = torch .utils .data .DataLoader (
150157 train_dataset ,
151158 batch_size = FLAGS .batch_size ,
@@ -156,6 +163,7 @@ def train_imagenet():
156163 test_loader = torch .utils .data .DataLoader (
157164 test_dataset ,
158165 batch_size = FLAGS .test_set_batch_size ,
166+ sampler = test_sampler ,
159167 drop_last = FLAGS .drop_last ,
160168 shuffle = False ,
161169 num_workers = FLAGS .num_workers )
@@ -206,13 +214,13 @@ def test_loop_fn(loader, epoch):
206214 for step , (data , target ) in enumerate (loader ):
207215 output = model (data )
208216 pred = output .max (1 , keepdim = True )[1 ]
209- correct += pred .eq (target .view_as (pred )).sum (). item ()
217+ correct += pred .eq (target .view_as (pred )).sum ()
210218 total_samples += data .size ()[0 ]
211219 if step % FLAGS .log_steps == 0 :
212220 xm .add_step_closure (
213221 test_utils .print_test_update , args = (device , None , epoch , step ))
214- accuracy = 100.0 * correct / total_samples
215- test_utils . print_test_update ( device , accuracy = accuracy , epoch = epoch )
222+ accuracy = 100.0 * correct . item () / total_samples
223+ accuracy = xm . mesh_reduce ( 'test_accuracy' , accuracy , np . mean )
216224 return accuracy
217225
218226 accuracy , max_accuracy = 0.0 , 0.0
@@ -223,7 +231,8 @@ def test_loop_fn(loader, epoch):
223231 xm .master_print ('Epoch {} train end {}' .format (epoch , test_utils .now ()))
224232 para_loader = pl .ParallelLoader (test_loader , [device ])
225233 accuracy = test_loop_fn (para_loader .per_device_loader (device ), epoch )
226- xm .master_print ('Epoch {} test end {}' .format (epoch , test_utils .now ()))
234+ xm .master_print ('Epoch {} test end {}, Accuracy={:.2f}' .format (
235+ epoch , test_utils .now (), accuracy ))
227236 max_accuracy = max (accuracy , max_accuracy )
228237 test_utils .write_to_summary (
229238 writer ,
0 commit comments