-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensorflow conflicts with nn.DataParallel #2230
Comments
What about setting the environt variable |
|
Can you print |
oops, it moves to another GPU after TF initialize!
|
It's too bad that Tensorflow changes the current device, but this is the expected PyTorch behavior. The model must be on Or, just set the current device after the TensorFlow call: config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(emb.initializer)
torch.cuda.set_device(0) # set the device back to 0
model = torch.nn.Linear(128, 1).cuda()
model = torch.nn.DataParallel(model).cuda()
data = Variable(torch.Tensor(8,128)).cuda()
x = model(data) |
Environment: 2 GTX1080 GPU
minimal reproducible code:
error message:
By removing
model = torch.nn.DataParallel(model).cuda()
orsess.run
the code works fine.The text was updated successfully, but these errors were encountered: