Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError: 'unexpected key "module.conv1_1.weight" in state_dict' #30

Closed
nelaturuharsha opened this issue Nov 7, 2019 · 6 comments
Closed

Comments

@nelaturuharsha
Copy link

nelaturuharsha commented Nov 7, 2019

Traceback (most recent call last):
File "predict.py", line 18, in
model = create_model(opt)
File "/home/Documents//enlighten/EnlightenGAN/models/models.py", line 36, in create_model
model.initialize(opt)
File "/home/Documents/enlighten/EnlightenGAN/models/single_model.py", line 72, in initialize
self.load_network(self.netG_A, 'G_A', which_epoch)
File "/home/Documents/enlighten/EnlightenGAN/models/base_model.py", line 53, in load_network
network.load_state_dict(torch.load(save_path))
File "/home/anaconda3/envs/enlighten/lib/python3.5/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict
.format(name))
KeyError: 'unexpected key "module.conv1_1.weight" in state_dict'

Python: 3.5.6
PyTorch : 0.3.1
CPU

I'm facing this error when I run the test command in script/scripts.py, I would appreciate any help!

Thank you in advance!

@nelaturuharsha
Copy link
Author

nelaturuharsha commented Nov 8, 2019

Fix :

In models/base_model.py replace the load_network with the following:

def load_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
state_dict = torch.load(save_path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove module.
new_state_dict[name] = v
network.load_state_dict(new_state_dict)

@Dear-Mr
Copy link

Dear-Mr commented Nov 11, 2019

您好,在经过上述修改后,
并将 new_state_dict = OrderedDict() 修改为 new_state_dict = collections.OrderedDict() 后,
出现KeyError: 'unexpected key "conv10.bias" in state_dict',该如何解决?

@nelaturuharsha
Copy link
Author

Is your python 3.5, pytorch 0.3.1? And are you trying to run via CPU or GPU?

@yifanjiang19
Copy link
Collaborator

yifanjiang19 commented Nov 11, 2019

@Dear-Mr @SreeHarshaNelaturu
This problem is about torch.nn.DataParallel(). Normally DataParallel() will cover the original model, which means you can only use model.module to access the original model. And also, if you directly save the model covered by DataParallel, the key of saved model will be module.conv.weight instead of conv.weight.
You can simply use DataParallel to cover the model before you load the pre-trained model to solve this problem.

Eg.

model = ResNet()
parallel_model = torch.nn.DataParallel(model)
print(parallel_model.module)
# Here parallel_model.module is equal to model

@yifanjiang19
Copy link
Collaborator

I've pushed a new version. You can directly pull the new version to avoid this problem.

@Dear-Mr
Copy link

Dear-Mr commented Nov 12, 2019

Is your python 3.5, pytorch 0.3.1? And are you trying to run via CPU or GPU?

Thanks for your reply.
My workspace is python 3.5, PyTorch 0.3.1, torchvision 0.2.0. And i tried to run via GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants