Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdv helper function for generating generator-discriminator loss charts #1828

Closed
srinify opened this issue Mar 4, 2024 · 2 comments · Fixed by #1868
Closed

sdv helper function for generating generator-discriminator loss charts #1828

srinify opened this issue Mar 4, 2024 · 2 comments · Fixed by #1868
Assignees
Labels
feature:evaluation Related to running metrics or visualizations feature request Request for a new feature
Milestone

Comments

@srinify
Copy link
Contributor

srinify commented Mar 4, 2024

Problem Description

When working with GAN models in sdv, visualizing the loss for the generator & discriminator help you understand GAN model performance and what experiment to try next.

Creating a good looking generator-discriminator loss chart requires you to import a viz library, plot the dataframe of loss values, and tweak a bunch of chart settings. Example of one such chart from the interpreting progress of CTGANs post

image

This was the code needed to make this chart:

import plotly.express as px

# Tidy up the loss values data
loss_df = synthesizer.get_loss_values()
loss_df['Generator Loss'] = loss_df['Generator Loss'].apply(lambda x: x.item())
loss_df['Discriminator Loss'] = loss_df['Discriminator Loss'].apply(lambda x: x.item())

# Create a pretty chart using Plotly Express
fig = px.line(loss_df, x='Epoch', y=['Generator Loss', 'Discriminator Loss'])
fig.update_layout(template='plotly_white',legend_title_text='', legend_orientation="v", legend=dict(x=1.1, y=0.3))
title = 'CTGAN loss function for Census dataset'
fig.update_layout(title=title, xaxis_title='Epoch', yaxis_title='Loss')
fig.show()

Suggested Improvement

It would be great if there was a function in sdv itself that did all this for the user. We've done this for other evaluation plots, like so: https://docs.sdv.dev/sdv/single-table-data/evaluation/visualization

@srinify srinify added feature request Request for a new feature feature:evaluation Related to running metrics or visualizations new Automatic label applied to new issues labels Mar 4, 2024
@npatki npatki removed the new Automatic label applied to new issues label Mar 4, 2024
@npatki
Copy link
Contributor

npatki commented Mar 5, 2024

API

CTGANSynthesizer.get_loss_values_plot(): Use this function on a fitted CTGAN synthesizer to plot the generator and discriminator loss values. Under-the-hood, this will use the data from CTGANSynthesizer.get_loss_values().

Parameters: None
Output: A plotly.Figure object containing a line plot showing the generator and discriminator loss values per epoch that was trained.

Error State: If the CTGANSynthesizer has not been fitted yet, raise an error. This should be the same error that is raised if using get_loss_values() before fitting.

UX: Align the design/colors with SDMetrics visualizations. This includes all elements such as:

  • Line colors (use green/blue)
  • Plot background (should be a gray)
  • Font sizes should large (for titles, legend, x/y-axis, etc.)

@shahenoor
Copy link

Hello, is there any similar function for CopulaGAN to plot the generator and discriminator loss values?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature:evaluation Related to running metrics or visualizations feature request Request for a new feature
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants