-
Notifications
You must be signed in to change notification settings - Fork 4.2k
chatbot_tutorial.py: Solve the optimizer cuda call problem #577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar: ``` encoder_optimizer.step() ``` Error message: ``` exp_avg.mul_(beta1).add_(1 - beta1, grad) RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float ``` Fix it: pytorch/pytorch#2830 ``` with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loader: images = images.to(device) # missing line from original code labels = labels.to(device) # missing line from original code images = images.reshape(-1, 28 * 28) out = model(images) _, predicted = torch.max(out.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() ```
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar: ``` encoder_optimizer.step() ``` Error message: ``` exp_avg.mul_(beta1).add_(1 - beta1, grad) RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float ``` Fix it: pytorch/pytorch#2830 ``` model = Model() model.load_state_dict(checkpoint['model']) model.cuda() optimizer = optim.Adam(model.parameters()) optimizer.load_state_dict(checkpoint['optimizer']) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() ```
Deploy preview for pytorch-tutorials-preview ready! Built with commit 3e1613d https://deploy-preview-577--pytorch-tutorials-preview.netlify.com |
Deploy preview for pytorch-tutorials-preview ready! Built with commit a473681 https://deploy-preview-577--pytorch-tutorials-preview.netlify.com |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the tutorial -- this looks good to me!
@brianjo -- can you merge this? |
Sure thing. It might be a while as we need to merge some other fixes first. Thanks! |
Cool ! I am so happy to help you a little! |
Thank you!!! @jiangzhonglian |
chatbot_tutorial.py: Solve the optimizer cuda call problem
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar:
Error message:
Fix it: pytorch/pytorch#2830