-
Notifications
You must be signed in to change notification settings - Fork 111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Train on a different dataset #14
Comments
Hi, The model files in test folder is slightly different from the one in train folder. That is why you are getting the error. Just replace the Model.py in the test folder with the one you used for training, and it should work. A side note: Since you are using 5 classes, ESP block might not be very effective in the decoder because it projects the feature maps into low-dimensional space which is equal to number of classes. To learn more better representations, you can follow this issue where we provided a work around to work on dataset with few classes. |
Hello, @sacmehta!
I'm trying to train a neural network on my own database which consists of 5 classes.
To train the encoder, I use the command:
CUDA_VISIBLE_DEVICES=1 python3 main.py --data_dir=./DataBase --inWidth=480 --inHeight=360 --classes=5 --cached_data_file=data.p --batch_size=10
To train the decoder, I use the command:
CUDA_VISIBLE_DEVICES=1 python3 main.py --data_dir=./DataBase --inWidth=480 --inHeight=360 --classes=5 --cached_data_file=data.p --batch_size=5 --decoder=True --pretrained=./results_enc__enc_2_8_long/model_161.pth --scaleIn=1 --savedir=./results_dec_
After completing the training, I start testing the neural network:
CUDA_VISIBLE_DEVICES=1 python3 VisualizeResults.py --modelType=1 --inWidth=480 --inHeight=360 --scaleIn=1 --weightsDir=../pretrained/decoder/ --classes=5 --cityFormat=False
In the decoder folder are the weights (new) of the trained neural network (espnet_p_2_q_8.pth).
As a result, I get the following error:
RuntimeError: Error(s) in loading state_dict for ESPNet: While copying the parameter named "conv.conv.weight", whose dimensions in the model are torch.Size([5, 21, 3, 3]) and whose dimensions in the checkpoint are torch.Size([5, 24, 3, 3]).
How can I fix this error?
The text was updated successfully, but these errors were encountered: