Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors when trying to resume distilling. #11

Closed
jvillegassmule opened this issue Apr 5, 2022 · 3 comments
Closed

Errors when trying to resume distilling. #11

jvillegassmule opened this issue Apr 5, 2022 · 3 comments

Comments

@jvillegassmule
Copy link

jvillegassmule commented Apr 5, 2022

  1. I have the image pairs for training ready on the /local_path dir.
    After successfully train the teacher with the command line:
!python train.py --dataroot /local_path \
  --model pix2pix \
  --log_dir /local_path/logs/teacher \
  --netG inception_9blocks \
  --lambda_recon 10 \
  --nepochs 500 --nepochs_decay 1000 \
  --norm batch \
  --norm_affine \
  --norm_affine_D \
  --norm_track_running_stats \
  --channels_reduction_factor 6 \
  --preprocess none \
  --kernel_sizes 1 3 5 \
  --save_epoch_freq 50 --save_latest_freq 20000 \
  --direction AtoB \
  --real_stat_path /local_path/out_stat.npz
  1. I got a foder full of model checkpoints on /local_path/checkpoints. I was able to resume the train session with something like:
!python train.py --dataroot /local_path \
  --model pix2pix \
  --log_dir /local_path/logs/teacher \
  --netG inception_9blocks \
  --lambda_recon 10 \
  --nepochs 0 --nepochs_decay 750 \
  --norm batch \
  --norm_affine \
  --norm_affine_D \
  --norm_track_running_stats \
  --channels_reduction_factor 6 \
  --preprocess none \
  --kernel_sizes 1 3 5 \
  --save_epoch_freq 50 --save_latest_freq 20000 \
  --direction AtoB \
  --real_stat_path /local_path/out_stat.npz \
  --epoch_base 750 \
  --iter_base 300001 \
  --restore_G_path /local_path/logs/teacher/checkpoints/latest_net_G.pth \
  --restore_D_path /local_path/logs/teacher/checkpoints/latest_net_D.pth

After training, the results on the eval/(it_number)/fake folder are acceptable.

  1. Then, I was able to run the distiller with the command:
 !python distill.py --dataroot /local_path \
  --distiller inception \
  --log_dir /local_path/logs/student \
  --restore_teacher_G_path /local_path/logs/teacher/checkpoints/best_net_G.pth \
  --restore_pretrained_G_path /local_path/logs/teacher/checkpoints/best_net_G.pth \
  --restore_D_path /local_path/logs/teacher/checkpoints/best_net_D.pth \
  --real_stat_path /local_path/out_stat.npz \
  --nepochs 500 --nepochs_decay 750 \
  --save_latest_freq 25000 --save_epoch_freq 25 \
  --teacher_netG inception_9blocks --student_netG inception_9blocks \
  --pretrained_ngf 64 --teacher_ngf 64 --student_ngf 24 \
  --eval_batch_size 2 \
  --gpu_ids 0 \
  --norm batch \
  --norm_affine \
  --norm_affine_D \
  --norm_track_running_stats \
  --channels_reduction_factor 6 \
  --kernel_sizes 1 3 5 \
  --direction AtoB \
  --lambda_distill 2.0 \
  --prune_cin_lb 16 \
  --target_flops 2.6e9 \
  --distill_G_loss_type ka

I had to stop the session before it finished, again different checkpoint models were saved on the folder /local_path/logs/student/checkpoints. Including pth files for G,D, optim-0,optim-1,A-0,A1,A2 and A3
Progress seems OK on the local_path/logs/student/eval folder

  1. I tried to resume distilling with the command line:
!python distill.py --dataroot /local_path \
  --distiller inception \
  --log_dir /local_path/logs/student \
  --restore_teacher_G_path /local_path/logs/teacher/checkpoints/best_net_G.pth \
  --restore_pretrained_G_path /local_path/logs/student/checkpoints/latest_net_G.pth \
  --restore_D_path /local_path/logs/student/checkpoints/latest_net_D.pth \
  --restore_student_G_path /local_path/logs/student/checkpoints/latest_net_G.pth\
  --pretrained_student_G_path /local_path/logs/student/checkpoints/latest_net_G.pth\
  --restore_A_path /local_path/logs/student/checkpoints/latest_net_A \
  --restore_O_path /local_path/logs/student/checkpoints/latest_optim \
  --real_stat_path /local_path/out_stat.npz \
  --nepochs 0 --nepochs_decay 325 \
  --save_latest_freq 25000 --save_epoch_freq 25 \
  --teacher_netG inception_9blocks --student_netG inception_9blocks \
  --pretrained_ngf 64 --teacher_ngf 64 --student_ngf 24 \
  --eval_batch_size 2 \
  --gpu_ids 0 \
  --norm batch \
  --norm_affine \
  --norm_affine_D \
  --norm_track_running_stats \
  --channels_reduction_factor 6 \
  --kernel_sizes 1 3 5 \
  --direction AtoB \
  --lambda_distill 2.0 \
  --prune_cin_lb 16 \
  --target_flops 2.6e9 \
  --distill_G_loss_type ka \
  --epoch_base 925 \
  --iter_base 370000

But now I get this error:

Load network at /local_path/logs/student/checkpoints/latest_net_G.pth
Traceback (most recent call last):
  File "distill.py", line 13, in <module>
    trainer = Trainer('distill')
  File "/content/CAT/trainer.py", line 80, in __init__
    model.setup(opt)
  File "/content/CAT/distillers/base_inception_distiller.py", line 260, in setup
    self.load_networks(verbose)
  File "/content/CAT/distillers/inception_distiller.py", line 203, in load_networks
    super(InceptionDistiller, self).load_networks()
  File "/content/CAT/distillers/base_inception_distiller.py", line 368, in load_networks
    self.opt.restore_student_G_path, verbose)
  File "/content/CAT/utils/util.py", line 139, in load_network
    net.load_state_dict(weights)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for InceptionGenerator:
	Missing key(s) in state_dict: "down_sampling.1.bias", "down_sampling.2.weight", "down_sampling.2.bias", "down_sampling.2.running_mean", "down_sampling.2.running_var", "down_sampling.2.num_batches_tracked", "down_sampling.4.bias", "down_sampling.5.weight", "down_sampling.5.bias", "down_sampling.5.running_mean", "down_sampling.5.running_var", "down_sampling.5.num_batches_tracked", "down_sampling.7.bias"

... lots of other missing layers, then

	size mismatch for down_sampling.1.weight: copying a param with shape torch.Size([16, 3, 7, 7]) from checkpoint, the shape in current model is torch.Size([24, 3, 7, 7]).
	size mismatch for down_sampling.4.weight: copying a param with shape torch.Size([16, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([48, 24, 3, 3]).
	size mismatch for down_sampling.7.weight: copying a param with shape torch.Size([210, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([96, 48, 3, 3]).

... lots of other size mismatches.

Seems to me that there is a mismatch between the network that was created internally and the one that is being used to fill it with the previously trained model. Not sure if it is a bug or if something is wrong in the command line I am using to resume.

Any help will be appreciated.

@alanspike
Copy link
Collaborator

Hi @jvillegassmule, thanks for your interest in the work, and sorry for the late reply! There is a bug in the finetuning stage for the student model. May you please try this branch and let me know if you have any questions? Thanks.

@jvillegassmule
Copy link
Author

The fix is working! Thanks a lot!

@alanspike
Copy link
Collaborator

Glad to know! I'll close the issue but please let me know if you need other help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants