Skip to content

Commit

Permalink
fix test_models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yjxiong committed Aug 16, 2017
1 parent 357c863 commit 93b3441
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion test_models.py
Expand Up @@ -83,7 +83,13 @@
batch_size=1, shuffle=False,
num_workers=args.workers * 2, pin_memory=True)

net = torch.nn.DataParallel(net, device_ids=list(range(args.workers)))
if args.gpus is not None:
devices = [args.gpus[i] for i in range(args.workers)]
else:
devices = list(range(args.workers))


net = torch.nn.DataParallel(net, device_ids=devices)
net.cuda()
net.eval()

Expand Down

0 comments on commit 93b3441

Please sign in to comment.