# Adversarial AutoEncoder

In [0]:
def get_free_gpu():
    from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlDeviceGetCount
    nvmlInit()

    return np.argmax([
        nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)).free
        for i in range(nvmlDeviceGetCount())
    ])

In [0]:
import numpy as np
import torch
import torchvision

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from IPython import display

if torch.cuda.is_available():
    cuda_id = get_free_gpu()
    DEVICE = 'cuda:%d' % (get_free_gpu(), )
    print('Selected %s' % (DEVICE, ))
else:
    DEVICE = 'cpu'
    print('WARNING: using cpu!')

### please, don't remove the following line
x = torch.tensor([1], dtype=torch.float32).to(DEVICE)

In [0]:
def show(filename):
    from IPython import display
    try:
        display.display(
            display.Image(filename=filename)
        )
    except:
        pass

## Loading data

In [0]:
def one_hot(y, n_classes=10):
    y_ = np.zeros(shape=(y.shape[0], n_classes), dtype='float32')
    y_[np.arange(y.shape[0]), y] = 1
    
    return y_

In [0]:
from torchvision.datasets import MNIST

ds_train = MNIST("../../data/", train=True, download=True)
ds_test = MNIST("../../data/", train=False, download=True)

data_train = \
    ds_train.data.reshape(-1, 1, 28, 28).detach().numpy().astype(np.float32) / 255

labels_train = ds_train.targets.detach().numpy()

### to make everything fast we transfer the entire dataset into GPU
X_train = torch.tensor(data_train, dtype=torch.float32, device=DEVICE)
y_train = torch.tensor(labels_train, dtype=torch.long, device=DEVICE)
y_one_hot_train = torch.tensor(one_hot(labels_train), dtype=torch.float32, device=DEVICE)

X_avg = torch.mean(X_train, dim=0)
MSE_baseline = torch.mean((X_train - X_avg[None, :, :, :]) ** 2)

data_test = \
    ds_test.data.reshape(-1, 1, 28, 28).detach().numpy().astype(np.float32) / 255

labels_test = ds_test.targets.detach().numpy()

X_test = torch.tensor(data_test, dtype=torch.float32, device=DEVICE)
y_test = torch.tensor(labels_test, dtype=torch.long, device=DEVICE)
y_one_hot_test = torch.tensor(one_hot(labels_test), dtype=torch.float32, device=DEVICE)

dataset_test = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X_test, y_test),
    batch_size=32
)

In [0]:
plt.figure(figsize=(12, 6), dpi=100)
plt.axis('off')
_ = plt.imshow(
    np.concatenate(
        np.concatenate(data_train[:200].reshape(20, 10, 28, 28), axis=2),
        axis=0
    ),
    cmap=plt.cm.Greys
)

## Models

In [0]:
class View(torch.nn.Module):
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [0]:
### Fully convolutional network
### aka encoder
### E: (None, 1, 28, 28), (None, additional_z_size) -> (None, code_size)
class Inference(torch.nn.Module):
    def __init__(self, n, code_size, xi_size=None):
        super(Inference, self).__init__()
        
        self.image_embedding = [
            ### 26 x 26
            torch.nn.Conv2d(1, 2 * n, kernel_size=3, stride=1), torch.nn.LeakyReLU(),
            ### 24 x 24
            torch.nn.Conv2d(2 * n, 2 * n, kernel_size=3, stride=1), torch.nn.LeakyReLU(),
            ### 12 x 12, conv pooling
            torch.nn.Conv2d(2 * n, 2 * n, kernel_size=2, stride=2), torch.nn.LeakyReLU(),

            ### 10 x 10
            torch.nn.Conv2d(2 * n, 3 * n, kernel_size=3, stride=1), torch.nn.LeakyReLU(),
            ### 8 x 8
            torch.nn.Conv2d(3 * n, 3 * n, kernel_size=3, stride=1), torch.nn.LeakyReLU(),
            ### 4 x 4, conv pooling
            torch.nn.Conv2d(3 * n, 3 * n, kernel_size=2, stride=2), torch.nn.LeakyReLU(),
            

            ### 2 x 2
            torch.nn.Conv2d(3 * n, 4 * n, kernel_size=3, stride=1), torch.nn.LeakyReLU(),
            ### 1 x 1
            torch.nn.Conv2d(4 * n, code_size, kernel_size=2, stride=1),

            torch.nn.Flatten()
        ]
        
        for i, f in enumerate(self.image_embedding):
            self.add_module('img_embedding%d' % (i, ), f)
        
        xi_size = 0 if xi_size is None else xi_size
        
        self.combined = [
            torch.nn.Linear(code_size + xi_size, 2 * code_size), torch.nn.LeakyReLU(),
            torch.nn.Linear(2 * code_size, code_size)
        ]
        
        for i, f in enumerate(self.combined):
            self.add_module('combined%d' % (i, ), f)

    def forward(self, x, z=None):
        for f in self.image_embedding:
            x = f(x)
        
        if z is not None:
            x = torch.cat([x, z], dim=1)
        
        for f in self.combined:
            x = f(x)
        
        return x

In [0]:
### aka decoder
### G: (None, code_size) -> (None, 1, 28, 28)
class Generator(torch.nn.Module):
    def __init__(self, n, code_size):
        super(Generator, self).__init__()
        
        self.modules = [
            torch.nn.Linear(code_size, 4 * n),
            View(-1, 4 * n, 1, 1),
            
            ### 2 x 2
            torch.nn.ConvTranspose2d(4 * n, 4 * n, kernel_size=2, stride=1),
            ### 4 x 4
            torch.nn.ConvTranspose2d(4 * n, 3 * n, kernel_size=3, stride=1),
            torch.nn.LeakyReLU(),

            ### 8 x 8
            torch.nn.ConvTranspose2d(3 * n, 3 * n, kernel_size=2, stride=2),
            torch.nn.LeakyReLU(),
            ### 10 x 10
            torch.nn.ConvTranspose2d(3 * n, 3 * n, kernel_size=3, stride=1),
            torch.nn.LeakyReLU(),
            ### 12 x 12
            torch.nn.ConvTranspose2d(3 * n, 2 * n, kernel_size=3, stride=1),
            torch.nn.LeakyReLU(),

            ### 24 x 24
            torch.nn.ConvTranspose2d(2 * n, 2 * n, kernel_size=2, stride=2),
            torch.nn.LeakyReLU(),
            ### 26 x 26
            torch.nn.ConvTranspose2d(2 * n, 2 * n, kernel_size=3, stride=1),
            torch.nn.LeakyReLU(),
            ### 28 x 28
            torch.nn.ConvTranspose2d(2 * n, 1, kernel_size=3, stride=1),
        ]
        
        for i, f in enumerate(self.modules):
            self.add_module('f%d' % (i, ), f)
    
    def forward(self, z):
        x = z

        for f in self.modules:
            x = f(x)
        
        return x

In [0]:
### aka critic
class Discriminator(torch.nn.Module):
    def __init__(self, n, input_size, n_outputs=None):
        super(Discriminator, self).__init__()
        
        self.fs = [
            torch.nn.Linear(input_size, 2 * n),
            torch.nn.LeakyReLU(),
            
            torch.nn.Linear(2 * n, n),
            torch.nn.LeakyReLU(),
        ]
        
        if n_outputs is None:
            self.fs.append(torch.nn.Linear(n, 1))
        else:
            self.fs.append(torch.nn.Linear(n, n_outputs))
                
        
        self.n_outputs = n_outputs
        
        for i, f in enumerate(self.fs):
            self.add_module('f%d' % (i, ), f)

    def forward(self, x):
        for f in self.fs:
            x = f(x)
        
        if self.n_outputs is None:
            return x.view(-1)
        else:
            return x

In [0]:
def iterate(f, n_epoches, n_steps, callback=None):
    losses = np.zeros((n_epoches, n_batches), dtype=np.float32)

    primary_pbar = tqdm(total=n_epoches, leave=False)
    secondary_pbar = tqdm(total=n_steps, leave=False)

    for i in range(n_epoches):
        secondary_pbar.reset()

        for j in range(n_steps):
            losses[i, j] = f()

            secondary_pbar.update()

        primary_pbar.update()
        if callback is not None:
            callback()

    secondary_pbar.close()
    primary_pbar.close()
    
    return losses

## Training procedures

In [0]:
batch_size = 16

### discriminator is a light-weight network, thus,
### can easily handle large batches
batch_size_discr_real = 32
batch_size_discr_prior = 128

n_epoches = 4

n_batches = len(data_train) // batch_size

n = 16
code_size = 16
xi_size = 4

In [0]:
def logit_binary_crossentropy(predictions_positive, predictions_negative):
    """
    Accepts logits (output of a network before sigmoid or softmax) and returns cross-entropy loss.
    - predictions_positive - predictions on real samples (y = 1);
    - predictions_negative - predictions on generated samples (y = 0);
    """

    ### -log sigmoid(p) = log( 1 + exp(-p) ) = softplus(-p)
    return torch.mean(
        torch.nn.functional.softplus(-predictions_positive)
    ) + torch.mean(
        torch.nn.functional.softplus(predictions_negative)
    )

### Task 1

- implement training procedure for the discriminator;
- implement training procedure for the autoencoder.

![AAE](../../img/AAE.png)

In [0]:
def get_step_discriminator(inference, discriminator, opt_discriminator):
    def step():
        with torch.no_grad():
            indx = torch.randint(low=0, high=X_train.shape[0], size=(batch_size_discr_real, ), device=DEVICE)
            X_real = X_train[indx]
            
            ### xi makes inference stochastic
            xi = torch.randn(X_real.shape[0], xi_size, device=DEVICE)
            Z_inferred = inference(X_real, xi)
            
            Z_prior = torch.randn(batch_size_discr_prior, code_size, device=DEVICE)
        
        opt_discriminator.zero_grad()
        
        # YOUR CODE HERE
        raise NotImplementedError()

        loss.backward()
        opt_discriminator.step()

        return loss.item()
    
    return step

In [0]:
def get_step_AE(generator, inference, opt_AE, discriminator, alpha=1e-1):
    def step():
        with torch.no_grad():
            indx = torch.randint(low=0, high=X_train.shape[0], size=(batch_size, ), device=DEVICE)
            X_real = X_train[indx]
            xi = torch.randn(X_real.shape[0], xi_size, device=DEVICE)
        
        opt_AE.zero_grad()
        
        ### loss now consists of two term - mse and penalty
        
        # YOUR CODE HERE
        raise NotImplementedError()
        
        loss.backward()
        opt_AE.step()

        return mse.item()

    return step

In [0]:
def get_step_AAE(step_discriminator, step_AE, discriminator_steps = 4):
    def step():
        for _ in range(discriminator_steps):
            step_discriminator()

        return step_AE()
    
    return step

## Building Generator and Inference

In [0]:
generator = Generator(n, code_size).to(DEVICE)
inference = Inference(n, code_size, xi_size).to(DEVICE)
discriminator = Discriminator(n, code_size).to(DEVICE)

In [0]:
### checks if shapes are correct
X_real = X_train[:10]
xi = torch.randn(10, xi_size, device=DEVICE)

Z_inferred = inference(X_real, xi)

print('Z inferred shape', Z_inferred.shape)

X_generated = generator(Z_inferred)

print('X generated shape:', X_generated.shape)

p_neg = discriminator(Z_inferred)

print('discriminator shape:', p_neg.shape)

In [0]:
def inspect():
    m = 20
    with torch.no_grad():
        Z = torch.randn(m, code_size, device=DEVICE) 
        X_gen = generator(Z)
        
        xi = torch.randn(m, xi_size, device=DEVICE)
        X_rec = generator(inference(X_gen, xi))
        
        X_original = X_train[:m]
        xi = torch.randn(m, xi_size, device=DEVICE) 
        X_rec_original = generator(inference(X_original, xi))
        

    plt.figure(figsize=(m * 2, 8))
    plt.axis('off')
    plt.imshow(
        np.concatenate(
            np.concatenate([
                    X_gen.cpu().numpy().reshape(m, 28, 28),
                    X_rec.cpu().numpy().reshape(m, 28, 28),
                    X_original.cpu().numpy().reshape(m, 28, 28),
                    X_rec_original.cpu().numpy().reshape(m, 28, 28),
                ],
                axis=1
            ),
            axis=1
        ),
        vmin=0, vmax=1,
        cmap=plt.cm.Greys
    )
    plt.show()

In [0]:
inspect()

In [0]:
### note that we use separate optimizers for pretraining as
### optimization tasks are different.
opt_AE = torch.optim.Adam(
    lr=2e-4,
    params=list(inference.parameters()) + list(generator.parameters()),
)

opt_discr = torch.optim.Adam(
    lr=2e-3, weight_decay=1e-3,
    params=discriminator.parameters(),
)

In [0]:
step_discr = get_step_discriminator(inference, discriminator, opt_discr)

### pretraining discriminator
_ = iterate(step_discr, 8, 128)

In [0]:
step_AE = get_step_AE(
    generator, inference,
    opt_AE=opt_AE,
    discriminator=discriminator,
    alpha=1
)

step_adv = get_step_AAE(
    step_discriminator=step_discr,
    step_AE=step_AE,
    discriminator_steps=8
)

In [0]:
losses_AAE = iterate(
    step_adv, n_epoches=4, n_steps=n_batches,
    callback=inspect
)

In [0]:
plt.figure(figsize=(6, 4))
plt.plot(np.mean(losses_AAE, axis=1))
plt.xlabel('epoch')
plt.ylabel('MSE')
plt.show()

In [0]:
codes = list()
errors = list()


with torch.no_grad():
    for X_batch, _ in dataset_test:
        xi = torch.randn(X_batch.shape[0], xi_size, device=DEVICE)
        z = inference(X_batch, xi)
        codes.append(z.cpu().numpy())
        X_rec = generator(z)
        errors.append(
            torch.mean(
                ((X_rec - X_batch) ** 2).view(X_batch.shape[0], -1) / MSE_baseline,
                dim=1
            ).cpu().numpy()
        )

codes = np.concatenate(codes, axis=0)
errors = np.concatenate(errors, axis=0)

show('../../img/AAE-1.png')

if np.mean(errors) > 0.5:
    raise ValueError('Reconstruction error is too high [%.2lf]!' % (np.mean(errors), ))
else:
    show('../../img/AAE-2.png')

if np.any(np.abs(np.mean(codes, axis=0)) > 2.5e-1):
    raise ValueError('Latent variables are biased!\nmean = %s' % (np.mean(codes, axis=1)))
elif np.any(np.std(codes, axis=0) > 1.5) or np.any(np.std(codes, axis=0) < 0.5):
    raise ValueError('The variance of latent variables is too high!\nstd = %s' % (np.std(codes, axis=1)))
else:
    show('../../img/AAE-3.png')

In [0]:
plt.scatter(np.arange(code_size), np.mean(codes, axis=0))
plt.errorbar(np.arange(code_size), np.mean(codes, axis=0), np.std(codes, axis=0))

In [0]:
plt.figure(figsize=(9, 6))
_ = plt.hist(
    [codes[:, i] for i in range(code_size)],
    bins=50,
    histtype='step'
)
plt.title('Distribution of latent variables')