## Task 1: Theory (3pt)

1. Let's recall the variational lower bound for the discrete data dequantization problem from Lecture 8:
$$
    \log P(\mathbf{x} | \boldsymbol{\theta}) \geq  \int q(\mathbf{u} | \mathbf{x}) \log \frac{p(\mathbf{x} + \mathbf{u} | \boldsymbol{\theta})}{q(\mathbf{u} | \mathbf{x})} d \mathbf{u} = \mathcal{L}(q, \boldsymbol{\theta}).
$$
We have discussed with you that the variational lower bound can be improved with the help of Importance Sampling. Write out the lower bound for dequantization using $\mathcal{L}_k$ by analogy with the IWAE model (you need to use not just one $\mathbf{u}$, but a set of $\{\mathbf{u}_k\}_{k=1}^K$).
    
2. The Vanilla GAN often suffers from problems with a saturating gradient. [Least Squares GAN](https://arxiv.org/abs/1611.04076) tries to solve this problem by replacing the error function with the following:
$$
   	\min_D V(D) = \min_D \frac{1}{2}\left[ \mathbb{E}_{\pi(\mathbf{x})} (D(\mathbf{x}) - b)^2 + \mathbb{E}_{p(\mathbf{z})} (D(G(\mathbf{z})) - a)^2 \right]
$$
$$
   	\min_G V(G) = \min_G \frac{1}{2}\left[ \mathbb{E}_{\pi(\mathbf{x})} (D(\mathbf{x}) - c)^2 + \mathbb{E}_{p(\mathbf{z})} (D(G(\mathbf{z})) - c)^2 \right],
$$
where $a,b,c \in \mathbb{R}$ some fixed constants.
    * Find the formula for the optimal discriminator $D^*$.
    *  \textbf{(1 pt)} Write out the expression of the error function of the generator $V(G)$ in the case of an optimal discriminator $D^*$.
    * \textbf{(1.5 pt)} Prove that for $b - c = 1$, $b - a = 2$, the error function of the generator $V(G)$ in the case of the optimal discriminator $D^*$ takes the form:
$$
   	V(G) = \frac{1}{2} \chi^2_{\text{Pearson}} (\pi(\mathbf{x}) + p(\mathbf{x} | \boldsymbol{\theta}) || 2 p(\mathbf{x} | \boldsymbol{\theta})), 
$$
where $\chi^2_{\text{Pearson}} (p || q)$ is a squared Pearson divergence:
$$
   	\chi^2_{\text{Pearson}} (p || q) = \int \frac{(p(\mathbf{x}) - q(\mathbf{x}))^2}{p(\mathbf{x})} d \mathbf{x}.
$$

```your solution```

In [None]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision.utils import make_grid

USE_CUDA = torch.cuda.is_available()

print('cuda is available:', USE_CUDA)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def plot_losses(losses, title):
    n_itr = len(losses)
    xs = np.arange(n_itr)

    plt.figure(figsize=(7, 5))
    plt.plot(xs, losses)
    plt.title(title, fontsize=14)
    plt.xlabel('Iterations', fontsize=14)
    plt.ylabel('Loss', fontsize=14)

    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.show()

# Task 2: Vanila GAN for 1d task (3pt)

In this task you will train simple GAN model for 1d distribution (mixture of 2 gaussians).

In [None]:
def generate_1d_data(count):
    gaussian1 = np.random.normal(loc=-1, scale=0.25, size=(count // 2,))
    gaussian2 = np.random.normal(loc=0.5, scale=0.5, size=(count // 2,))
    data = (np.concatenate([gaussian1, gaussian2]) + 1).reshape([-1, 1]).astype('float32')
    return (data - data.min()) / (data.max() - data.min())


def visualize_1d_data(data):
    plt.figure(figsize=(7, 4))
    plt.hist(data, bins=50)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.show()


def make_inference(generator, discriminator, n_samples=5000):
    generator.eval()
    discriminator.eval()
    xs = np.linspace(0, 1, 1000)
    samples = generator.sample(n_samples).cpu().detach().numpy()
    discr_output = discriminator(torch.FloatTensor(xs).cuda().unsqueeze(1)).cpu().detach().numpy()
    return samples, xs, discr_output


def plot_results(data, samples, xs, ys, title):
    plt.figure(figsize=(7, 5))
    plt.hist(samples, bins=50, density=True, alpha=0.7, label='fake')
    plt.hist(data, bins=50, density=True, alpha=0.7, label='real')

    plt.plot(xs, ys, label='discrim')
    plt.title(title, fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

In [None]:
COUNT = 20000

train_data = generate_1d_data(COUNT)
visualize_1d_data(train_data)

The next functions help you to train your model. Read them carefully.

In [None]:
def make_step(loss_fn, x, generator, discriminator, optimizer):
    loss = loss_fn(generator, discriminator, x)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss


def train_epoch(
    generator, 
    discriminator, 
    gen_loss_fn, 
    discr_loss_fn, 
    train_loader, 
    gen_optimizer, 
    discr_optimizer, 
    discr_steps,
    use_cuda
):
    generator.train()
    discriminator.train()

    gen_losses, discr_losses = [], []
    for idx, x in enumerate(train_loader):
        if use_cuda:
            x = x.cuda()
        discr_loss = make_step(discr_loss_fn, x, generator, discriminator, discr_optimizer)
        discr_losses.append(discr_loss.item())

        if idx % discr_steps == 0:
            gen_loss = make_step(gen_loss_fn, x, generator, discriminator, gen_optimizer)
            gen_losses.append(gen_loss.item())
    return {
        'generator_losses': gen_losses,
        'discriminator_losses': discr_losses
    }


def train_gan(
    generator, 
    discriminator, 
    gen_loss_fn, 
    discr_loss_fn, 
    train_loader, 
    epochs,
    lr,
    discr_steps=1,
    use_tqdm=False, 
    use_cuda=False
):
    gen_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0, 0.9))
    discr_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0, 0.9))

    if use_cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    train_losses = {}
    forrange = tqdm(range(epochs)) if use_tqdm else range(epochs)
    for epoch in forrange:
        train_loss = train_epoch(
            generator, 
            discriminator, 
            gen_loss_fn, 
            discr_loss_fn,
            train_loader, 
            gen_optimizer, 
            discr_optimizer, 
            discr_steps=discr_steps,
            use_cuda=use_cuda
        )
        
        for k in train_loss.keys():
            if k not in train_losses:
                train_losses[k] = []
            train_losses[k].extend(train_loss[k])

    return train_losses

Generator and Discriminator models are simple MLP models.

In [None]:
class FullyConnectedMLP(nn.Module):
    # do not change this class
    def __init__(self, input_dim, hiddens, output_dim):
        assert isinstance(hiddens, list)
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hiddens = hiddens

        # ====
        # your code
        # stack Linear layear and ReLU (the last layer shoud be Linear)
        
        
        # ====

    def forward(self, x):
        # ====
        # your code
        
        
        # ====

In [None]:
class MLPGenerator(nn.Module):
    def __init__(self, latent_dim, hiddens, data_dim, use_sigmoid=False):
        super().__init__()
        self.latent_dim = latent_dim
        self.use_sigmoid = use_sigmoid
        # ====
        # your code
        # define mlp
        
        # ====
    
    def forward(self, z):
        # ====
        # your code
        # 1) apply mlp 
        # 2) use_sigmoid flag means that data is from 0 to 1, apply sigmoid function to the output if flag is True

        # ====
        
    def sample(self, n):
        # ====
        # your code
        # sample from standard normal distribution and apply the model

        # ====


class MLPDiscriminator(nn.Module):
    def __init__(self, data_dim, hiddens, use_sigmoid=False):
        super().__init__()
        self.use_sigmoid = use_sigmoid
        # ====
        # your code
        # define mlp

        # ====
    
    def forward(self, x):
        # ====
        # your code
        # 1) apply mlp
        # 2) use_sigmoid flag means that discriminator outputs the value from 0 to 1
        #    apply sigmoid function to the output if flag is True

        # ====

The objective function is 
$$\min_{G} \max_{D} \mathbb{E}_{\mathbf{x} \sim \pi(\mathbf{x})} [\log D(\mathbf{x})] + \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})}[\log (1-D(G(\mathbf{z})))]$$

Now you have to implement the objective.

In [None]:
def gen_loss_fn(generator, discriminator, x):
    # ====
    # your code
    # sample data from generator (number of samples = x.shape[0])

    # ====


def discr_loss_fn(generator, discriminator, x):
    # ====
    # your code
    # sample data from generator (number of samples = x.shape[0])

    # ====

In [None]:
# ====
# your code
# choose these parameters
BATCH_SIZE = 
GEN_HIDDENS = 
DISCR_HIDDENS = 
EPOCHS = 
LR = 
DISCR_STEPS = 
# ====

train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

generator = MLPGenerator(latent_dim=1, hiddens=GEN_HIDDENS, data_dim=1, use_sigmoid=True)
discriminator = MLPDiscriminator(data_dim=1, hiddens=DISCR_HIDDENS, use_sigmoid=True)

# train
train_losses = train_gan(
    generator, 
    discriminator, 
    gen_loss_fn, 
    discr_loss_fn, 
    train_loader, 
    epochs=EPOCHS,
    lr=LR,
    discr_steps=DISCR_STEPS,
    use_tqdm=True,
    use_cuda=USE_CUDA
)

plot_losses(train_losses['discriminator_losses'], 'Discriminator loss')
plot_losses(train_losses['generator_losses'], 'Generator loss')

In [None]:
samples, xs, discr_output = make_inference(generator, discriminator)
plot_results(train_data, samples, xs, discr_output, 'Results')

Now we'll use the non-saturating formulation of the GAN objective. In this case we have two separate losses:
$$
    L^{(D)} = \mathbb{E}_{\mathbf{x} \sim \pi(\mathbf{x})} [\log D(\mathbf{x})] + \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})}[\log (1-D(G(\mathbf{z})))]
$$
$$
    L^{(G)} = - \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})} \log(D(G(\mathbf{z}))
$$

The discriminator loss is the same, we have to change generator loss.

In [None]:
def gen_nonsaturating_loss_fn(generator, discriminator, x):
    # ====
    # your code
    # sample data from generator (number of samples = x.shape[0])

    # ====

In [None]:
# ====
# your code
# choose these parameters
BATCH_SIZE = 
GEN_HIDDENS = 
DISCR_HIDDENS =
EPOCHS = 
LR = 
DISCR_STEPS = 
# ====

train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

# model
generator = MLPGenerator(latent_dim=1, hiddens=GEN_HIDDENS, data_dim=1, use_sigmoid=True)
discriminator = MLPDiscriminator(data_dim=1, hiddens=DISCR_HIDDENS, use_sigmoid=True)

# train
train_losses = train_gan(
    generator, 
    discriminator, 
    gen_nonsaturating_loss_fn, 
    discr_loss_fn, 
    train_loader, 
    epochs=EPOCHS,
    lr=LR,
    discr_steps=DISCR_STEPS,
    use_tqdm=True,
    use_cuda=USE_CUDA
)

plot_losses(train_losses['discriminator_losses'], 'Discriminator loss')
plot_losses(train_losses['generator_losses'], 'Generator loss')

In [None]:
samples, xs, discr_output = make_inference(generator, discriminator)
plot_results(train_data, samples, xs, discr_output, 'Results')

# Task 3: WGAN vs WGAN-GP on 2d data (3pt)

Here your task is to reproduce the experiment with a toy dataset from [WGAN-GP](https://arxiv.org/pdf/1704.00028.pdf) paper.

Let generate data.

In [None]:
def generate_2d_data(size):
    scale = 2
    var = 0.02
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]
    
    centers = [(scale * x, scale * y) for x, y in centers]
    dataset = []
    
    for i in range(size):    
        point = np.random.randn(2) * var
        center = centers[np.random.choice(np.arange(len(centers)))]
        point[0] += center[0]
        point[1] += center[1]
        dataset.append(point)
        
    dataset = np.array(dataset, dtype='float32')
    dataset /= 1.414  # stdev
    
    return dataset


def visualize_2d_data(train_data, train_labels=None):
    plt.figure(figsize=(6, 6))
    plt.title('train data', fontsize=14)
    plt.scatter(train_data[:, 0], train_data[:, 1], s=1, c=train_labels)
    plt.show()

In [None]:
COUNT = 20000

train_data = generate_2d_data(COUNT)
visualize_2d_data(train_data)

The data has lots of separate modes. Our goal is to compare WGAN with WGAN-GP.

In [None]:
def make_inference(generator, critic, n_samples=5000):
    generator.eval()
    critic.eval()
    xs = np.linspace(-3.0, 3.0, 1000 + 1)
    xg, yg = np.meshgrid(xs, xs)
    grid = np.concatenate((xg.reshape(-1, 1), yg.reshape(-1, 1)), axis=-1)
    
    with torch.no_grad():
        samples = generator.sample(n_samples).cpu().detach().numpy()
        critic_output = critic(torch.FloatTensor(grid).cuda()).cpu().detach().numpy()
    
    critic_output = np.prod(critic_output, axis=-1).reshape((1000 + 1, 1000 + 1))
    return samples, grid, critic_output


def visualize_critic_output(samples, grid, critic_output, npts=100 + 1):
    plt.figure(figsize=(6, 6))
    plt.gca().set_aspect("equal")
    
    npts = critic_output.shape[0]
    cnt = plt.contourf(
        grid[:, 0].reshape((npts, npts)), grid[:, 1].reshape((npts, npts)), critic_output,
        levels=25, cmap="cividis"
    )
    plt.scatter(samples[:, 0], samples[:, 1], marker=".", color="red", s=0.5)
    plt.colorbar(cnt)

## WGAN

[WGAN](https://arxiv.org/abs/1701.07875) model uses weight clipping to enforce Lipschitzness.

In [None]:
# WGAN train loop

def train_wgan(
    generator, 
    critic, 
    train_loader,
    critic_steps, 
    batch_size,
    n_epochs,
    lr, 
    clip_c,
    use_cuda
):
    train_losses = []

    if use_cuda:
        generator = generator.cuda()
        critic = critic.cuda()

    gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0, 0.9))
    critic_optimizer = torch.optim.Adam(critic.parameters(), lr=lr, betas=(0, 0.9))

    generator.train()
    critic.train()

    curr_iter = 0
    d_loss, g_loss = torch.zeros(1), torch.zeros(1)
    
    batch_loss_history = {'discriminator_losses': [], 'generator_losses': []}
    for epoch_i in range(n_epochs):

        with tqdm(train_loader, desc=f"epoch {epoch_i}", leave=False) as pbar:
            for (batch_i, x) in enumerate(pbar):
                curr_iter += 1
                if use_cuda:
                    x = x.cuda()

                # do a critic update
                with torch.no_grad():
                    fake_data = generator.sample(x.shape[0])
                
                critic_optimizer.zero_grad()
                
                # ====
                # your code
                # compute discriminator loss as D(x_fake) - D(x_real)

                # ====
                d_loss.backward()
                critic_optimizer.step()
                critic.clip_weights(clip_c)
                pbar.set_postfix({"D loss": d_loss.item(), "G loss": g_loss.item()})

                # generator update
                if curr_iter % critic_steps == 0:
                    gen_optimizer.zero_grad()
                    fake_data = generator.sample(batch_size)
                    
                    # ====
                    # your code
                    # compute generator loss as -D(x_fake)

                    # ====
                    g_loss.backward()
                    gen_optimizer.step()
                    pbar.set_postfix({"D loss": d_loss.item(), "G loss": g_loss.item()})

                    batch_loss_history['generator_losses'].append(g_loss.data.cpu().numpy())
                    batch_loss_history['discriminator_losses'].append(d_loss.data.cpu().numpy())

    return batch_loss_history

In [None]:
# ====
# your code
# choose these parameters
BATCH_SIZE = 
GEN_HIDDENS = 
DISCR_HIDDENS = 
CRITIC_STEPS = 
LR = 
CLIP_C = 
# ====

N_EPOCHS = 200

train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

generator = MLPGenerator(latent_dim=16, hiddens=GEN_HIDDENS, data_dim=2, use_sigmoid=False)
critic = MLPCritic(data_dim=2, hiddens=DISCR_HIDDENS, use_sigmoid=False)

train_losses = train_wgan(
    generator, 
    critic, 
    train_loader,
    critic_steps=CRITIC_STEPS, 
    batch_size=BATCH_SIZE, 
    n_epochs=N_EPOCHS,
    lr=LR,
    clip_c=CLIP_C,
    use_cuda=USE_CUDA
)

plot_losses(train_losses['discriminator_losses'], 'Critic loss')
plot_losses(train_losses['generator_losses'], 'Generator loss')

In [None]:
samples, grid, critic_output = make_inference(generator, critic)
visualize_critic_output(samples, grid, critic_output)

## WGAN-GP

[WGAN-GP](https://arxiv.org/pdf/1704.00028.pdf)  model uses gradient penalty to enforce Lipschitzness.

In [None]:
class MLPCritic(MLPDiscriminator):
    def clip_weights(self, c):
        for layer in self.mlp.net:
            if isinstance(layer, nn.Linear):
                layer.weight = nn.Parameter(torch.clamp(layer.weight, -c, c))

In [None]:
def gradient_penalty(critic, real_data, fake_data):
    batch_size = real_data.shape[0]

    # ====
    # your code
    # calculate interpolation x_t = t * x_real + (1 - t) x_fake
    # 1) sample t
    # 2) create x_t (be careful about shapes)


    # ====

    # ====
    # your code
    # apply critic to x_t


    # ====
    
    gradients = torch.autograd.grad(
        outputs=d_output, 
        inputs=interpolated, 
        grad_outputs=torch.ones(d_output.size()).cuda(), 
        create_graph=True, 
        retain_graph=True
    )[0]

    gradients = gradients.reshape(batch_size, -1)
    # ====
    # your code
    # compute gradient norm

    # ====
    return ((gradients_norm - 1) ** 2).mean()


def train_wgan_gp(
    generator, 
    critic, 
    train_loader,
    critic_steps, 
    batch_size,
    n_epochs,
    lr, 
    gp_weight=10,
    use_cuda=False
):

    if use_cuda:
        critic = critic.cuda()
        generator = generator.cuda()
    critic.train()
    generator.train()

    gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0, 0.9))
    critic_optimizer = torch.optim.Adam(critic.parameters(), lr=lr, betas=(0, 0.9))

    curr_iter = 0
    d_loss, g_loss = torch.zeros(1), torch.zeros(1)
    batch_loss_history = {'discriminator_losses': [], 'generator_losses': []}
    for epoch_i in range(n_epochs):
        with tqdm(train_loader, desc=f"epoch {epoch_i}", leave=False) as pbar:
            for batch_i, x in enumerate(pbar):
                curr_iter += 1
                if use_cuda:
                    x = x.cuda()

                # do a critic update
                critic_optimizer.zero_grad()
                fake_data = generator.sample(x.shape[0])

                # ====
                # your code
                # compute discriminator loss of D(x_fake) - D(x_real) + gp_weight * grad_pen

                # ====

                d_loss.backward()
                critic_optimizer.step()
                pbar.set_postfix({"D loss": d_loss.item(), "G loss": g_loss.item()})
                # generator update
                if curr_iter % critic_steps == 0:
                    gen_optimizer.zero_grad()
                    fake_data = generator.sample(batch_size)
                    # ====
                    # your code
                    # compute generator loss as -D(x_fake)

                    # ====
                    g_loss.backward()
                    gen_optimizer.step()
                    pbar.set_postfix({"D loss": d_loss.item(), "G loss": g_loss.item()})

                    batch_loss_history['generator_losses'].append(g_loss.data.cpu().numpy())
                    batch_loss_history['discriminator_losses'].append(d_loss.data.cpu().numpy())

    return batch_loss_history

In [None]:
# ====
# your code
# choose these parameters
BATCH_SIZE = 
GEN_HIDDENS = 
DISCR_HIDDENS = 
CRITIC_STEPS = 
LR = 
GP_WEIGHT = 
# ====

N_EPOCHS = 200
train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

generator = MLPGenerator(latent_dim=16, hiddens=GEN_HIDDENS, data_dim=2, use_sigmoid=False)
critic = MLPCritic(data_dim=2, hiddens=DISCR_HIDDENS, use_sigmoid=False)

losses = train_wgan_gp(
    generator, 
    critic, 
    train_loader,
    CRITIC_STEPS, 
    batch_size=BATCH_SIZE, 
    n_epochs=N_EPOCHS,
    lr=LR,
    gp_weight=GP_WEIGHT,
    use_cuda=USE_CUDA
)

plot_losses(train_losses['discriminator_losses'], 'Critic loss')
plot_losses(train_losses['generator_losses'], 'Generator loss')

In [None]:
samples, grid, critic_output = make_inference(generator, critic)
visualize_critic_output(samples, grid, critic_output)

# Task 4: WGAN-GP for CIFAR 10 (4pt)

In this task you will fit [Wasserstein GAN](https://arxiv.org/abs/1701.07875) with [Gradient Penalty](https://arxiv.org/pdf/1704.00028.pdf) model to the CIFAR10 dataset (download it from [here](https://drive.google.com/file/d/16j3nrJV821VOkkuRz7aYam8TyIXLnNme/view?usp=sharing)).  

In [None]:
def load_pickle(path, flatten=False, binarize=False):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    train_data = data['train'].astype('float32')
    test_data = data['test'].astype('float32')
    if binarize:
        train_data = (train_data > 128).astype('float32')
        test_data = (test_data > 128).astype('float32')
    else:
        train_data = train_data / 255.
        test_data = test_data / 255.
    train_data = np.transpose(train_data, (0, 3, 1, 2))
    test_data = np.transpose(test_data, (0, 3, 1, 2))
    if flatten:
        train_data = train_data.reshape(len(train_data.shape[0]), -1)
        test_data = test_data.reshape(len(train_data.shape[0]), -1)
    return train_data, test_data


def show_samples(samples, title, nrow=10):
    samples = torch.FloatTensor(samples)
    grid_img = make_grid(samples, nrow=nrow)
    plt.figure()
    plt.title(title)
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.axis('off')
    plt.show()


def visualize_data(data, title):
    idxs = np.random.choice(len(data), replace=False, size=(100,))
    images = train_data[idxs]
    show_samples(images, title)

In [None]:
train_data, test_data = load_pickle(os.path.join('drive', 'MyDrive', 'DGM', 'homework_supplementary', 'cifar10.pkl'))
visualize_data(train_data, 'CIFAR10 samples')

Here we will use convolution-based generator and discriminator.

In [None]:
class ConvGenerator(nn.Module):
    def __init__(self, input_size=128, n_channels=64):
        super().__init__()
        self.n_channels = n_channels
        self.input_size = input_size
        # ====
        # your code
        # 1) define linear layer with output units 4 * 4 * 4 * n_channels, then relu
        # 2) define transposed conv with stride 2, kernel size 2 then BN, then relu
        # 3) define transposed conv with stride 2, kernel size 2 then BN, then relu

        # ====

    def forward(self, input):
        # ====
        # your code
        # 1) apply all layers
        # 2) the output should be in the range of [0, 1] (apply activation) 


        # ====
        return output.view(-1, 3, 32, 32)

    def sample(self, n_samples):
        # ====
        # your code
        # sample from standard normal distribution and apply the model

        # ====


class ConvCritic(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.n_channels = n_channels
        
        # ====
        # your code
        # sequence of Conv2D with kernel size 3 stride 2 and LeakyRelU
        
        # ====

    def forward(self, x):
        # ====
        # your code
        # apply all layers

        # ====
        return output

In [None]:
# ====
# your code
# choose these parameters (you have to train the model more than 20 epochs to get good results)
BATCH_SIZE = 
N_CHANNELS = 
N_EPOCHS = 
CRITIC_STEPS = 
GP_WEIGHT = 
LR = 
# ====

train_data, test_data = load_pickle(os.path.join('drive', 'MyDrive', 'DGM', 'homework_supplementary', 'cifar10.pkl'))

train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

generator = ConvGenerator(n_channels=N_CHANNELS)
critic = ConvCritic(n_channels=N_CHANNELS)

train_losses = train_wgan_gp(
    generator, 
    critic, 
    train_loader,
    critic_steps=CRITIC_STEPS, 
    batch_size=BATCH_SIZE, 
    n_epochs=N_EPOCHS,
    lr=LR,
    gp_weight=GP_WEIGHT,
    use_cuda=USE_CUDA
)

plot_losses(train_losses['discriminator_losses'], 'Discriminator loss')
plot_losses(train_losses['generator_losses'], 'Generator loss')

In [None]:
generator.eval()
critic.eval()
with torch.no_grad():
    samples = generator.sample(1000)
    samples = samples.cpu().detach().numpy()
    

show_samples(samples[:100], title='CIFAR-10 generated samples')