Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.nn.modules.module.ModuleAttributeError: 'TokenClassification' object has no attribute 'module' #6

Closed
mridenour7 opened this issue Mar 2, 2022 · 2 comments

Comments

@mridenour7
Copy link

mridenour7 commented Mar 2, 2022

Hello, I was getting a couple of errors when I tried finetuning (on CP tokens).
After running this: python3 finetune.py --task=melody --name=default --ckpt='pretrain_model.ckpt'
It trains for one epoch and then when it tries to save a checkpoint I get this error:
torch.nn.modules.module.ModuleAttributeError: 'TokenClassification' object has no attribute 'module'

I was also getting a separate error when finetuning on sequence classification tasks.
After running this: python3 finetune.py --task=composer --name=default --ckpt='pretrain_model.ckpt'
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I was able to fix this one by changing line 92 of finetune_trainer.py to explicitly push the attention on the GPU:
attn = (y != 0).float().to(self.device)

But I can’t figure out how to fix the first error.

@sophia1488
Copy link
Collaborator

Hi @mridenour7,
For the melody task you mentioned, I just clone the repo and run python3 finetune.py --task=melody --name=default --ckpt='pretrain_model.ckpt', it works fine here for me.
And the state_dict of saved checkpoint has keys like 'midibert.bert.embeddings.position_ids', 'midibert.in_linear.bias', 'classifier.1.weight'.

A quick fix would be to change the following line to 'state_dict': self.model.state_dict(),

'state_dict': self.model.module.state_dict(),

But I suppose the state_dict of the saved checkpoint would have keys like 'module.midibert.bert***' then, and you'd have to be careful when loading the model for evaluation or your specific task.
Something like the following,
# remove module
#from collections import OrderedDict
#new_state_dict = OrderedDict()
#for k, v in checkpoint['state_dict'].items():
# name = k[7:]
# new_state_dict[name] = v
#model.load_state_dict(new_state_dict)

Hope this helps!

@mridenour7
Copy link
Author

@sophia1488 Thank you for the quick response! That fixed my error :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants