# Adversarial AutoEncoder (extra)

In [0]:
def get_free_gpu():
    from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlDeviceGetCount
    nvmlInit()

    return np.argmax([
        nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)).free
        for i in range(nvmlDeviceGetCount())
    ])

In [0]:
import numpy as np
import torch
import torchvision

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from IPython import display

if torch.cuda.is_available():
    cuda_id = get_free_gpu()
    DEVICE = 'cuda:%d' % (get_free_gpu(), )
    print('Selected %s' % (DEVICE, ))
else:
    DEVICE = 'cpu'
    print('WARNING: using cpu!')

### please, don't remove the following line
x = torch.tensor([1], dtype=torch.float32).to(DEVICE)

In [0]:
def show(filename):
    from IPython import display
    try:
        display.display(
            display.Image(filename=filename)
        )
    except:
        pass

## Loading data

In [0]:
def one_hot(y, n_classes=10):
    y_ = np.zeros(shape=(y.shape[0], n_classes), dtype='float32')
    y_[np.arange(y.shape[0]), y] = 1
    
    return y_

In [0]:
from torchvision.datasets import MNIST

ds_train = MNIST("../../data/", train=True, download=True)
ds_test = MNIST("../../data/", train=False, download=True)

data_train = \
    ds_train.data.reshape(-1, 1, 28, 28).detach().numpy().astype(np.float32) / 255

labels_train = ds_train.targets.detach().numpy()

### to make everything fast we transfer the entire dataset into GPU
X_train = torch.tensor(data_train, dtype=torch.float32, device=DEVICE)
y_train = torch.tensor(labels_train, dtype=torch.long, device=DEVICE)
y_one_hot_train = torch.tensor(one_hot(labels_train), dtype=torch.float32, device=DEVICE)

X_avg = torch.mean(X_train, dim=0)
MSE_baseline = torch.mean((X_train - X_avg[None, :, :, :]) ** 2)

data_test = \
    ds_test.data.reshape(-1, 1, 28, 28).detach().numpy().astype(np.float32) / 255

labels_test = ds_test.targets.detach().numpy()

X_test = torch.tensor(data_test, dtype=torch.float32, device=DEVICE)
y_test = torch.tensor(labels_test, dtype=torch.long, device=DEVICE)
y_one_hot_test = torch.tensor(one_hot(labels_test), dtype=torch.float32, device=DEVICE)

dataset_test = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X_test, y_test),
    batch_size=32
)

In [0]:
plt.figure(figsize=(12, 6), dpi=100)
plt.axis('off')
_ = plt.imshow(
    np.concatenate(
        np.concatenate(data_train[:200].reshape(20, 10, 28, 28), axis=2),
        axis=0
    ),
    cmap=plt.cm.Greys
)

## Models

In [0]:
class View(torch.nn.Module):
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [0]:
### Fully convolutional network
### aka encoder
### E: (None, 1, 28, 28), (None, additional_z_size) -> (None, code_size)
class Inference(torch.nn.Module):
    def __init__(self, n, code_size, xi_size=None):
        super(Inference, self).__init__()
        
        self.image_embedding = [
            ### 26 x 26
            torch.nn.Conv2d(1, 2 * n, kernel_size=3, stride=1), torch.nn.LeakyReLU(),
            ### 24 x 24
            torch.nn.Conv2d(2 * n, 2 * n, kernel_size=3, stride=1), torch.nn.LeakyReLU(),
            ### 12 x 12, conv pooling
            torch.nn.Conv2d(2 * n, 2 * n, kernel_size=2, stride=2), torch.nn.LeakyReLU(),

            ### 10 x 10
            torch.nn.Conv2d(2 * n, 3 * n, kernel_size=3, stride=1), torch.nn.LeakyReLU(),
            ### 8 x 8
            torch.nn.Conv2d(3 * n, 3 * n, kernel_size=3, stride=1), torch.nn.LeakyReLU(),
            ### 4 x 4, conv pooling
            torch.nn.Conv2d(3 * n, 3 * n, kernel_size=2, stride=2), torch.nn.LeakyReLU(),
            

            ### 2 x 2
            torch.nn.Conv2d(3 * n, 4 * n, kernel_size=3, stride=1), torch.nn.LeakyReLU(),
            ### 1 x 1
            torch.nn.Conv2d(4 * n, code_size, kernel_size=2, stride=1),

            torch.nn.Flatten()
        ]
        
        for i, f in enumerate(self.image_embedding):
            self.add_module('img_embedding%d' % (i, ), f)
        
        xi_size = 0 if xi_size is None else xi_size
        
        self.combined = [
            torch.nn.Linear(code_size + xi_size, 2 * code_size), torch.nn.LeakyReLU(),
            torch.nn.Linear(2 * code_size, code_size)
        ]
        
        for i, f in enumerate(self.combined):
            self.add_module('combined%d' % (i, ), f)

    def forward(self, x, z=None):
        for f in self.image_embedding:
            x = f(x)
        
        if z is not None:
            x = torch.cat([x, z], dim=1)
        
        for f in self.combined:
            x = f(x)
        
        return x

In [0]:
### aka decoder
### G: (None, code_size) -> (None, 1, 28, 28)
class Generator(torch.nn.Module):
    def __init__(self, n, code_size):
        super(Generator, self).__init__()
        
        self.modules = [
            torch.nn.Linear(code_size, 4 * n),
            View(-1, 4 * n, 1, 1),
            
            ### 2 x 2
            torch.nn.ConvTranspose2d(4 * n, 4 * n, kernel_size=2, stride=1),
            ### 4 x 4
            torch.nn.ConvTranspose2d(4 * n, 3 * n, kernel_size=3, stride=1),
            torch.nn.LeakyReLU(),

            ### 8 x 8
            torch.nn.ConvTranspose2d(3 * n, 3 * n, kernel_size=2, stride=2),
            torch.nn.LeakyReLU(),
            ### 10 x 10
            torch.nn.ConvTranspose2d(3 * n, 3 * n, kernel_size=3, stride=1),
            torch.nn.LeakyReLU(),
            ### 12 x 12
            torch.nn.ConvTranspose2d(3 * n, 2 * n, kernel_size=3, stride=1),
            torch.nn.LeakyReLU(),

            ### 24 x 24
            torch.nn.ConvTranspose2d(2 * n, 2 * n, kernel_size=2, stride=2),
            torch.nn.LeakyReLU(),
            ### 26 x 26
            torch.nn.ConvTranspose2d(2 * n, 2 * n, kernel_size=3, stride=1),
            torch.nn.LeakyReLU(),
            ### 28 x 28
            torch.nn.ConvTranspose2d(2 * n, 1, kernel_size=3, stride=1),
        ]
        
        for i, f in enumerate(self.modules):
            self.add_module('f%d' % (i, ), f)
    
    def forward(self, z):
        x = z

        for f in self.modules:
            x = f(x)
        
        return x

In [0]:
### aka critic
class Discriminator(torch.nn.Module):
    def __init__(self, n, input_size, n_outputs=None):
        super(Discriminator, self).__init__()
        
        self.fs = [
            torch.nn.Linear(input_size, 2 * n),
            torch.nn.LeakyReLU(),
            
            torch.nn.Linear(2 * n, n),
            torch.nn.LeakyReLU(),
        ]
        
        if n_outputs is None:
            self.fs.append(torch.nn.Linear(n, 1))
        else:
            self.fs.append(torch.nn.Linear(n, n_outputs))
                
        
        self.n_outputs = n_outputs
        
        for i, f in enumerate(self.fs):
            self.add_module('f%d' % (i, ), f)

    def forward(self, x):
        for f in self.fs:
            x = f(x)
        
        if self.n_outputs is None:
            return x.view(-1)
        else:
            return x

In [0]:
def iterate(f, n_epoches, n_steps, callback=None):
    losses = np.zeros((n_epoches, n_batches), dtype=np.float32)

    primary_pbar = tqdm(total=n_epoches, leave=False)
    secondary_pbar = tqdm(total=n_steps, leave=False)

    for i in range(n_epoches):
        secondary_pbar.reset()

        for j in range(n_steps):
            losses[i, j] = f()

            secondary_pbar.update()

        primary_pbar.update()
        if callback is not None:
            callback()

    secondary_pbar.close()
    primary_pbar.close()
    
    return losses

## Training procedures

In [0]:
batch_size = 16

### discriminator is a light-weight network, thus,
### can easily handle large batches
batch_size_discr_real = 32
batch_size_discr_prior = 128

n_epoches = 4

n_batches = len(data_train) // batch_size

n = 16
code_size = 16
xi_size = 4

In [0]:
def logit_binary_crossentropy(predictions_positive, predictions_negative):
    """
    Accepts logits (output of a network before sigmoid or softmax) and returns cross-entropy loss.
    - predictions_positive - predictions on real samples (y = 1);
    - predictions_negative - predictions on generated samples (y = 0);
    """

    ### -log sigmoid(p) = log( 1 + exp(-p) ) = softplus(-p)
    return torch.mean(
        torch.nn.functional.softplus(-predictions_positive)
    ) + torch.mean(
        torch.nn.functional.softplus(predictions_negative)
    )

## Feature disentanglement (extra, not graded)

You are task with implementing a feature disentanglement!
Here, features are labels, i.e., which digit is shown on the image.

1. The idea is simple - add another discriminator, that tries to predict labels given latent representation fron the inference network:

$$d' : \mathcal{Z} \to \mathbb{R}^{10}$$

2. Then, using this new discriminator penalize the latent variables for containing information about the labels.
However, generator will have a hard time reconstructing the original input with any information about the label. Thus, we introduce it back --- ground-truth labels (one-hot encoded) are supplied directly to the generator!

$$\begin{eqnarray}
\mathcal{L} &=& \mathrm{MSE} + \alpha \cdot \mathrm{penalty}_Z + \beta \cdot \mathrm{penalty}_y ;\\
\mathrm{MSE} &=& \mathbb{E}_{x, y, \xi} \big( \mathrm{generator}(\mathrm{inference}(x, \xi), y) - x \big)^2;\\
\mathrm{penalty}_Z &=& \mathbb{E}_{x, \xi} \log d(\mathrm{inference}(x, \xi));\\
\mathrm{penalty}_y &=& \mathbb{E}_{x, y, \xi} y \log d'(\mathrm{inference}(x, \xi))
\end{eqnarray}$$

where:
- $y \in \mathbb{R}^{10}$ --- one-hot encoded vector of labels;
- $d$ --- the discriminator trained to distinguish $\mathrm{inference}(x, \xi)$ from $\mathcal{N}^m(0, 1)$;
- $d'$ --- the label discriminator trained to predict labels $y$ given latent variables $\mathrm{inference}(x, \xi)$.

*(note signs)*

You can find more details in the [original paper](https://arxiv.org/abs/1511.05644), section 4. The only difference from the original AAE is that we explicitly penalize latent variabels for containing label information.

Notice, however, that $\beta$ can be set to a low value since $\mathrm{penalty}_y$ has no conflict with the main objective, $\mathrm{MSE}$, i.e. it is possible to achieve minimum of the both loss functions simultaneously.

In [0]:
def get_step_discriminator(inference, discriminator, opt_discriminator):
    def step():
        with torch.no_grad():
            indx = torch.randint(low=0, high=X_train.shape[0], size=(batch_size_discr_real, ), device=DEVICE)
            X_batch = X_train[indx]
            
            ### xi makes inference stochastic
            xi = torch.randn(X_batch.shape[0], xi_size, device=DEVICE)
            Z_inferred = inference(X_batch, xi)
            
            Z_prior = torch.randn(batch_size_discr_prior, code_size, device=DEVICE)
        
        opt_discriminator.zero_grad()
        
        # YOUR CODE HERE
        raise NotImplementedError()

        loss.backward()
        opt_discriminator.step()

        return loss.item()
    
    return step

In [0]:
def get_step_label_discriminator(inference, discriminator, opt_discriminator):
    loss_f = torch.nn.CrossEntropyLoss()
    
    def step():
        with torch.no_grad():
            indx = torch.randint(low=0, high=X_train.shape[0], size=(batch_size, ), device=DEVICE)
            X_batch = X_train[indx]
            y_real = y_train[indx]
            
            xi = torch.randn(X_batch.shape[0], xi_size, device=DEVICE)
            Z_inferred = inference(X_batch, xi)

        opt_discriminator.zero_grad()
        
        # YOUR CODE HERE
        raise NotImplementedError()
        
        loss.backward()
        opt_discriminator.step()

        return loss.item()
    
    return step

In [0]:
def get_step_disentanglement_AE(
    generator, inference, opt_AE,
    discriminator, discriminator_labels,
    alpha=1e-1, beta=2e-2
):
    ce_loss_f = torch.nn.CrossEntropyLoss()

    def step():
        with torch.no_grad():
            indx = torch.randint(low=0, high=X_train.shape[0], size=(batch_size, ), device=DEVICE)
            X_batch = X_train[indx]
            y_batch = y_train[indx]
            y_onehot_batch = y_one_hot_train[indx]
            
            xi = torch.randn(X_batch.shape[0], xi_size, device=DEVICE)
        
        opt_AE.zero_grad()
        
        ### loss now consists of three term - mse, penalty for distribution of Z,
        ### penalty for passing information about the labels.
        
        # YOUR CODE HERE
        raise NotImplementedError()
        
        loss.backward()
        opt_AE.step()

        return mse.item()

    return step

In [0]:
def get_step_AAE(step_discriminator, step_AE, discriminator_steps = 4):
    def step():
        for _ in range(discriminator_steps):
            step_discriminator()

        return step_AE()
    
    return step

In [0]:
discriminator_labels = Discriminator(n, code_size, n_outputs=10).to(DEVICE)
discriminator = Discriminator(n, code_size, n_outputs=None).to(DEVICE)

generator = Generator(n, code_size + 10).to(DEVICE)
inference = Inference(n, code_size, xi_size).to(DEVICE)

In [0]:
opt_discriminator_labels = torch.optim.Adam(
    lr=2e-3, weight_decay=1e-3,
    params=discriminator_labels.parameters()
)

opt_discriminator = torch.optim.Adam(
    lr=2e-3, weight_decay=1e-3,
    params=discriminator.parameters()
)

opt_AE = torch.optim.Adam(
    lr=1e-3,
    params=list(generator.parameters()) + list(inference.parameters())
)

In [0]:
### main discriminator step is unchanged...
step_discriminator = get_step_discriminator(inference, discriminator, opt_discriminator)

step_label_discriminator = get_step_label_discriminator(
    inference, discriminator_labels, opt_discriminator_labels
)

step_AE = get_step_disentanglement_AE(
    generator, inference, opt_AE,
    discriminator=discriminator,
    discriminator_labels=discriminator_labels
)

def step_discriminators():
    step_discriminator()
    step_label_discriminator()

step_disentanglement = get_step_AAE(step_discriminators, step_AE, discriminator_steps=8)

In [0]:
step_disentanglement()

In [0]:
def inspect():
    m = 20
    tensors = []
    with torch.no_grad():
        Z = torch.randn(m, code_size, device=DEVICE) 

        for i in range(10):
            labels = torch.tensor(
                one_hot(np.repeat(i, m)),
                dtype=torch.float32, device=DEVICE
            )
            Z_labels = torch.cat([Z, labels], dim=1)
            X_gen = generator(Z_labels)

            tensors.append(X_gen.cpu().numpy().reshape(m, 28, 28))
        

    plt.figure(figsize=(m * 2, 20))
    plt.axis('off')
    plt.imshow(
        np.concatenate(
            np.concatenate(tensors,axis=1),
            axis=1
        ),
        vmin=0, vmax=1,
        cmap=plt.cm.Greys
    )
    plt.show()

In [0]:
mse = iterate(
    step_disentanglement,
    n_epoches=8, n_steps=n_batches,
    callback=lambda : inspect()
)

In [0]:
m = 20

with torch.no_grad():
    X_original = X_train[:m]
    xi = torch.randn(m, xi_size, device=DEVICE)
    Z_inferred = inference(X_original, xi)

    tensors = [1 - X_original.cpu().numpy().reshape(m, 28, 28)]
    
    for i in range(10):
        labels = torch.tensor(
            one_hot(np.repeat(i, m)),
            dtype=torch.float32, device=DEVICE
        )
        Z_labels = torch.cat([Z_inferred, labels], dim=1)
        X_gen = generator(Z_labels)

        tensors.append(X_gen.cpu().numpy().reshape(m, 28, 28))


plt.figure(figsize=(m * 2, 20))
plt.axis('off')
plt.imshow(
    np.concatenate(
        np.concatenate(tensors,axis=1),
        axis=1
    ),
    vmin=0, vmax=1,
    cmap=plt.cm.Greys
)
plt.show()