New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inference with 4-channel model #56
Comments
Hi @SuzannaLin It seems from the error that the check point model only has 3 channels and not 4, maybe a double-check of the checkpoint? |
You mean the checkpoint in the saved/ folder? Where in the script is this loaded? |
the checkpoint you passed: |
The error occurs here in the inference script: |
yes, i think there is just a mismatch between the created model and the checkpoint, one of them still has only 3 channels |
I found the cause of the problem. Since I have a 4-channel model, I set pretrained: False in the config file. In model.py the default value for pretrained is True, and thus the weights of ResNet will be loaded. So we have to pass the pretrained value when defining the model
|
@SuzannaLin thanks you posting the answer, glad it is working now |
Hi Yassine!
I have managed to train a model with 4 channels, but the inference is not working. I get this error message:
!python inference.py --config configs/config_70_30_sup_alti.json --model './saved/ABCE_70_30_sup_alti/best_model.pth' --output 'CCT_output/ABCE_70_30_sup_alti/Angers/' --images 'val/Angers/BDORTHO'
Loading pretrained model:models/backbones/pretrained/3x3resnet50-imagenet.pth
Traceback (most recent call last):
File "inference.py", line 155, in
main()
File "inference.py", line 102, in main
conf=config['model'], testing=True, pretrained = True)
File "/home/scuypers/CCT_4/models/model.py", line 55, in init
self.encoder = Encoder(pretrained=pretrained)
File "/home/scuypers/CCT_4/models/encoder.py", line 49, in init
model = ResNetBackbone(backbone='deepbase_resnet50_dilated8', pretrained=pretrained)
File "/home/scuypers/CCT_4/models/backbones/resnet_backbone.py", line 145, in ResNetBackbone
orig_resnet = deepbase_resnet50(pretrained=pretrained)
File "/home/scuypers/CCT_4/models/backbones/resnet_models.py", line 227, in deepbase_resnet50
model = ModuleHelper.load_model(model, pretrained=pretrained)
File "/home/scuypers/CCT_4/models/backbones/module_helper.py", line 109, in load_model
model.load_state_dict(load_dict)
File "/home/scuypers/.conda/envs/envCCT/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ResNet:
size mismatch for prefix.conv1.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 4, 3, 3]).
The text was updated successfully, but these errors were encountered: