-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Closed
Description
I can specify multiple GPU IDs and train a network using them just fine. However, when it comes time to run a validation dataset through the trained network, pytorch throws an error when using a list of GPU IDs unless the GPU ID list starts with id 0.
Here's a sample function that will cause an error:
def visualize_model(model=None, num_images=5, use_gpu=False, gpu_ids=[0]):
print("Using GPU: " + str(gpu_ids[0]))
for i, data in enumerate(dset_loaders['val']):
inputs, labels = data
if use_gpu:
inputs, labels = Variable(inputs.cuda(device=gpu_ids[0])), Variable(labels.cuda(device=gpu_ids[0]))
else:
inputs, labels = Variable(inputs), Variable(labels)
outputs = model(inputs) # ERROR HERE!
_, preds = torch.max(outputs.data, 1)
plt.figure()
imshow(inputs.cpu().data[0])
plt.title('pred: {}'.format(dset_classes[labels.data[0]]))
plt.show()
if i == num_images - 1:
break
If I train the network using gpu_ids = [0,1,2]
the function above executes with no problem. However, if I train the network using gpu_ids = [1,2,3]
it will throw an error.
Metadata
Metadata
Assignees
Labels
No labels