Skip to content

Can't save a model with torch.save if model has a torch.Device attr #7545

Closed
@qlwang25

Description

@qlwang25

If you have a question or would like help and support, please ask at our
forums.

If you are submitting a feature request, please preface the title with [feature request].
If you are submitting a bug report, please fill in the following details.

Issue description

File "/home/wangqianlong/MODELS/sentence_express_model/src/sst_main.py", line 393, in main train(train_data_set, test_data_set, parameters) File "/home/wangqianlong/MODELS/sentence_express_model/src/sst_main.py", line 283, in train save_model(model, epoch, loss, e_acc, parameters.model_dir) File "/home/wangqianlong/MODELS/sentence_express_model/src/utils.py", line 109, in save_model torch.save(model, fp) File "/home/wangqianlong/.local/lib/python3.6/site-packages/torch/serialization.py", line 161, in save return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol)) File "/home/wangqianlong/.local/lib/python3.6/site-packages/torch/serialization.py", line 118, in _with_file_like return body(f) File "/home/wangqianlong/.local/lib/python3.6/site-packages/torch/serialization.py", line 161, in <lambda> return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol)) File "/home/wangqianlong/.local/lib/python3.6/site-packages/torch/serialization.py", line 232, in _save pickler.dump(obj) TypeError: can't pickle torch.Device objects class RaceModel(nn.Module): def __init__(self, parameter): super(RaceModel, self).__init__() self.cuda_device = parameter.device self.A = Model(parameter) self.B = Model(parameter)

model = RaceModel(parameters).to(parameters.device)

Provide a short description.

Code example

Please try to provide a minimal example to repro the bug.
Error messages and stack traces are also helpful.

System Info

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • PyTorch or Caffe2:
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • OS:
  • PyTorch version:0.4
  • Python version:3.5
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • GCC version (if compiling from source):
  • CMake version:
  • Versions of any other relevant libraries:

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions