Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loss going up #22

Closed
tankche1 opened this issue May 26, 2022 · 3 comments
Closed

loss going up #22

tankche1 opened this issue May 26, 2022 · 3 comments

Comments

@tankche1
Copy link

I run the following script:

torchrun --nproc_per_node=4 train.py \
--ckpt cat --load_G_only --padding_mode border --vis_every 5000 --ckpt_every 50000 \
--iter 1500000 --tv_weight 1000 --loss_fn vgg_ssl --exp-name lsun_cats --batch 10

and find the loss is going up and the transformed image learn almost nothing.
Also, it take 1.85s/iter and need 1500000 iter which cost ~220 hour. Is that normal?

Screen Shot 2022-05-26 at 12 40 22 PM

@wpeebles
Copy link
Owner

wpeebles commented May 26, 2022

Hi @tankche1, yeah this is normal behavior. The reason why this happens is because of how we construct the target images early in training. We gradually interpolate the latent code that synthesizes the target images from the input image’s code to the learned congealing vector over the first 150K gradient steps. This stabilizes training by letting the STN predict only small warps at the start of training. So basically the learning task gets harder for the STN in the early training steps, leading to the loss curve you’re seeing. Once you hit ~100K to 150K iterations, the STN will start making a lot of progress.

You don’t have to train for the full 1.5M steps unless you want to squeeze out every last bit of performance. From my experience, you get most of the benefit after about 400K to 712K steps. Make sure you use one of the checkpoints that gets saved when the learning rate hits zero. Those seem to be the best.

@tankche1
Copy link
Author

tankche1 commented May 28, 2022

Thanks! The training looks great now. One question, I find that both the transformed sample and the target sample (truncated sample) are gradually moving from the original image to the final congealing version.

If I put the original image(e.g., a white cat) into the stn and set the final congealing output (e.g., the head of a white cat) as the target in the perceptual loss, the stn can not learn the transform.

Does this mean that the congealing algorithm is based on the gradual improvement from both the stn and the latent learning embedding? Also, is the t_ema only use for visualization?

@wpeebles
Copy link
Owner

wpeebles commented May 29, 2022

Yeah, the reason you see the gradual transformation is a result of this gradual annealing of the target latent code. We have an ablation in Table 4 of the supplementary materials where we omit this annealing, and it drops PCK@0.1 of the cat model from 67% --> 59%, so it definitely makes a significant impact.

That's an interesting experiment you ran. I guess some images should be able to be successfully congealed without using gradual annealing (otherwise the ablation would probably be closer to 0% PCK :), but I don't have great intuition for the specific subset of images that it helps the most with.

At the end of training, t is effectively discarded and t_ema is the final model used for everything (that's the reason we visualize it during training, since we care about t_ema's final performance more than t's). It's an exponential moving average of t's parameters over training, which is a trick that a lot of generative models (DDPMs, GANs, etc.) use to improve performance.

Btw, as an aside, when you use your trained models at test time, I would recommend using iters=3 when calling the STN (e.g., stn(x, iters=3)). The iters argument recursively applies the similarity STN on its own output, which helps a lot for harder datasets like LSUN. If you're using the testing scripts in the applications folder, you can specify this from the command line with --iters 3. The visualizations made during training all use iters=1, so it's a lower bound on performance.

@wpeebles wpeebles closed this as completed Jun 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants