Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualizing Training Loss #21

Closed
brannondorsey opened this issue Nov 29, 2016 · 11 comments
Closed

Visualizing Training Loss #21

brannondorsey opened this issue Nov 29, 2016 · 11 comments

Comments

@brannondorsey
Copy link
Contributor

Hi there,

Is any kind of real-time training plotting visualization in the works for pix2pix?

I'm interested in visualizing Err_G, Err_D, and ErrL1 with the UI display while training with train.lua. I've only poked around with lua and torch with various machine learning projects and have yet to write much of either, although I'm happy to dig-in and try and figure something like this out if it seems like it would be rather trivial/helpful.

My initial thought would be to use display.plot(...) to update the display on each training batch. Anybody more familiar with the code base have any ideas or examples they would like to share?

P.S. Really rad paper + source code, super excited to have access to this research :) Thanks to all who are working on this!

@brannondorsey
Copy link
Contributor Author

brannondorsey commented Nov 29, 2016

green  -> ErrG
purple -> ErrD
teal   -> ErrL1

screenshot from 2016-11-29 15-32-21

I've managed to get a plot of ErrG, ErrD, and ErrL1 over ~140 epochs of training using the Facades dataset. I believe that the display is accurately plotting the error values reported to stdout via train.lua, however, I am surprised by the very sporadic nature of the errors, as well as the lack of convergence (the plot would suggest that very little changes from about the 2nd Epoch to the 140th). When testing with the model checkpoints produced during this training session I do seem to get decent visual results but am surprised that the loss values are what they are. Does anyone have a possible explanation for this?

Update: Here is the last saved output of the model that I was training while generating the above plot.

55000_train_res

@phillipi
Copy link
Owner

phillipi commented Nov 29, 2016

Thanks for implementing this! Care to share the code?

These plots look accurate. The ErrG and ErrD losses don't converge since the "loss function" is changing on every iteration. Consider school as an analogy. When the training starts, the generator is in kindergarten and the discriminator is the kindergarten teacher. The generator will work hard and start to get good grades, but then the teacher will respond by making the tests harder. A few epochs later, the generator is in first grade. Now the generator is much smarter, but the tests are also harder. If the generator produces kindergarten level work, the discriminator will give it a worse score than it did back in kindergarten. In fact, if the generator never improves, then it's test scores will go down over time. The fact that the scores don't go down is actually a sign that the generator is probably improving.

The L1 loss, on the other hand, should go down over time, and it appears it slightly does in the plot you shared. In practice, ErrL1 usually flattens out after a bit, even though the higher order structure is still improving.

@brannondorsey
Copy link
Contributor Author

No problem 👍. You can find my current implementation of this plot on the plot-loss branch of my fork. Here is the specific (and currently only) additional commit. I'd be happy to make a PR from that fork (or you could make one yourself), however, perhaps it would be better to get this functionality working in the most helpful way possible before its integrated. Also, I'm no Lua programmer, so not sure if there is a cleaner way to add this. I'm trying to be as low-impact as possible to the original code base while making my changes.

Thank you very much for the detailed description, it is a very helpful analogy! By the sounds of it it seems like perhaps the only interesting value to plot would be L1, and not ErrG or ErrD. Would you agree? How are GANs like this usually plotted to give helpful information to the trainer? Is there something else that should be plotted instead of simply L1? Personally, I would like to see accuracy plotted (per-pixel accuracy perhaps, or whatever method you deem most appropriate from your paper). I've read the paper once, but perhaps I need to give it another read to better understand what would be most helpful to see plotted here. Any pointers are also very welcome 😸!

As of now, the plot display is updated as often as logging occurs (dependent on opt.print_freq). There are also currently no checks to make sure that the errG, errD, and errL1 are not nil.

@phillipi
Copy link
Owner

phillipi commented Nov 30, 2016

Thanks for sharing the branch! Feel free to make PR when you feel it is ready, but I agree low-impact is best.

ErrG and ErrD can be interesting to plot for understanding properties of the optimization (like how much the scores oscillate), but that's right that they are not especially useful to knowing when the training has converged.

ErrL1 is useful for checking that the generator is really making progress, but it's not perfect since it doesn't measure a lot of the structure we care about. For example, it will score a blurry image as preferable to a sharp image in many cases. It would be interesting to add other error metrics, like SSIM or Gram statistics, and track their performance as training proceeds.

If we had an error metric that exactly reflected what we care about (e.g., perceptual quality), then that would be wonderful. Instead of using a GAN we could directly optimize that metric (assuming it's differentiable). Unfortunately, we don't have such a metric in general.

@ppwwyyxx
Copy link

ppwwyyxx commented Nov 30, 2016

My loss curve has a bit different nature: there is a visible progress over time.
(learning rate was decreased at 60k step)
1130-00 48 41
It still generates good results by the way.

@phillipi
Copy link
Owner

Interesting. Those curves look a bit strange but not entirely out of the ordinary. GAN training dynamics are definitely pretty weird. In my experience the training curves can take on quite a variety of shapes, depending on the dataset, hyperparameters, maybe even the random initialization.

@brannondorsey brannondorsey mentioned this issue Nov 30, 2016
@brannondorsey
Copy link
Contributor Author

@phillipi Ok, I've created a pull request with my cleaned up plotting code integrated. I've exposed the plotting functionality via the display_plot env variable. By default, it plots L1 loss, however it can also plot errG and errD with display_plot="errG,errD". I've tested the code and everything should be working correctly. Feel free to test or make edits before merging :).

I've re-read the paper and I believe things are starting to make a bit more sense to me 😺. I now understand why error between images (btw, is error equivalent to loss? Can the words be used interchangeably in the way they are talked about here, or no?) is subjective and hard to quantify. SSIM does seems like a fantastic approach and one that I would definitely like to implement and add to the plotter if I can carve out some time.

In the paper it seems like different combinations of loss functions like L1, L1+cGAN, L1+GAN were each successful (or unsuccessful) in different ways and for different training sets/objectives. I also see env variables that are parsed in train.lua which allow you to specify which type of loss function to use when training your model. What then is the errL1 loss really showing? Is it always showing L1 loss only, or is it the name of whatever loss function you've specified with your envs? E.g., if I use L1 and cGAN, which I believe occurs by default, is errL1 a representation of L1 or L1+cGAN? If it is the former, how would I get access to the L1+cGAN error value so that I could add it to the plotter?

Finally, I'm having a bit of trouble understanding the role of lambda, do you think you could briefly shed some light on the use of that environment variable for me? Thanks so much!

@brannondorsey
Copy link
Contributor Author

@ppwwyyxx great, thanks for sharing! If you are comfortable with it, would you mind sharing info about what training set/hyperparameters you are training your model with so that I can try and reproduce your plots with my display plotter? Thanks!

@ppwwyyxx
Copy link

ppwwyyxx commented Dec 1, 2016

@brannondorsey I'm not using torch to train the model, and I'm using the same hyperparameters except for batchSize=4.
I just did some experiments and found that the reason you got noisy plots is because you're only plotting the losses of the 50th, 100th, 150th,...400th image in each epoch. This is certainly very noisy and not informative.
If you simply compute a running average of the losses of all images, by loss = 0.99 * loss + 0.01 * new_loss, you'll get a similar plot as mine. (Mine actually use more smoothing than this)
Another small difference is that I use a batch size of 4 so my losses tend to be a bit smoother.

@phillipi
Copy link
Owner

What then is the errL1 loss really showing? Is it always showing L1 loss only, or is it the name of whatever loss function you've specified with your envs? E.g., if I use L1 and cGAN, which I believe occurs by default, is errL1 a representation of L1 or L1+cGAN? If it is the former, how would I get access to the L1+cGAN error value so that I could add it to the plotter?

errL1 shows the L1 loss only, it doesn't show L1+cGAN. errD and errD are the losses associated with the GAN. The total L1+cGAN error for the generator (which is really what we care about in the end) is lambda*errL1+errG.

Finally, I'm having a bit of trouble understanding the role of lambda, do you think you could briefly shed some light on the use of that environment variable for me? Thanks so much!

lambda is just the weight on the L1 objective relative to the GAN objective. As you can see above, the total objective is the weighted sum of these two terms.

@noob16
Copy link

noob16 commented Mar 12, 2018

@brannondorsey is it possible to save the loss plot somehow during training?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants