In [None]:
import os

import torch
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader

import numpy as np

import matplotlib.pyplot as plt

from pythae.models import AutoModel
from pythae.data.datasets import DatasetOutput

import tqdm

In [None]:
vae_128 = AutoModel.load_from_folder(
    'my_models_on_cifar/final_model'
    )
vae_128.to('cuda')

In [None]:
# the cifar 10 data set
transform = transforms.Compose(
    [transforms.ToTensor()]
)

test_data = CIFAR10(root='data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [None]:
for img, label in test_loader:
    break

In [None]:
reconstructed_128 = vae_128.reconstruct(img.to('cuda')).detach().cpu()

In [None]:
dataset = DatasetOutput(data=img.to('cuda'))
out = vae_128(dataset)

In [None]:
# from https://github.com/tml-tuebingen/explanations-manifold 
def compute_tangent_space(decoder, z, device='cuda'):
    """ compute the tangent space of a generative model at latent vector z
    
    NetAE: The decoder. A pytorch module that implements decode(z)
    z: pytorch tensor (latent dim)

    batch dimension in z is not supported.
    
    Returns: vectors that span the tangent space (tangent space dim, model output dim). 
             the vectors correspond 1:1 to the latent dimensions of z
    """
    assert len(z.shape) == 1, "compute_tangent_space: batch dimension in z is not supported. z has to be a 1-dimensional vector"
    decoder.to(device)
    z = z.to(device)
    latent_dim = z.shape[0]
    z.requires_grad = True
    out = decoder(z.unsqueeze(0))
    out = out['reconstruction'].squeeze()      # remove singleton batch dimension
    output_shape = out.shape # store original output shape
    out = out.reshape(-1)    # and transform the output into a vector
    tangent_space = torch.zeros((latent_dim, out.shape[0]))
    for i in range(out.shape[0]):
        out[i].backward(retain_graph=True)
        tangent_space[:, i] = z.grad
        z.grad.zero_()
    tangent_space = tangent_space.reshape((-1, *output_shape)) # tangent space in model output shape
    return tangent_space

tangent_space = compute_tangent_space(vae_128.decoder, out['z'].detach()[0])

In [None]:
tangent_space_batch = []
for i in tqdm.tqdm(range(32)):
    tangent_space_batch.append(compute_tangent_space(vae_128.decoder, out['z'].detach()[i]))

In [None]:
def plot_image_grid(images):
    fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

    for i in range(5):
        for j in range(5):
            axes[i][j].imshow(images[i*5 +j].cpu().squeeze(0).numpy().transpose((1,2,0)), cmap='gray')
            axes[i][j].axis('off')
    plt.tight_layout(pad=0.)

In [None]:
# show the original images
plot_image_grid(img)
 
# show the reconstructed images
plot_image_grid(reconstructed_128)

## Figure 7

In [None]:
from scipy.linalg import orth

def project_into_tangent_space(tangent_space, vector):
    BATCH_DIM = tangent_space.shape[1]
    IMG_DIM = tangent_space.shape[2]
    tangent_space_orth = orth(tangent_space.reshape((-1, BATCH_DIM*IMG_DIM*IMG_DIM)).T).T.reshape((-1, BATCH_DIM, IMG_DIM, IMG_DIM))
    dim = tangent_space_orth.shape[0]
    coeff = np.zeros(dim)
    for i in range(dim):
        coeff[i] = tangent_space_orth[i, :, :].flatten() @ vector.flatten()
    vector_in_tangent_space = (coeff @ tangent_space_orth.reshape((dim, -1))).reshape((BATCH_DIM, IMG_DIM, IMG_DIM))
    return vector_in_tangent_space

In [None]:
# project vectors onto the respective tangent spaces
noise = torch.randn_like(img)
tangent_noise = torch.zeros_like(noise)
orthogonal_noise = torch.zeros_like(noise)

def direction_to_image(direction):
    direction /= direction.abs().max()
    direction = (1 + direction) / 2
    return direction

for i in range(32):
    tangent_noise[i] = torch.Tensor(project_into_tangent_space(tangent_space_batch[i].numpy(), noise[i].numpy()))
    orthogonal_noise[i] = noise[i] - tangent_noise[i]
    
    tangent_noise[i] = direction_to_image(tangent_noise[i])
    orthogonal_noise[i] = direction_to_image(orthogonal_noise[i])

plot_image_grid(tangent_noise)
plot_image_grid(orthogonal_noise)

## Estimate the tangent space for all images in the test set of Cifar-10

In [None]:
loader = DataLoader(test_data, batch_size=1, shuffle=False)

test_tangent_spaces = []
for idx, (img, label) in tqdm.tqdm(enumerate(loader)):
    dataset = DatasetOutput(data=img.to('cuda'))
    out = vae_128(dataset)
    test_tangent_spaces.append(compute_tangent_space(vae_128.decoder, out['z'].detach()[0]))
    if idx % 10 == 0:
        torch.save(test_tangent_spaces, f'results/test_tangent_spaces_{idx}.pt')
        if idx > 0:
            os.remove(f'results/test_tangent_spaces_{idx-10}.pt')
    if idx > 999:
        break
torch.save(test_tangent_spaces, f'results/test_tangent_spaces.pt')

In [None]:
np.allclose(test_tangent_spaces[0][0, 0, 0, :], tangent_space_batch[0][0, 0, 0, :])

In [None]:
test_tangent_spaces = torch.load(f'results/test_tangent_spaces.pt')