Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't save a model with torch.save if model has a torch.Device attr #7545

Closed
qlwang25 opened this issue May 14, 2018 · 7 comments
Closed

Can't save a model with torch.save if model has a torch.Device attr #7545

qlwang25 opened this issue May 14, 2018 · 7 comments
Assignees

Comments

@qlwang25
Copy link

If you have a question or would like help and support, please ask at our
forums.

If you are submitting a feature request, please preface the title with [feature request].
If you are submitting a bug report, please fill in the following details.

Issue description

File "/home/wangqianlong/MODELS/sentence_express_model/src/sst_main.py", line 393, in main train(train_data_set, test_data_set, parameters) File "/home/wangqianlong/MODELS/sentence_express_model/src/sst_main.py", line 283, in train save_model(model, epoch, loss, e_acc, parameters.model_dir) File "/home/wangqianlong/MODELS/sentence_express_model/src/utils.py", line 109, in save_model torch.save(model, fp) File "/home/wangqianlong/.local/lib/python3.6/site-packages/torch/serialization.py", line 161, in save return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol)) File "/home/wangqianlong/.local/lib/python3.6/site-packages/torch/serialization.py", line 118, in _with_file_like return body(f) File "/home/wangqianlong/.local/lib/python3.6/site-packages/torch/serialization.py", line 161, in <lambda> return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol)) File "/home/wangqianlong/.local/lib/python3.6/site-packages/torch/serialization.py", line 232, in _save pickler.dump(obj) TypeError: can't pickle torch.Device objects class RaceModel(nn.Module): def __init__(self, parameter): super(RaceModel, self).__init__() self.cuda_device = parameter.device self.A = Model(parameter) self.B = Model(parameter)

model = RaceModel(parameters).to(parameters.device)

Provide a short description.

Code example

Please try to provide a minimal example to repro the bug.
Error messages and stack traces are also helpful.

System Info

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • PyTorch or Caffe2:
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • OS:
  • PyTorch version:0.4
  • Python version:3.5
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • GCC version (if compiling from source):
  • CMake version:
  • Versions of any other relevant libraries:
@qlwang25 qlwang25 reopened this May 14, 2018
@zou3519
Copy link
Contributor

zou3519 commented May 14, 2018

Let's say your model is net:
What you're probably doing right now is torch.save(net, PATH).

The recommended way to save a model is torch.save(net.state_dict(), PATH) and will probably fix your problem.

@zou3519 zou3519 changed the title Why save model is failed ? Please help me. Can't save a model with torch.save if model has a torch.Device as parameter May 14, 2018
@zou3519 zou3519 changed the title Can't save a model with torch.save if model has a torch.Device as parameter Can't save a model with torch.save if model has a torch.Device attr May 14, 2018
@zou3519
Copy link
Contributor

zou3519 commented May 14, 2018

Here is a minimal example demonstrating the problem.

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.param = nn.Parameter(torch.Tensor(3, 5))
        self.device = torch.device('cpu:0')

net = Net()
torch.save(net.state_dict(), "net.pth")  # OK
torch.save(net, "net2.pth")              # errors out

@ssnl
Copy link
Collaborator

ssnl commented May 14, 2018

The real problem is that how to make device objects serializable.

@zou3519
Copy link
Contributor

zou3519 commented May 14, 2018

We can add getstate and setstate attributes

@qlwang25
Copy link
Author

thanks you, solve this problem @zou3519

@apaszke
Copy link
Contributor

apaszke commented May 21, 2018

Fixed in #7713

@leeqiaogithub
Copy link

I still encounter similar problems. How do you solve them

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants