<a href="https://colab.research.google.com/github/psuarezserrato/aprendizaje-geometrico/blob/main/AGP_AMallasto_WGAN_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[Motivational video](https://www.youtube.com/watch?v=XOxxPcy5Gr4&t=81s)

#Brief Introduction to WGANs
The aim of this tutorial is to give a brief introduction to Wasserstein GANs (WGANs). The objective is simple, we wish to learn from a target data distribution $\mu_t$, by minimizing the $1$-Wasserstein distance $W_1$ between $\mu_t$ and a parametrized model distribution $\mu_\omega$ with parameter $\omega$. That is, our training objective is
\begin{equation}
\min\limits_{\omega} W_1(\mu_{\omega}, \mu_t).
\end{equation}
This has two main ingredients: 1) expressing the model distribution as a **push-forward** employing a **generator** neural network 2)  Estimating the $1$-Wasserstein distance using a **discriminator** neural network. Below, we will first briefly look at pytorch notation. After this, we consider 1) first, which is the easy part. Then, we move on to 2), where the magic of WGANs happens.

In [None]:
#Import the python libraries that we will need

import numpy as np #Standard python package for numerical computations

import torch #GPU enabled deep learning library by Facebook
import torch.nn as nn
import torch.autograd as autograd
from torch.autograd import Variable

import seaborn as sns #For statistical visualization
import pandas as pd #Cannot do data science without pandas

import matplotlib.pyplot as plt #For visualization

# Some Pytorch Notation
Pytorch operates on its own data structure, namely [torch.Tensors](https://pytorch.org/docs/stable/tensors.html). Below, we will define one, apply a function on it, and show  how pytorch can automatically compute gradients using [Autograd](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html).

In [None]:
#Define a tensor
x = torch.Tensor([[1,2,3,4]])
x = Variable(x, requires_grad = True)

#Tell pytorch that we want to compute gradients
x.requires_grad = True

#Define a simple function using python's lambda notation
f = lambda x: torch.sum(x**2)

#Compute f(x)
y = f(x)

#Compute the gradient at x by backpropagation
y.backward()
v = x.grad

print(v)

##Exercise 1:
Now you do it! Compute the gradient of 
\begin{equation}
A \mapsto \log\left(\mathrm{det}(AB + A^{-1})\right)
\end{equation}
at 
\begin{equation}
A = \begin{bmatrix}5&2\\2&1\end{bmatrix},
\end{equation}
when
\begin{equation}
B = \begin{bmatrix}1&2\\2&1\end{bmatrix}.
\end{equation}
( Hint: torch.log, torch.det, torch.inverse, matrix multiplication is given by A@B and for example the identity matrix is written as torch.Tensor([[1,0],[0,1]]) )


#  Generative Adversial Networks
**Generative adversial networks ** (GANs) aim at learning a generative model for sampling from a given **data distribution** $\mu_{\mathrm{data}}\in \mathcal{P}(\mathbb{R}^n)$. This is carried out by defining a **source distribution** $\mu_{\mathrm{source}}\in \mathcal{P}(\mathbb{R}^d)$, where $d<<n$, and then pushing it forward with the **generator** $g_\omega: \mathbb{R}^d \to \mathbb{R}^n$, denoted by $(g_\omega)_\# \mu_{\mathrm{source}}$. Then, the parameter $\omega$ is optimized to minimize a given similarity measure between  $(g_\omega)_\# \mu_{\mathrm{source}}$ and $\mu_{\mathrm{data}}$

Two reminders before we move on. First, recall that the **manifold assumption** is prevalent in machine learning, stating that natural data lies on low dimensional submanifolds of the ambient data space, and therefore $d << n$ is justified. Second, the **push-forward** of a measure $\mu$ on $\Omega$ with respect to a measurable map $f:\Omega \to \Omega'$ is defined by $(f_\#\mu)(A) = \mu(f^{-1}(A))$ for any measurable set $A$ in $\Omega'$. This definition  translates into something even simpler. Assume a random variable $X$ has law $\mu$, then $f(X)$ has law $f_\#\mu$.

# Implementing the Push-Forward
Lets dirty our hands a bit, by creating a source distribution $\mu_{\mathrm{source}}\in \mathbb{R}^2$ and a generator $g_\omega:\mathbb{R}^2\to \mathbb{R}^2$ given by a multilayer perceptron (the simplest neural network), and see what we get.

The function **source** eats the argument N and spits out N samples from the standard normal distribution on $\mathbb{R^2}$.

The generator **g** is constructed as a multilayer perceptron (MLP) with two hidden layers of 128 neurons, using rectified linear units (ReLU) as activation functions. This feeds on $M\times 2$ matrices, applying itself row-wise, returning a  $M\times 2$ matrix.

See the [wikipedia article](https://en.wikipedia.org/wiki/Multilayer_perceptron) for more about MLPs.

See the [pytorch documentation](https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) for different activations and layers pytorch has to offer. Feel free to play around with them, if you are feeling adventurous!


In [None]:
#Source samples [N] points from the standard normal distribution in R^2.
#These will be given in a N x 2 matrix, with rows corresponding to samples.
source = lambda N:torch.randn((N,2))
#Generator network (used to push forward the source to the data distribution).
g = nn.Sequential(nn.Linear(2,128), nn.ReLU(),
                    nn.Linear(128,128), nn.ReLU(),
                    nn.Linear(128,2)
                   )

#Sample 100 points from source and the push-forward.
#We pick different samples for the push-forward to avoid correlation.
samples_source = source(100)
samples_push = g(source(100)).detach()
#above, detach tells pytorch that this variable is now static, and should not
#be considered in the computational graph. This allows us to plot the points.

#Lets see how the push-forward changes the distribution.
plt.figure()
plt.scatter(samples_source[:,0], samples_source[:,1], color='b')
plt.scatter(samples_push[:,0], samples_push[:,1], color='r')
plt.legend(['source', 'push-forward'])
plt.show()

# Picking a Similarity Measure
Right now, our push-forward is not producing anything useful. To learn the model to sample from a given data distribution $\mu_{\mathrm{data}}$, we need to pick a similarity measure to be minimized. This choice specifies which GAN we are working with, for example,  [The original GAN](https://papers.nips.cc/paper/5423-generative-adversarial-nets) minimizes the [Jensen-Shannon divergence](https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence), whereas the [WGAN](https://arxiv.org/abs/1701.07875) minimizes the [$1$-Wasserstein metric](https://en.wikipedia.org/wiki/Wasserstein_metric). There are plenty of other choices, but in this tutorial we will focus on WGANs.

#$p$-Wasserstein Metric
Let us define the $p$-Wasserstein metric between two probability measures, that results from solving a constrained linear program. We first give the **primal** formulation, after which we look at the **dual formulation**.

##Primal formulation
Let $(X,d)$ be a metric space that is complete and separable (a Polish space), and pick two probability measures $\mu$ and $\nu$ with finite $p$-moments. Then, the $p$-Wasserstein metric between $\mu$ and $\nu$ is given by
\begin{equation}
W_p(\mu,\nu) = \left(\min_\gamma\mathbb{E}_\gamma[d^p]\right)^{\frac{1}{p}},
\end{equation}
where $\gamma$ is constrained to be a joint distribution of $\mu$ and $\nu$, and we use the notation
\begin{equation}
\mathbb{E}_\mu[f] = \int_X f(x)d\mu(x).
\end{equation}
Note that this defines a symmetric and positive-definite metric that satisfies the triangle inequality.

##Dual formulation
As the Wasserstein metric results from solving a linear program, it admits a dual formulation given by 
\begin{equation}
W_p^p(\mu, \nu) = \max_{\varphi, \psi}\lbrace \mathbb{E}_\mu[\varphi] + \mathbb{E}_\nu[\psi]\rbrace,
\end{equation}
where we require $\varphi(x) + \psi(y) \leq d^p(x,y)$ for any $x,y$. The optimal $\varphi, \psi$ are called **Kantorovich potentials**, satisfying $\varphi(x) + \psi(y) = d^p(x,y)$ for any $(x,y)$ in the support of the optimal $\gamma$ that solves the primal problem. Furthermore, we know that $\psi(y) = \varphi^c(y)$, where
\begin{equation}
\varphi^c(y) = \inf_x\{d^p(x,y)-\varphi(x)\},
\end{equation}
is the **$c$-transform** of $\varphi$. This allows us to write the dual problem as
\begin{equation}
W_p^p(\mu, \nu) = \max_{\varphi, \psi}\lbrace \mathbb{E}_\mu[\varphi] + \mathbb{E}_\nu[\varphi^c]\rbrace.
\end{equation}
If $\varphi$ is $1$-Lipschitz and $c=d^1$, magic happens, as $\varphi^c(y) = -\varphi(y)$!(not factorial) It can be shown, that when $p=1$, the optimal $\varphi$ will indeed be $1$-Lipschitz, which we will use later on when formulating the WGAN objective. (Trivia: for a general $p$, if the support of any joint measure of $\mu$ and $\nu$ is bounded by diameter $D$, then the Kantorovich potentials will be $pD^{p-1}$-Lipschitz. However, the $d^p$-transform does not behave as nicely for Lipschitz functions).

#Computing the 1-Wasserstein Distance
Let's put this into practice and compute the 1-Wasserstein distance between $\mu_{\mathrm{source}}$ and $(g_\omega)_\#\mu_{\mathrm{source}}$, which we implemented earlier. We will do this using a Monte Carlo scheme and approximating $\varphi$ with a MLP $f_{\omega'}$, which we call the **discriminator**.

Sample $N$ points from the source $\mu_s$ and push-forward $(g_\omega)_\#\mu_s$, yielding $\{x_i\}_{i=1}^N$ and $\{y_i\}_{i=1}^N$, respectively, which we call the **mini-batches**. Then, we compute the objective function $\rho$
\begin{equation}
\rho(\omega') = W_1(\mu_{\mathrm{source}}, (g_\omega)_\#\mu_{\mathrm{source}}) \approx \max_{\omega'} \frac{1}{N}\lbrace \sum_{i=1}^N f_{\omega'}(x_i) - \sum_{i=1}^N f{\omega'}(y_i)  \rbrace,
\end{equation}
where we used that at optimality, $f_{\omega'}^c(y) = -f_{\omega'}(y)$. We optimize using [stochastic gradient descent](https://en.wikipedia.org/wiki/Stochastic_gradient_descent), which means that after sampling mini-batches, we compute the gradient $v$ of the above expression with respect to $\omega'$, update $v + \omega' \mapsto \omega'$, and then sample new mini-batches and repeat the same until convergence (or until we are satisfied).

Computing the gradients and updating the parameters has been made very easy in pytorch, by using automatic differentation via [backpropagation](https://en.wikipedia.org/wiki/Backpropagation), which is carried out by [pytorch optimizers](https://pytorch.org/docs/stable/optim.html).

##Implementation
The implementation can be found below. We define the discriminator and the related optimizer (1), after which we begin the training loop (2). At each step in the training loop, we sample mini-batches from each measure (3), and compute the mean of $f_{\omega'}$ under $\mu$ and the mean of $-f_{\omega'}$ under $\nu$ (4). Summing these together gives us the objective function (5). Then, we just compute the gradient of the loss with respect to $\omega'$ and update $\omega'$ (6).

In [None]:

#-----------(1)-----------
#Define the discriminator that assings a real value to
#points in R^2.
f = nn.Sequential(nn.Linear(2,128), nn.ReLU(),
                    nn.Linear(128,128), nn.ReLU(),
                    nn.Linear(128,1)
                   )
#For updating the parameters, we need to use a pytorch
#optimizer object.
f_optim = torch.optim.RMSprop(f.parameters(), lr = 1e-4)
#We are using the RMSprop optimizer, but see the above cell
#for other optimization methods pytorch has to offer.

N_batch = 64 #Mini-batch size
N_iterations = 1000 #Amount of gradient steps we will take

#Lets keep track of how our discriminator is approximating
#the 1-Wasserstein distance at each iteration
loss_history = []

#-----------(2)-----------
#Train the discriminator
for i in range(N_iterations):
  #-----------(3)-----------
  #Sample from the source and the push-forward
  samples_source = source(N_batch)
  samples_push = g(source(N_batch))
  
  #-----------(4)-----------
  #Compute f values
  f_source = f(samples_source)
  f_push = -f(samples_push)
  
  #-----------(5)-----------
  #Pytorch optimizers minimize objective functions, as we
  #want to maximize instead, we will be minimizing the negative
  #of the objective
  loss = -(f_source.mean() + f_push.mean())
  
  #-----------(6)-----------
  #tell pytorch to compute the gradient
  loss.backward()
  
  #update the parameters of f
  f_optim.step()
  
  #Pytorch accummulates gradients, so we will have to zero
  #them after each update
  f_optim.zero_grad()
  
  #save the current value
  loss_history.append(float(-loss))
  
#Plot the loss
plt.figure()
plt.plot(np.arange(N_iterations), loss_history)
plt.title('Approximating 1-Wasserstein')
plt.show()

##Exercise 2: 
Increase N_iterations and observe the behavior of the loss function. Are we able to compute the 1-Wasserstein distance?

#Incorporating the Discriminator Constraints
As you can see, the loss does not converge. This is because we are not taking into account the constraint on the Kantorovich potentials
\begin{equation}
\varphi(x)-\varphi(y) \leq d(x,y).
\end{equation}
The constraint is very important, indeed, as the only way the $l^2$-metric $d$ comes into play is through the constraint. Therefore, in the cell above, we are not computing the 1-Wasserstein distance!

Implementing the constraint is an art of its own, as a considerable amount of papers study how to do this properly. The original WGAN paper does this through **weight clipping**. Other notable versions include the [WGAN-GP](https://arxiv.org/abs/1704.00028), which is pretty much the state of the art, and the [CT- WGAN](https://arxiv.org/abs/1803.01541).

The papers approach the constraint through Lipschitzness. If you look at the expression above, this is exactly the requirement for $\varphi$ to be $1$-Lipschitz. The papers then focus on enforcing this:

-Original WGANs clip the weights of the neural network $f_{\omega'}$ to a box (they enforce $-c\leq \omega'_{\mathrm{weights}} \leq c$, where $c$ is a small constant and the inequalities are considered element-wise). This forces $f_{\omega'}$ to be $K$-Lipschitz for some $K$, which is not a problem, as maximizing the dual expression with respect to $K$-Lipschitz functions yields the Wasserstein distance multiplied by $K$.

-Gradient penalty WGAN (WGAN-GP) relies on the result, that $f_{\omega'}$ is $1$-Lipschitz, if $\|\nabla_x f_{\omega'}(x)\| \leq 1$ for almost every $x$. Therefore, they apply a penalty term to the objective function to enforce this condition.

-Consistency term WGAN (CT-WGAN) enforces the Lipschitz condition directly, by adding a penalty term $\mathbb{E}(|f_{\omega'}(x) - f_{\omega'}(y)|-1)^2$.

Do note, that these techniques only work in the $1$-Wasserstein case. Below, we will apply the original weight clipping scheme with $c=0.01$. This does not perform too well in practice, but is the easiest to implement. Afterwards, we will try out the gradient penalty used in WGAN-GP.

##Exercise 3:
I am lying when saying that we are not taking the constraint on Kantorovich potentials into accout. Why? 

**Extra**: We are taking the constraint into account in a way. Why isn't the method converging?

In [None]:
#An utility function for clipping the weights of f
def clip_weights(f, c):
  for param in f.parameters():
    if len(param)>1:
        param.data.clamp_(-c, c)

#Discriminator
f = nn.Sequential(nn.Linear(2,128), nn.ReLU(),
                    nn.Linear(128,128), nn.ReLU(),
                    nn.Linear(128,1)
                   )
f_optim = torch.optim.RMSprop(f.parameters(), lr = 1e-4)

N_batch = 64 #Mini-batch size
N_iterations = 1000 #Amount of gradient steps we will take

loss_history = [] #Save the loss at each iteration here

#Train the discriminator
for i in range(N_iterations):
    #Sample from the source and the push-forward
    samples_source = source(N_batch)
    samples_push = g(source(N_batch))

    #-----------Weight clipping-----------
    #Clip the weights of f
    clip_weights(f, .01) 
    
    #Compute the f values
    f_source = f(samples_source)
    f_push = -f(samples_push)
    loss = -(f_source.mean() + f_push.mean())

    #Compute the gradient
    loss.backward()

    #update the parameters of f
    f_optim.step()
    f_optim.zero_grad()

    #save the current value
    loss_history.append(float(-loss))
  
#Plot the loss
plt.figure()
plt.plot(np.arange(N_iterations), loss_history)
plt.title('Approximating 1-Wasserstein with Weight Clipping')
plt.show()

# Gradient Penalty
We are doing better, but the objective still will not converge in a reasonable amount of iterations. Let's see, how the gradient penalty introduced in WGAN-GP performs. This is done, by adding the gradient penalty to the objective function
\begin{equation}
\rho(\omega') = W_1(\mu_{\mathrm{source}}, (g_{\omega})_\#\mu_{\mathrm{source}}) \approx \max_{\omega'} \frac{1}{N}\lbrace \sum_{i=1}^N f_{\omega'}(x_i) - \sum_{i=1}^N f_{\omega'}(y_i) - \lambda\mathbb{E}_\nu[(\|\nabla_x \varphi (x)\| -1)^2]\rbrace,
\end{equation}
where $\lambda$ is the weight of the penalization, and $\nu$ some reference measure. Note that in theory, instead of equality we should be enforcing $\|\nabla_x f_{\omega'}(x)\| \leq 1$, but in the WGAN-GP the authors remarked that in practice this suffices.

##Exercise 4:

Below, I have provided a function that computes the gradient penalty. Incorporate this into the discriminator training procedure.

In [None]:
#An utility function for clipping the weights of f
def gradient_penalty(f, samples_source, samples_push, lambda_reg=5, use_cuda = False):
    BATCH_SIZE = samples_source.shape[0]
    LAMBDA = 5
    alpha = torch.rand(BATCH_SIZE, 1)
    alpha = alpha.expand(samples_source.size())

    interpolates = alpha * samples_source + ((1 - alpha) * samples_push)
    
    if use_cuda:
        interpolates = interpolates.cuda(gpu)
    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates = f(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_reg
    return gradient_penalty

#Discriminator
f = nn.Sequential(nn.Linear(2,128), nn.ReLU(),
                    nn.Linear(128,128), nn.ReLU(),
                    nn.Linear(128,1)
                   )
f_optim = torch.optim.RMSprop(f.parameters(), lr = 1e-4)

N_batch = 64 #Mini-batch size
N_iterations = 1000 #Amount of gradient steps we will take

loss_history = [] #Save the loss at each iteration here

#Train the discriminator
for i in range(N_iterations):
    #Sample from the source and from the push-forward
    samples_source = source(N_batch)
    samples_push = g(source(N_batch))
    
   #Remember to add the gradient penalty!!
        
    f_source = f(samples_source)
    f_push = -f(samples_push)
    loss = -(f_source.mean() + f_push.mean())
     
    #backpropagate for the gradient
    loss.backward()

    #update the parameters of f
    f_optim.step()
    f_optim.zero_grad()

    #save the current value
    loss_history.append(float(-loss))
  
#Plot the loss
plt.figure()
plt.plot(np.arange(N_iterations), loss_history)
plt.title('Approximating 1-Wasserstein with Gradient Penalty')
plt.show()

#Implementing the WGAN
We can now estimate the $1$-Wasserstein distance between two measures $\mu$ and $\nu$, where above we took $\mu=\mu_s$ and $\nu = (g_{\omega})_\#\mu_s$.  The final step in the WGAN implementation is to update the weights $\omega$ of the generator to minimize the objective function that is given by the $1$-Wasserstein distance and possibly a penalization term.

To do this, we will do something slightly more interesting, and try to learn the distribution of a Gaussian mixture model with three components in $\mathbb{R}^2$, which we have implemented below.

In [None]:
def gaussian_mixture(N, weights, variances, means):
    M = len(weights) #Number of clusters
    d = means[0].shape[1] #Dimension of the Gaussian
    
    #Sample from multinomial distribution
    pvals = torch.multinomial(weights, N, replacement = True)
    
    #counts of each index appearing in pvals
    counts = [int(torch.sum(pvals == i)) for i in range(M)]
    
    #Sample from clusters
    samples = []
    
    for i in range(M):
        samplesi = torch.sqrt(variances[i])*torch.randn((counts[i],d)) + means[i]
        samples.append(samplesi)
        
    return torch.cat(samples,0)
    
weights = torch.Tensor([.2, .3, .5])
variances = torch.Tensor([.5, .2, .6])
means = [torch.Tensor([[0,2]]), torch.Tensor([[-4,-1]]), torch.Tensor([[2,-2]])]

#Sample data points for visualization
samples = gaussian_mixture(200, weights, variances, means)


#visualize data
plt.figure()
plt.scatter(samples[:,0], samples[:,1])
plt.show()

#Updating the Generator
Now that we have our data set, we will learn how to sample from it.

In [None]:

#Discriminator
f = nn.Sequential(nn.Linear(2,128), nn.ReLU(),
                    nn.Linear(128,128), nn.ReLU(),
                    nn.Linear(128,1)
                   )
f_optim = torch.optim.RMSprop(f.parameters(), lr = 1e-3)

#Source
source = lambda N:torch.randn((N,2))
#Target
target = lambda N: gaussian_mixture(N, weights, variances, means)

#Generator
g = nn.Sequential(nn.Linear(2,128), nn.ReLU(),
                    nn.Linear(128,128), nn.ReLU(),
                    nn.Linear(128,128), nn.ReLU(),
                    nn.Linear(128,2)
                   )
g_optim = torch.optim.RMSprop(g.parameters(), lr = 1e-3)

N_batch = 64 #Mini-batch size
N_discriminator = 5 #Discriminator iterations per generator iteration
N_generator = 500 #Amount of generator iterations

loss_history = []
#Train the generator
for i in range(N_generator):
    #Sample from the source and apply the push-forward
    samples_source= source(N_batch)
    #Sample from the target
    #samples_nu = target(N_batch)
    samples_nu = target(N_batch)
    
    #Train the discriminator
    for j in range(N_discriminator):
        #Push-forward source samples
        samples_mu = g(samples_source)
        
        #Compute the gradient penalty
        penalty = gradient_penalty(f, samples_mu, samples_nu)
        
        #Compute f values
        f_mu = f(samples_mu)
        f_nu = -f(samples_nu)
        loss_discriminator = -(f_mu.mean() + f_nu.mean()) + penalty

        #backpropagate for the gradient
        loss_discriminator.backward()

        #update the parameters of f
        f_optim.step()
        f_optim.zero_grad()
        g_optim.zero_grad()

        #save the current value
        loss_history.append(float(-loss_discriminator))
   
    #Update generator
    samples_mu = g(samples_source)
    f_source = f(samples_mu)
    f_push = -f(samples_nu)
    
    #This time we want to minimize the expression, so no minus needed.
    loss_generator = f_source.mean() + f_push.mean()
    
    loss_generator.backward()
    g_optim.step()
    g_optim.zero_grad()
    
#Lets plot the results!
samples_generator = g(source(200)).detach()
samples_target = target(200).detach()

gen_samples = g(source(200)).detach()
dat_samples =pd.DataFrame(target(200).detach().numpy(), columns=['x', 'y'])
gen_samples_df = pd.DataFrame(gen_samples.detach().numpy(), columns=['x', 'y'])

plt.figure()
plt.scatter(dat_samples.x, dat_samples.y, c='b', alpha=0.2)
sns.kdeplot(gen_samples_df.x, gen_samples_df.y, zorder=0, n_levels=10, legend=True, shade=True)
plt.legend(['Target Samples'])
plt.title('Model distribution')
plt.show()

#A Challenge for the Brave
Below, I have added a script for downloading the MNIST hand-written digits dataset. Running the script will yield a **train_loader**, that is mandatory for working with large datasets. It reads mini-batches from the data into memory when needed, and discards then when we are done with that mini-batch. This avoids loading the whole dataset into memory at once. Below this script, there is a cell with the skeleton for the WGAN procedure on the MNIST dataset. It is your task to complete this.

In order to to make this a bit more interesting, we will use the GPU provided by google for this task. To do this, we will have to do some notifications in the code. We will have to change our datatype from torch.FloatTensor to torch.cuda.FloatTensor. Additionally, we will also have to specify that our discriminator and generator work with this datatype, by using the command .cuda() (see the code below).

Changing the notebook to use the GPU will reset the session, so we will have to import the libraries again below. You can start using the GPU, by clicking on [Runtime] in the upper left corner, then choose [Change runtime type] and finally pick GPU as the [Hardware accelerator].

In [None]:
#Import the python libraries that we will need

import numpy as np #Standard python package for numerical computations

import torch #GPU enabled deep learning library by Facebook
import torch.nn as nn
import torch.autograd as autograd
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms

import seaborn as sns #For statistical visualization
import pandas as pd #Cannot do data science without pandas

import matplotlib.pyplot as plt #For visualization

def imshow(image):
    I = image.clone().cpu().numpy().reshape(DIM_IMG, DIM_IMG) + 0.5
    plt.imshow(I, cmap='gray')

DLATENT = 64
DIM_IMG = 28
DIM = DIM_IMG*DIM_IMG
N_batch = 64
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (.5,))])
# if not exist, download mnist dataset
train_set = dset.MNIST(root = './', train=True, transform=trans, download=True)



train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=N_batch,
                 shuffle=True)

##WGAN Skeleton for the MNIST Dataset

In [None]:
#We have to work with a GPU friendly data type.
dtype = torch.cuda.FloatTensor

def gradient_penalty(f, samples_source, samples_push, lambda_reg=5, use_cuda = False):
    BATCH_SIZE = samples_source.shape[0]
    LAMBDA = 5
    alpha = torch.rand(BATCH_SIZE, 1).type(dtype)
    alpha = alpha.expand(samples_source.size())

    interpolates = alpha * samples_source + ((1 - alpha) * samples_push)
    
    if use_cuda:
        interpolates = interpolates.cuda(gpu)
    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates = f(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).type(dtype),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_reg
    return gradient_penalty
  
  
#Discriminator
f = nn.Sequential(nn.Linear(DIM,128), nn.ReLU(),
                    nn.Linear(128,128), nn.ReLU(),
                    nn.Linear(128,1)
                   ).cuda() #<--- Notice the .cuda()

f_optim = torch.optim.Adam(f.parameters(), lr = 1e-4, betas=(0, .9))

#Source
source = lambda N:torch.randn((N,DLATENT)).type(dtype)

#Generator
g = nn.Sequential(nn.Linear(DLATENT,128), nn.ReLU(),
                    nn.Linear(128,128), nn.ReLU(),
                    nn.Linear(128,128), nn.ReLU(),
                    nn.Linear(128,DIM), nn.Tanh()
                   ).cuda()
g_optim = torch.optim.Adam(g.parameters(), lr = 1e-4, betas=(0,.9))

N_batch = 64 #Mini-batch size
N_discriminator = 5 #Discriminator iterations per generator iteration
N_epoch= 10 #Amount of times we iterate over the full dataset

loss_history = []
#Train the generator
for i in range(N_epoch):
  #Sample from target using the train_loader
    for samples_nu, _ in train_loader:
        if samples_nu.shape[0] != N_batch:
          break

        samples_nu = samples_nu.view(N_batch,-1).type(dtype) #Make the data GPU friendly
        samples_source= source(N_batch)

        #Train the discriminator
        for j in range(N_discriminator):
            #Add the code!

        #Update generator
        #Add the code!

    print('%d Epochs done' %i)
    


#Plot some generator samples!

In [None]:
samples_generator = g(source(8)).detach()

fig = plt.figure(figsize=(8,4))
for i in range(8):
  plt.subplot(2,4, i+1)
  imshow(samples_generator[i])
  plt.grid(False)
  plt.axis('off')
  plt.xticks([])
  plt.yticks([])
  plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
plt.show()
print(torch.max(samples_generator[0]))