Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference with 4-channel model #56

Closed
SuzannaLin opened this issue Mar 4, 2022 · 7 comments
Closed

inference with 4-channel model #56

SuzannaLin opened this issue Mar 4, 2022 · 7 comments

Comments

@SuzannaLin
Copy link
Contributor

Hi Yassine!
I have managed to train a model with 4 channels, but the inference is not working. I get this error message:

!python inference.py --config configs/config_70_30_sup_alti.json --model './saved/ABCE_70_30_sup_alti/best_model.pth' --output 'CCT_output/ABCE_70_30_sup_alti/Angers/' --images 'val/Angers/BDORTHO'

Loading pretrained model:models/backbones/pretrained/3x3resnet50-imagenet.pth
Traceback (most recent call last):
File "inference.py", line 155, in
main()
File "inference.py", line 102, in main
conf=config['model'], testing=True, pretrained = True)
File "/home/scuypers/CCT_4/models/model.py", line 55, in init
self.encoder = Encoder(pretrained=pretrained)
File "/home/scuypers/CCT_4/models/encoder.py", line 49, in init
model = ResNetBackbone(backbone='deepbase_resnet50_dilated8', pretrained=pretrained)
File "/home/scuypers/CCT_4/models/backbones/resnet_backbone.py", line 145, in ResNetBackbone
orig_resnet = deepbase_resnet50(pretrained=pretrained)
File "/home/scuypers/CCT_4/models/backbones/resnet_models.py", line 227, in deepbase_resnet50
model = ModuleHelper.load_model(model, pretrained=pretrained)
File "/home/scuypers/CCT_4/models/backbones/module_helper.py", line 109, in load_model
model.load_state_dict(load_dict)
File "/home/scuypers/.conda/envs/envCCT/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ResNet:
size mismatch for prefix.conv1.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 4, 3, 3]).

@yassouali
Copy link
Owner

Hi @SuzannaLin

It seems from the error that the check point model only has 3 channels and not 4, maybe a double-check of the checkpoint?

@SuzannaLin
Copy link
Contributor Author

You mean the checkpoint in the saved/ folder? Where in the script is this loaded?

@yassouali
Copy link
Owner

the checkpoint you passed: ABCE_70_30_sup_alti/best_model.pth, either it is not loaded or it is not the current one since in the error you have copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 4, 3, 3]).

@SuzannaLin
Copy link
Contributor Author

SuzannaLin commented Mar 4, 2022

The error occurs here in the inference script:
model = models.CCT(num_classes=num_classes, conf=config['model'], testing=True, pretrained = True)
The checkpoint is only loaded after, right?

@yassouali
Copy link
Owner

yassouali commented Mar 4, 2022

yes, i think there is just a mismatch between the created model and the checkpoint, one of them still has only 3 channels

@SuzannaLin
Copy link
Contributor Author

I found the cause of the problem. Since I have a 4-channel model, I set pretrained: False in the config file. In model.py the default value for pretrained is True, and thus the weights of ResNet will be loaded. So we have to pass the pretrained value when defining the model

model = models.CCT(num_classes=num_classes, conf=config['model'],pretrained=config['pretrained'],testing=True)

@yassouali
Copy link
Owner

@SuzannaLin thanks you posting the answer, glad it is working now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants