In [None]:
# Code source: Sebastian Curi and Andreas Krause, based on Jaques Grobler (sklearn demos).
# License: BSD 3 clause

# We start importing some modules and running some magic commands
%matplotlib inline
%reload_ext autoreload
%load_ext autoreload
%autoreload 2

# General math and plotting modules.
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.special import erfinv
from scipy import linalg
from scipy.stats import multivariate_normal, norm

# Project files.
from utilities.util import gradient_descent
from utilities.classifiers import Logistic
from utilities.regularizers import L2Regularizer
from utilities.load_data import polynomial_data, linear_separable_data
from utilities import plot_helpers
from utilities.widgets import noise_widget, n_components_widget, min_prob_widget

# Widget and formatting modules
import IPython
import ipywidgets
from ipywidgets import interact, interactive, interact_manual, fixed
from matplotlib import rcParams
import matplotlib as mpl 

# If in your browser the figures are not nicely vizualized, change the following line. 
rcParams['figure.figsize'] = (10, 5)
rcParams['font.size'] = 16

# Machine Learning library. 
import torch 
import torch.nn as nn 


import warnings
warnings.filterwarnings("ignore")


# GMM Generative Model

In [None]:
class GMM(object):
    def __init__(self, weights, means, scales):
        self.num_centers = len(weights)
        self.weights = weights / np.sum(weights)
        self.means = means
        self.scales = scales 
    
    def sample(self, batch_size=1):
        centers = np.random.choice(self.num_centers, batch_size, p=self.weights)
        eps = np.random.randn(batch_size)
        return self.means[centers] + eps * self.scales[centers]

In [None]:
def plot_gmm(true_model, sampling_model,  title):
    gaussians = [norm(mean, scale) for mean, scale in zip(true_model.means, true_model.scales)]
    scale = sum(true_model.weights)
    fig, ax = plt.subplots(1, 1)
    X = np.linspace(-1.25, 1.25, 1000)
    y = np.zeros_like(X)
    for i, (weight, gaussian) in enumerate(zip(true_model.weights, gaussians)):
        y += weight * gaussian.pdf(X) / scale

    ax.plot(X, y, label='Exact PDF')
    
    ax.hist(sampling_model.sample(10000), bins=100, density=True, label='Samples')
    ax.legend(loc='best')
    ax.set_xlim([-1.25, 1.25])
    ax.set_title(title)
    IPython.display.clear_output(wait=True)
    IPython.display.display(fig)
    plt.close()


# GAN Architecture

In [None]:
class Generator(nn.Module):
    """Given a random input, produce a random output."""

    def __init__(self, input_dim: int, output_dim: int, noise='uniform'):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        
        self.noise = noise
        print(noise)
        
        self.main = nn.Sequential(
            nn.Linear(input_dim, 15),
            nn.ReLU(True),
            nn.Linear(15, output_dim),
            nn.Tanh()  # Distribution is bounded between -1 and 1.
        )

    def forward(self, x):
        return self.main(x)

    def rsample(self, batch_size=1):
        """Get a differentiable sample of the generator model."""
        if self.noise == 'uniform':
            noise = torch.rand(batch_size, self.input_dim)

        elif self.noise == 'normal':
            noise = torch.randn(batch_size, self.input_dim)
        else:
            raise NotImplementedError
            
        return self(noise).squeeze(-1)

    def sample(self, batch_size=1):
        """Get a sample of the generator model."""
        return self.rsample(batch_size).detach()


class Discriminator(nn.Module):
    """Discriminate if true from fake samples."""

    def __init__(self, input_dim: int):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(input_dim, 25),
            nn.ReLU(True),
            nn.Linear(25, 1),
            nn.Sigmoid()  # Output is bounded between 0 and 1.
        )

    def forward(self, x):
        return self.main(x)


# GAN Training Algorithm

In [None]:
def train_gan(generator, discriminator, true_model, generator_optimizer, discriminator_optimizer, 
              num_iter, discriminator_loss, generator_loss, batch_size=64):
    loss = nn.BCELoss()
    for i in range(num_iter):
        true_data = torch.tensor(true_model.sample(batch_size)).float().unsqueeze(-1)
        fake_data = generator.rsample(batch_size).unsqueeze(-1)
        # equivalently, fake_data = generator(torch.randn(batch_size, code_size)).squeeze()

        true_label = torch.full((batch_size,), 1.)
        fake_label = torch.full((batch_size,), 0.)

        ###################################################################################
        # Update G network: maximize log(D(G(z)))                                         #
        ###################################################################################
        generator_optimizer.zero_grad()
        loss_g = loss(discriminator(fake_data).squeeze(-1), true_label)  # true label.
        loss_g.backward()
        generator_optimizer.step()

        generator_loss.append(loss_g.item())

        ###################################################################################
        # Update D network: maximize log(D(x)) + log(1 - D(G(z)))                         #
        ###################################################################################
        discriminator_optimizer.zero_grad()

        # train on true data.
        loss_d_true = loss(discriminator(true_data).squeeze(-1), true_label)
        loss_d_true.backward()

        # train on fake data.
        loss_d_fake = loss(discriminator(fake_data.detach()).squeeze(-1), fake_label)
        loss_d_fake.backward()

        discriminator_optimizer.step()

        loss_d = loss_d_true + loss_d_fake
        discriminator_loss.append(loss_d.item())

        if i % 1000 == 0:
            ax = plot_gmm(true_model, generator, f"Episode {i}")
    
    return discriminator_loss, generator_loss


def train_gan_interactive(num_iter, true_model, noise_model, noise_dim, generator_lr, discriminator_lr):
    torch.manual_seed(0)
    np.random.seed(0)
    generator = Generator(input_dim=noise_dim, output_dim=1, noise=noise_model)
    discriminator = Discriminator(input_dim=1)
    
    generator_optimizer = torch.optim.Adam(generator.parameters(), lr=generator_lr, betas=(0.5, 0.999))
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=discriminator_lr, betas=(0.5, 0.99))

    discriminator_loss, generator_loss = [], []
    try:
        train_gan(generator, discriminator, true_model, generator_optimizer, discriminator_optimizer, num_iter, discriminator_loss, generator_loss)
    except KeyboardInterrupt:
        pass

    plot_gmm(true_model, generator, "Final Generator Model")
    plt.plot(generator_loss, label='Generator Loss')
    plt.plot(discriminator_loss, label='Discriminator Loss')
    plt.xlabel('Iteration Number')
    plt.ylabel(' Loss')
    plt.legend(loc='best')
    plt.show()

# GAN's for fitting a Gaussian

In [None]:
rcParams['figure.figsize'] = (20, 8)
rcParams['font.size'] = 16

gaussian_model = GMM(weights=np.array([1.]),means=np.array([0.5]), scales=np.array([0.2])) 
plot_gmm(gaussian_model, gaussian_model, 'Exact Model')

In [None]:
rcParams['figure.figsize'] = (20, 8)
rcParams['font.size'] = 16
num_iter = 15000
interact_manual(lambda noise_model, noise_dim, generator_lr, discriminator_lr: train_gan_interactive(
    num_iter, gaussian_model, noise_model, noise_dim, generator_lr, discriminator_lr),
                noise_model=ipywidgets.Dropdown(options=['uniform', 'normal'], value='normal', description='Noise model:', style={'description_width': 'initial'}, continuous_update=False),
                noise_dim=ipywidgets.IntSlider(min=1, max=10, value=4, description='Noise dimension:', style={'description_width': 'initial'}, continuous_update=False),
                generator_lr=ipywidgets.FloatLogSlider(value=1e-4, min=-6, max=0, description="Generator lr", style={'description_width': 'initial'}, continuous_update=False),
                discriminator_lr=ipywidgets.FloatLogSlider(value=1e-4, min=-6, max=0, description="Discriminator lr", style={'description_width': 'initial'}, continuous_update=False),
               );

# GAN's for fitting a GMM

In [None]:
rcParams['figure.figsize'] = (20, 8)
rcParams['font.size'] = 16

gmm_model = GMM(weights=np.array([0.3, 0.5, 0.2]),
                 means=np.array([-3., 0., 2.]) / 5,
                 scales=np.array([0.5, 1.0, 0.1]) / 5)
plot_gmm(gmm_model, gmm_model, 'Exact Model')

In [None]:
rcParams['figure.figsize'] = (20, 8)
rcParams['font.size'] = 16
num_iter = 15000
interact_manual(lambda noise_model, noise_dim, generator_lr, discriminator_lr: train_gan_interactive(
    num_iter, gmm_model, noise_model, noise_dim, generator_lr, discriminator_lr),
                noise_model=ipywidgets.Dropdown(options=['uniform', 'normal'], value='normal', description='Noise model:', style={'description_width': 'initial'}, continuous_update=False),
                noise_dim=ipywidgets.IntSlider(min=1, max=10, value=8, description='Noise dimension:', style={'description_width': 'initial'}, continuous_update=False),
                generator_lr=ipywidgets.FloatLogSlider(value=5e-4, min=-6, max=0, description="Generator lr", style={'description_width': 'initial'}, continuous_update=False),
                discriminator_lr=ipywidgets.FloatLogSlider(value=1e-3, min=-6, max=0, description="Discriminator lr", style={'description_width': 'initial'}, continuous_update=False),
               );