In [1]:
from time import time

import torch
import torch.optim as optim
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
import torch.distributions as TD

from scipy.stats import norm

import numpy as np
from matplotlib import pyplot as plt

from tqdm.notebook import tqdm

if torch.cuda.is_available():
    DEVICE = 'cuda'
    GPU_DEVICE = 1
    torch.cuda.set_device(GPU_DEVICE)
else:
    DEVICE='cpu'
# DEVICE='cpu'

import warnings
warnings.filterwarnings('ignore')

# dgm_utils
from dgm_utils import train_model, show_samples, visualize_images
from dgm_utils import visualize_2d_samples, visualize_2d_densities, visualize_2d_data

def reset_seed():
    OUTPUT_SEED = 0xBADBEEF
    torch.manual_seed(OUTPUT_SEED)
    np.random.seed(OUTPUT_SEED)

reset_seed()

In [25]:
from WGAN import WGAN, WGAN_GP, VanillaGAN

In [3]:
def get_simple_model(hiddens):
    assert len(hiddens) > 1

    modules = []
    for in_, out_ in zip(hiddens[:-2], hiddens[1:-1]):
        modules.extend([nn.Linear(in_, out_), nn.ReLU()])

    modules.append(nn.Linear(hiddens[-2], hiddens[-1]))

    return nn.Sequential(*modules)

def plot_gan_data(data_fn, noise_fn, data_pdf=None):
    noise = noise_fn(5000).numpy().flatten()
    target = data_fn(5000).numpy().flatten()

    plt.hist(noise, label='noise', alpha=0.5, density=True, color='b')
    plt.hist(target, label='target', alpha=0.5, density=True, color='g')
    if data_pdf is not None:
        x = np.linspace(-6,6,100)
        plt.plot(x, data_pdf(x), 'g', label='real distibution')

    plt.legend(loc='upper left')
    plt.show()

def visualize_GAN(gan, data_pdf=None):
    size = 500
    x = np.linspace(-6,6,100)
    bins = np.linspace(-6,6,60)
    real_data = gan.data_fn(size)
    noise = gan.noise_fn(size)
    sampled_data = gan.generate_samples(noise)
    
    plt.hist(noise.numpy(), label='noise', alpha=0.5, density=True, color='b', bins=bins)
    plt.hist(real_data.numpy(), label='real data', alpha=0.5, density=True, color='g', bins=bins)
    plt.hist(sampled_data.numpy(), label='G samples', alpha=0.5, density=True, color='r', bins=bins)
    
    if data_pdf is not None:
        plt.plot(x, data_pdf(x), 'g', label='real distibution')
    with torch.no_grad():
        plt.plot(x, gan.D(torch.from_numpy(x).float().unsqueeze(-1)).numpy(), 'b', label='D distibution')
    
    plt.legend(loc='upper left')
    plt.show()

In [50]:
# 2d
from seminar7_utils import make_inference, visualize_GAN_output, FullyConnectedMLP
from WGAN import train_wgan, WGAN_MLPCritic
from wgan_gp import train_wgan_gp

def plot_losses(losses, title):
    n_itr = len(losses)
    xs = np.arange(n_itr)

    plt.figure(figsize=(7, 5))
    plt.plot(xs, losses)
    plt.title(title, fontsize=14)
    plt.xlabel('Iterations', fontsize=14)
    plt.ylabel('Loss', fontsize=14)

    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.show()

In [51]:
class MLPGenerator(FullyConnectedMLP):

    def sample(self, n):
        z = torch.randn(size=(n, self.input_dim)).to(
            next(iter(self.parameters())))
        return self.forward(z)
    
    
class WGAN_MLPGenerator(MLPGenerator):
    pass


class WGANGP_MLPGenerator(MLPGenerator):
    pass


class WGANGP_MLPCritic(FullyConnectedMLP):
    pass

# <center>Deep Generative Models</center>
## <center>Seminar 7</center>

<center>22.10.2024</center>


## Plan

Wasserstein GANs
     
     - Vanilla GAN
     
     - WGAN
     
     - WGAN-GP

# Vanilla GAN

<img src="pics/gan_objective.jpg" width=800 height=800 />

**Practical Note**:

Use **RMSProp** or **Adam** with $\beta_1 = 0$ when training your GAN. Large $\beta_1$ of Adam leads to training instabilities!

In [5]:
mu = 2
noise_fn = lambda x: torch.rand((x, 1), device='cpu')-2
data_fn = lambda x: mu+torch.randn((x, 1), device='cpu')
data_pdf = lambda X: norm.pdf(X-mu)

In [None]:
plot_gan_data(data_fn, noise_fn, data_pdf)

In [7]:
gen_hiddens = [1,64,64,1]
dis_hiddens = [1,64,64,1]
G = get_simple_model(gen_hiddens)
D = nn.Sequential(*get_simple_model(dis_hiddens), nn.Sigmoid())

gan = VanillaGAN(G, D, noise_fn, data_fn, device='cpu')

In [8]:
epochs = 50
batches = 100

In [None]:
step_size = 30
loss_g, loss_d_real, loss_d_fake = [], [], []
start = time()
for epoch in range(epochs):
    #break
    loss_g_running, loss_d_real_running, loss_d_fake_running = 0, 0, 0
    for i,batch in enumerate(range(batches)):
        lg_, (ldr_, ldf_) = gan.train_step()
        #ldr_, ldf_ = gan.train_step_D()
        #if i%step_size == 0:
        #    print(i)
        #    print('D train step')
        #    visualize_GAN(gan)
        #lg_ = gan.train_step_G()
        #if i%step_size == 0:
        #    print('G train step')
        #    visualize_GAN(gan)
        
        loss_g_running += lg_
        loss_d_real_running += ldr_
        loss_d_fake_running += ldf_
    loss_g.append(loss_g_running / batches)
    loss_d_real.append(loss_d_real_running / batches)
    loss_d_fake.append(loss_d_fake_running / batches)
    print(f"Epoch {epoch+1}/{epochs} ({int(time() - start)}s):"
          f" G={loss_g[-1]:.3f},"
          f" Dr={loss_d_real[-1]:.3f},"
          f" Df={loss_d_fake[-1]:.3f}")
    visualize_GAN(gan, data_pdf=data_pdf)

## WGAN

<img src="pics/WD.jpg" width=800 height=800 />

<img src="pics/KRD.jpg" width=800 height=800 />

[WGAN](https://arxiv.org/abs/1701.07875) model uses weight clipping to enforce Lipschitzness of the critic.

The model objective is
$$
\min_{G} W(\pi || p) \approx \min_{G} \max_{\boldsymbol{\phi} \in \boldsymbol{\Phi}} \left[ \mathbb{E}_{\pi(\mathbf{x})} f(\mathbf{x}, \boldsymbol{\phi})  - \mathbb{E}_{p(\mathbf{z})} f(G(\mathbf{z}, \boldsymbol{\theta}), \boldsymbol{\phi} )\right].
$$
Here $f(\mathbf{x}, \boldsymbol{\phi})$ is the critic model. The critic weights $\boldsymbol{\phi}$ should lie in the compact set $\boldsymbol{\Phi} = [-c, c]^d$.

<img src="pics/wgan_alg.jpg" width=800 height=800 />

In [10]:
mu = 2
noise_fn = lambda x: torch.rand((x, 1), device='cpu')-2
data_fn = lambda x: mu+torch.randn((x, 1), device='cpu')
data_pdf = lambda X: norm.pdf(X-mu)

In [11]:
gen_hiddens = [1,64,64,1]
dis_hiddens = [1,64,64,1]
G = get_simple_model(gen_hiddens)
D = get_simple_model(dis_hiddens)

gan = WGAN(G, D, noise_fn, data_fn, device='cpu', n_critic=5, clip_c=0.01)

In [None]:
visualize_GAN(gan, data_pdf=data_pdf)

In [13]:
epochs = 50
batches = 100

In [None]:
loss_g, loss_d_real, loss_d_fake, loss_WD = [], [], [], []
start = time()
for epoch in range(epochs):
    #break
    loss_g_running, loss_d_real_running, loss_d_fake_running, loss_WD_running = 0, 0, 0, 0
    for i,batch in enumerate(range(batches)):
        lg_, (ldr_, ldf_) = gan.train_step()
        loss_g_running += lg_
        loss_d_real_running += ldr_
        loss_d_fake_running += ldf_
        loss_WD_running = ldr_ - ldf_
        
    loss_g.append(loss_g_running / batches)
    loss_d_real.append(loss_d_real_running / batches)
    loss_d_fake.append(loss_d_fake_running / batches)
    loss_WD.append(loss_WD_running / batches)
    
    print(f"Epoch {epoch+1}/{epochs} ({int(time() - start)}s):"
          f" G={loss_g[-1]:.3f},"
          f" Dr={loss_d_real[-1]:.3f},"
          f" Df={loss_d_fake[-1]:.3f},"
            f" WD={loss_WD[-1]:.3f}")
    visualize_GAN(gan, data_pdf=data_pdf)
    #break

In [15]:
params = []
for param in gan.D.parameters():
    params.extend(param.detach().numpy().flatten())

In [None]:
plt.hist(params, bins=100);

### Bimodal distribution

In [17]:
noise_fn = lambda x: torch.rand((x, 1), device='cpu') - 0.5

pi = torch.tensor([0.7, 0.3])
mu = torch.tensor([-3., 3.])
scale = torch.tensor([1., 1.])

mixture_gaussian = TD.MixtureSameFamily(TD.Categorical(pi), TD.Normal(mu, scale))

def data_fn(x):
    return mixture_gaussian.sample((x, 1))

def data_pdf(x):
    return mixture_gaussian.log_prob(torch.tensor(x)).exp().numpy()

In [18]:
gen_hiddens = [1,64,64,1]
dis_hiddens = [1,64,64,1]
G = get_simple_model(gen_hiddens)
D = get_simple_model(dis_hiddens)

gan = WGAN(G, D, noise_fn, data_fn, device='cpu', n_critic=5, clip_c=0.1)

In [None]:
visualize_GAN(gan, data_pdf=data_pdf)

In [None]:
loss_g, loss_d_real, loss_d_fake, loss_WD = [], [], [], []
start = time()
for epoch in range(epochs):
    #break
    loss_g_running, loss_d_real_running, loss_d_fake_running, loss_WD_running = 0, 0, 0, 0
    for i,batch in enumerate(range(batches)):
        lg_, (ldr_, ldf_) = gan.train_step()
        loss_g_running += lg_
        loss_d_real_running += ldr_
        loss_d_fake_running += ldf_
        loss_WD_running = ldr_ - ldf_
        
    loss_g.append(loss_g_running / batches)
    loss_d_real.append(loss_d_real_running / batches)
    loss_d_fake.append(loss_d_fake_running / batches)
    loss_WD.append(loss_WD_running / batches)
    
    print(f"Epoch {epoch+1}/{epochs} ({int(time() - start)}s):"
          f" G={loss_g[-1]:.3f},"
          f" Dr={loss_d_real[-1]:.3f},"
          f" Df={loss_d_fake[-1]:.3f},"
            f" WD={loss_WD[-1]:.3f}")
    visualize_GAN(gan, data_pdf=data_pdf)
    #break

### 2D WGAN

In [21]:
def generate_2d_data(size, var=0.02):
    scale = 2
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]

    centers = [(scale * x, scale * y) for x, y in centers]
    dataset = []

    for i in range(size):
        point = np.random.randn(2) * var
        center = centers[np.random.choice(np.arange(len(centers)))]
        point[0] += center[0]
        point[1] += center[1]
        dataset.append(point)

    dataset = np.array(dataset, dtype='float32')
    dataset /= 1.414  # stdev

    return dataset

In [None]:
reset_seed()
COUNT = 20000

train_data = generate_2d_data(COUNT, var=0.08) # 0.02, 0.1, 0.4
visualize_2d_samples(train_data, "Train data")

In [None]:
# CRITIC_STEPS = 5 => more or less learning
# CRITIC_STEPS = 1 => no learning

reset_seed()
BATCH_SIZE = 1024 # any adequate value
GEN_HIDDENS = [32, 128, 128, 32] # 4 layers with < 128 neurons would be enough
DISCR_HIDDENS = [64, 256, 256, 64] # 4 layers with < 128 neurons would be enough
CRITIC_STEPS = 5 # > 2
LR = 2e-4 # < 1e-2
CLIP_C = 0.05 # < 1

N_EPOCHS = 600 # change it if you want

train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

generator = WGAN_MLPGenerator(16, GEN_HIDDENS, 2).to(DEVICE)
critic = WGAN_MLPCritic(2, DISCR_HIDDENS, 1).to(DEVICE)

train_losses = train_wgan(
    generator, 
    critic, 
    train_loader,
    critic_steps=CRITIC_STEPS, 
    batch_size=BATCH_SIZE, 
    n_epochs=N_EPOCHS,
    lr=LR,
    clip_c=CLIP_C,
    visualize_steps=50,
    use_cuda = DEVICE != 'cpu'
)

plot_losses(train_losses['discriminator_losses'], 'Critic loss')
plot_losses(train_losses['generator_losses'], 'Generator loss')

In [None]:
samples, grid, critic_output, critic_grad_norms = make_inference(generator, critic)
visualize_GAN_output(samples, train_data, grid, critic_output, critic_grad_norms)

In [None]:
params = []
for param in critic.parameters():
    params.extend(param.detach().cpu().numpy().flatten())
plt.hist(params, bins=30)
plt.show()

## WGAN-GP

[WGAN-GP](https://arxiv.org/pdf/1704.00028.pdf)  model uses gradient penalty to enforce Lipschitzness.

The model objective is
$$
    W(\pi || p) = \underbrace{\mathbb{E}_{\pi(\mathbf{x})} f(\mathbf{x})  - \mathbb{E}_{p(\mathbf{x} | \boldsymbol{\theta})} f(\mathbf{x})}_{\text{original critic loss}} + \lambda \underbrace{\mathbb{E}_{U[0, 1]} \left[ \left( \| \nabla_{\hat{\mathbf{x}}} f(\hat{\mathbf{x}}) \|_2 - 1 \right) ^ 2\right]}_{\text{gradient penalty}},
$$
where the samples $\hat{\mathbf{x}}_t = t \mathbf{x} + (1 - t) \mathbf{y}$ with $t \in [0, 1]$ are uniformly sampled along straight lines between pairs of points: $\mathbf{x}$ from the data distribution $\pi(\mathbf{x})$ and $\mathbf{y}$ from the generator distribution $p(\mathbf{x} | \boldsymbol{\theta}))$.

<img src="pics/WGAN-GP_theorem.jpg" width=800 height=800 />

<img src="pics/wgan-gp_alg.jpg" width=800 height=800 />

In [39]:
mu = 2
noise_fn = lambda x: torch.rand((x, 1), device='cpu')-2
data_fn = lambda x: mu+torch.randn((x, 1), device='cpu')
data_pdf = lambda X: norm.pdf(X-mu)

In [48]:
gen_hiddens = [1,64,64,1]
dis_hiddens = [1,64,64,1]
G = get_simple_model(gen_hiddens)
D = get_simple_model(dis_hiddens)

gan = WGAN_GP(G, D, noise_fn, data_fn, device='cpu', n_critic=5, Lambda=1)

In [None]:
visualize_GAN(gan, data_pdf=data_pdf)

In [46]:
epochs = 50
batches = 100

In [None]:
step_size = 30
loss_g, loss_d_real, loss_d_fake, loss_WD, loss_gp = [], [], [], [], []
start = time()
for epoch in range(epochs):
    #break
    loss_g_running, loss_d_real_running, loss_d_fake_running, loss_WD_running, loss_gp_running = 0, 0, 0, 0, 0
    for i,batch in enumerate(range(batches)):
        lg_, (ldr_, ldf_, lgp_) = gan.train_step()
        #ldr_, ldf_ = gan.train_step_D()
        #if i%step_size == 0:
        #    print(i)
        #    print('D train step')
        #    visualize_GAN(gan)
        #lg_ = gan.train_step_G()
        #if i%step_size == 0:
        #    print('G train step')
        #    visualize_GAN(gan)
        
        loss_g_running += lg_
        loss_d_real_running += ldr_
        loss_d_fake_running += ldf_
        loss_gp_running += lgp_
        loss_WD_running = ldr_ - ldf_
        
    loss_g.append(loss_g_running / batches)
    loss_d_real.append(loss_d_real_running / batches)
    loss_d_fake.append(loss_d_fake_running / batches)
    loss_gp.append(loss_gp_running / batches)
    loss_WD.append(loss_WD_running / batches)
    
    print(f"Epoch {epoch+1}/{epochs} ({int(time() - start)}s):"
          f" G={loss_g[-1]:.3f},"
          f" Dr={loss_d_real[-1]:.3f},"
          f" Df={loss_d_fake[-1]:.3f}"
          f" WD={loss_WD[-1]:.3f},"
          f" GP={loss_gp[-1]:.3f}")
    visualize_GAN(gan, data_pdf=data_pdf)
    #break

In [None]:
reset_seed()
BATCH_SIZE = 1024 # any adequate value
GEN_HIDDENS = [32, 128, 128, 32] # 4 layers with < 128 neurons would be enough
DISCR_HIDDENS = [64, 256, 256, 64] # 4 layers with < 128 neurons would be enough
CRITIC_STEPS = 5 # > 2
LR = 2e-4 # < 1e-2
GP_WEIGHT = 10 # > 5

N_EPOCHS = 800 # change it if you want

train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

generator = WGANGP_MLPGenerator(16, GEN_HIDDENS, 2).to(DEVICE)
critic = WGANGP_MLPCritic(2, DISCR_HIDDENS, 1).to(DEVICE)

train_losses = train_wgan_gp(
    generator, 
    critic, 
    train_loader,
    critic_steps=CRITIC_STEPS, 
    batch_size=BATCH_SIZE, 
    n_epochs=N_EPOCHS,
    lr=LR,
    gp_weight=GP_WEIGHT,
    visualize_steps=50,
    use_cuda = DEVICE != 'cpu'
)

plot_losses(train_losses['discriminator_losses'], 'Critic loss')
plot_losses(train_losses['generator_losses'], 'Generator loss')

In [None]:
samples, grid, critic_output, critic_grad_norms = make_inference(generator, critic)
visualize_GAN_output(samples, train_data, grid, critic_output, critic_grad_norms)

## WGAN / WGAN-GP Recap

1. **Lipschitz Condition for the Critic**
   - Required by the **Kantorovich-Rubinstein duality** for Wasserstein Distance (WD).
   - **Key difference**: No weight clipping for the generator!


2. **Important Insights**:
   - **KL & JSD Divergences fail** where WD succeeds (e.g., two parallel lines example).
   - **Weaker topology**: Convergence in JS implies convergence in WD (WD is easier to optimize).
   - Lipschitz continuity is essential for the critic—achieved through **weight clipping**.
   - **Loss correlates with visual quality** of the generated images.


3. **Problems with Clipping**:
   - **Weight sticking to boundaries** causes exploding or vanishing gradients.
   - Refer to the visual example from the WGAN-GP paper.


4. **WGAN-GP: More Robust**
   - Gradient norm between real and generated samples is forced to be **1**.
   - **No weight clipping** for the critic in WGAN-GP.


5. **Why no oscillations ("steps") as in Vanilla GAN?**
   - Thanks to the Lipschitz condition, training is smoother.


6. **WGAN-GP Theorem Illustration**:
   - The gradient norm between real and fake samples remains strictly **1** (e.g., set \(\lambda = 100\)).

## Bonus Chapter

### SNGAN

Spectral Normalization GAN [article](https://arxiv.org/pdf/1802.05957.pdf) replaces the weights in the critic $f(\mathbf{x}, \boldsymbol{\phi})$ by 
$$
    \mathbf{W}^{SN} = \frac{\mathbf{W}}{\|\mathbf{W}\|_2}.
$$

This ensures that $\| f\|_L \leq 1.$.

Power iteration method allows to efficiently compute $\| \mathbf{W} \|_2 = \sqrt{\lambda_{\text{max}}(\mathbf{W}^T \mathbf{W})}$.
    
The pseudocode of the method is:
* $\mathbf{u}_0$ -- random vector.
* for $k = 0, \dots, n - 1$: 
$$
    \mathbf{v}_{k+1} = \frac{\mathbf{W}^T \mathbf{u}_{k}}{\| \mathbf{W}^T \mathbf{u}_{k} \|}, \quad \mathbf{u}_{k+1} = \frac{\mathbf{W} \mathbf{v}_{k+1}}{\| \mathbf{W} \mathbf{v}_{k+1} \|}.
$$
* approximate the spectral norm
$$
    \| \mathbf{W} \|_2 = \sqrt{\lambda_{\text{max}}(\mathbf{W}^T \mathbf{W})} \approx \mathbf{u}_{n}^T \mathbf{W} \mathbf{v}_{n}.
$$


## GANs zoo

### Losses

- Vanilla GAN

    <img src="pics/gan_objective.jpg" width=800 height=800 />

    - Nonsaturating Vanilla GAN

- Wassersteing GAN 

    <img src="pics/WGAN_obj.jpg" width=800 height=800 />
    
    - WGAN-GP

- IPM GAN 

    **IPM** (Integral Probability Metric):
    
    $$
    \gamma_{\mathcal{F}}(\mathbb{P}, \mathbb{Q}) = \sup\limits_{f \in \mathcal{F}} \left\vert \int f d \mathbb{P} - \int f d \mathbb{Q} \right\vert,
    $$ 
    see [Sriperumbudur et. al.](https://arxiv.org/pdf/0901.2698.pdf) for the details on IPM metric.
    
    see [Mroueh et. al.](https://arxiv.org/pdf/1711.04894.pdf) for the examples of IPM GANs.
    
    
- GAN with Hinge loss

    <img src="pics/hinge_loss_GAN.png" width=800 height=800 />
    
    see [Lim et. al.](https://arxiv.org/pdf/1705.02894.pdf)

- fGAN 

    <img src="pics/fgan_loss.png" width=800 height=800 />
    
    article: [Nowozin et. al.](https://arxiv.org/pdf/1606.00709.pdf)
    
    * **Question:** By which parameter we maximize, and by which we minimize?
    
- ...

### Regularizations

- Weight clipping, Gradient penalty in WGAN

- Spectral Normalization (for general GAN architectures)

- $R_1$, $R_2$, $R_3$ regularizations (penalize discriminator gradients) [paper](https://arxiv.org/pdf/1801.04406.pdf), [paper2](https://arxiv.org/pdf/1705.09367.pdf)

- Improved techniques for training GANs [paper](https://arxiv.org/pdf/1606.03498.pdf)

- Orthogonal regularization [paper](https://arxiv.org/pdf/1809.11096.pdf)

    <img src="pics/ortho_reg.png" width=400 height=800 />

- ...