-
Notifications
You must be signed in to change notification settings - Fork 277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Calling CTGANSynthesizer.fit() now always fits from scratch, but trained_epochs parameter is not being reset #131
Comments
Thanks for pointing this out! Yes, the "resume" functionality was removed in the refactor. @csala, any opinion on whether it makes sense to add this feature back? One concern here may be that errors may happen if you fit on some dataset and then try to fit on a different dataset. It might make sense to add this as a separate method (like |
Thanks for the quick response! I personally found the "resume training" feature useful when I was trying to determine how many epochs are necessary to build a good generative model. To this end, I sampled synthetic data from different epochs (1 sample after each 20 epochs of training) and then measured the quality of synthetic data to see where in caps/starts decreasing. If the "resume training" feature was unavailable, I'd have to rebuild from scratch every time which not only increases the number of epochs in a quadratic fashion, but also refits the GMM vectorizers every time. Regarding a potential implementation, I think we have to consider how aligned we want to be with the conventions in SKLearn going forward. If the CTGAN implementation respects them well, it may integrate well with useful tools like RandomizedSearchCV or GridSearchCV. On the other hand, SKLearn doesn't seem too consistent in the way it handles this. As you mentioned, partial_fit() seems like a good option, but so does a warm_start optional parameter in the fit() method. Something that's not typical for SKLearn is the "discrete_columns" parameter in the fit() method. Usually, vectorization and model fitting are structurally separated there with the option to be recombined via Pipelines. In any case, I agree we want to have good error handling in DataTransformer().transform() that would detect if the structure of the input data is consistent with the data used during the fitting. Let me know what you think and I may start working on it. I'm pretty new to contributing to Open Source projects, so I hope it goes smoothly :) |
Thanks for the suggestions @nrangelo1729 As @fealho mentioned, the removal of the feature was intentional to fully ensure that CTGAN is aligned with the API and behavioral standards of SDV. Apart from that, some of the things that you are suggesting, like the quality evaluation every few epochs, will actually be covered with other features such as callback functions, so having a partial fit will possibly not be necessary. In any case, I think that for now we can close this issue, since the problem specified in the title is already solved, and we can eventually open a new issue to discuss all the details and work on the partial fit explicitly. |
Environment Details
Please indicate the following details about the environment in which you found the bug:
Error Description
In previous versions, calling fit() multiple times would continue the training (data tranformer, discriminator, generator, etc were not re-instantiated at every call). Therefore, it made sense for trained_epochs to not be reset to 0 at the beginning of every fit() call.
Now, it seems that a fit() call always starts the training from scratch and so trained_epochs should also be reset.
By the way, is this the intended behavior? I think that being able to continue training an already fitted CTGAN model is a useful feature. I didn't see the removal of this feature to be documented in the release notes, which leads me to think that it may have happened by accident.
Steps to reproduce
Just call fit twice - for 10 and 5 epochs. The trained_epochs parameter would increase to 15, but the internal GAN would have undergone only 5 epochs of training.
The text was updated successfully, but these errors were encountered: