# Results

In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# import models
from models import Discriminator, INN
import data
import torchvision
import pandas as pd
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from utils import config

c = config.Config()
c.load('./config/default.toml')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def make_cond(labels):
    cond_tensor = torch.zeros(labels.size(0), c.ncl).cuda()
    if c.conditional:
        cond_tensor.scatter_(1, labels.view(-1, 1), 1.)
    else:
        cond_tensor[:, 0] = 1
    return cond_tensor

fill = torch.zeros((10, 10, 32, 32), device=device)
for i in range(10):
    fill[i, i, :, :] = 1

In [2]:
train_loader = torch.utils.data.DataLoader(
    datasets.EMNIST('~/Data', split='digits', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.Pad(2),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,)),
                        transforms.Lambda(lambda x: x.permute(0, 2, 1)),
                    ])),
    batch_size=512, shuffle=False, pin_memory=True, num_workers=4,
    drop_last=True
)
test_loader = torch.utils.data.DataLoader(
    datasets.EMNIST('~/Data', split='digits', train=False, download=True,
                    transform=transforms.Compose([
                        transforms.Pad(2),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,)),
                        transforms.Lambda(lambda x: x.permute(0, 2, 1)),
                    ])),
    batch_size=512, shuffle=True, pin_memory=True, num_workers=4,
    drop_last=True
)
letter_loader = torch.utils.data.DataLoader(
    datasets.EMNIST('~/Data', split='letters', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.Pad(2),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,)),
                        transforms.Lambda(lambda x: x.permute(0, 2, 1)),
                    ])),
    batch_size=512, shuffle=True, pin_memory=True, num_workers=4,
    drop_last=True
)
fashion_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST('~/Data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.Pad(2),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,)),
                        transforms.Lambda(lambda x: x.permute(0, 2, 1)),
                    ])),
    batch_size=512, shuffle=True, pin_memory=True, num_workers=4,
    drop_last=True
)
kmnist_loader = torch.utils.data.DataLoader(
    datasets.KMNIST('~/Data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.Pad(2),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,)),
                        transforms.Lambda(lambda x: x.permute(0, 2, 1)),
                        # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                    ])),
    batch_size=512, shuffle=True, pin_memory=True, num_workers=4,
    drop_last=True
)


In [8]:
models_disc = {
    'base_letters': 'Jun09_15-46-31_GLaDOS_DISCRIMINATOR',
    'base_fashion': 'Jun13_21-01-11_GLaDOS_DISCRIMINATOR',
    'inn': 'Jun09_13-20-28_GLaDOS_DISCRIMINATOR',
#     'inn+minlh': '',
    'deepaa_sq': 'Jun14_15-49-00_GLaDOS_DISCRIMINATOR',
    'deepaa_linear': 'Jun13_20-08-48_GLaDOS_DISCRIMINATOR',
    'deepaa_pre': 'Jun30_18-15-11_GLaDOS_DISCRIMINATOR'
}
models_class = {
    'base': 'Jun14_17-10-51_GLaDOS_CLASSIFIER',
    'base_letters': 'Jun13_21-23-28_GLaDOS_CLASSIFIER',
    'base_fashion': 'Jun13_21-30-23_GLaDOS_CLASSIFIER',
    'inn': 'Jun13_17-30-06_GLaDOS_CLASSIFIER',
#     'inn+minlh': '',
    'deepaa_linear': 'Jun13_20-50-34_GLaDOS_CLASSIFIER',
    'deepaa_sq': 'Jun14_15-55-01_GLaDOS_CLASSIFIER',
    'deepaa_pre': 'Jul01_13-28-04_GLaDOS_CLASSIFIER'
}
models_inn = {
    'inn': 'May13_19-39-26_GLaDOS',
#     'inn+minlh': ''
}
dsets = {
    'mnist_train': train_loader,
    'mnist_test': test_loader,
    'letters': letter_loader,
    'fashion': fashion_loader,
    'kmnist': kmnist_loader,
}
    
results_disc = pd.DataFrame(index=list(models_disc.keys()), columns=list(dsets.keys()), dtype=np.float64)
results_class = pd.DataFrame(
    index=list(models_class.keys()),
    columns=[f'{k}_acc' for k in dsets.keys()] + [f'{k}_conf' for k in dsets.keys()] + [f'{k}_ind' for k in dsets.keys()],
    dtype=np.float64
)

## Comparing Discriminators

In [4]:
def discriminator_test(model, data_loader, t, rand_y=False):
    with torch.no_grad():
        positive = []
        for x, y in tqdm(data_loader):
            if rand_y:
                y = torch.randint(10, (512,))
            x, y = x.to(device), y.to(device)
            output = torch.sigmoid(model(x, fill[y]).reshape(-1))
            positive.append((output >= t).float().mean())

        acc = torch.mean(torch.tensor(positive)).item()
    return acc

In [6]:
for k_d, data_loader in dsets.items():
    for k_m, v_m in models_disc.items():
        if v_m is '':
            continue
            
        model = Discriminator(c, conditional=True)
        pre = '../archetypal_analysis/' if k_m.startswith('deepaa') else ''
        model.load_state_dict(dict(filter(lambda x: 'tmp' not in x[0], torch.load(f'{pre}runs/{v_m}/checkpoints/discriminator.pt').items())))
        model.to(device)
#         model.eval()
        
        # TODO: bootstrap threshold
        t = 0.7
        
        acc = discriminator_test(model, data_loader, t, rand_y=(not k_d.startswith('mnist')))
        
#         print(f'{k_m} on {k_d}: {acc}')
        
        results_disc[k_d][k_m] = acc

HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




In [7]:
results_disc.round(3)

Unnamed: 0,mnist_train,mnist_test,letters,fashion,kmnist
base_letters,1.0,1.0,0.0,0.0,0.004
base_fashion,1.0,1.0,1.0,0.0,0.901
inn,0.989,0.99,0.0,0.0,0.0
deepaa_sq,1.0,1.0,1.0,1.0,1.0
deepaa_linear,0.864,0.865,0.616,0.282,0.178
deepaa_pre,1.0,1.0,0.621,0.0,0.0


## Comparing Classifiers

In [9]:
def classifier_test(model, data_loader, t, rand_y=False):
    with torch.no_grad():
        acc = torch.empty(0, dtype=torch.float, device=device)
        confidence = torch.empty(0, dtype=torch.float, device=device)
        positive = torch.empty(0, dtype=torch.float, device=device)

        for x, y in tqdm(data_loader):
            if rand_y:
                y = torch.randint(10, (512,))
            x, y = x.to(device), y.to(device)
            x = x.repeat(1, 3, 1, 1)
            output = model(x)
            confidence = torch.cat([confidence, F.softmax(output, dim=1).max(dim=1)[0].mean().reshape(1)], dim=0)
            positive = torch.cat([positive, (F.softmax(output, dim=1).max(dim=1)[0] >= t).float().mean().reshape(1)], dim=0)
            acc = torch.cat([acc, (F.softmax(output, dim=1).max(dim=1)[1] == y).float().mean().reshape(1)], dim=0)

    return acc.mean().item(), confidence.mean().item(), positive.mean().item()

In [10]:
for k_d, data_loader in dsets.items():
    for k_m, v_m in models_class.items():
        if v_m is '':
            continue
            
        model = torchvision.models.vgg11(num_classes=10).to(device)
        pre = '../archetypal_analysis/' if k_m.startswith('deepaa') else ''
        model.load_state_dict(dict(filter(lambda x: 'tmp' not in x[0], torch.load(f'{pre}runs/{v_m}/checkpoints/classifier.pt').items())))
        model.to(device)
#         model.eval()
        
        # TODO: bootstrap threshold
        if k_m.startswith('deepaa'):
            t = 0.85
        elif k_m.startswith('inn'):
            t = 0.6
        else:
            t = 0.7
        
        acc, confidence, positive = classifier_test(model, data_loader, t, rand_y=(not k_d.startswith('mnist')))
        
#         print(f'{k_m} on {k_d}: {acc}')
        
        results_class[f'{k_d}_acc'][k_m] = acc
        results_class[f'{k_d}_conf'][k_m] = confidence
        results_class[f'{k_d}_ind'][k_m] = positive

HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=243.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=117.0), HTML(value='')))




In [11]:
results_class[[f'{k}_acc' for k in dsets.keys()]].round(3)

Unnamed: 0,mnist_train_acc,mnist_test_acc,letters_acc,fashion_acc,kmnist_acc
base,0.999,0.995,0.101,0.1,0.101
base_letters,0.996,0.995,0.1,0.102,0.1
base_fashion,0.992,0.99,0.1,0.1,0.101
inn,0.963,0.963,0.101,0.101,0.099
deepaa_linear,0.975,0.975,0.101,0.1,0.101
deepaa_sq,0.99,0.99,0.1,0.098,0.099
deepaa_pre,0.982,0.981,0.1,0.099,0.102


In [12]:
results_class[[f'{k}_conf' for k in dsets.keys()]].round(3)

Unnamed: 0,mnist_train_conf,mnist_test_conf,letters_conf,fashion_conf,kmnist_conf
base,0.999,0.998,0.897,0.69,0.821
base_letters,0.987,0.986,0.393,0.247,0.379
base_fashion,0.992,0.992,0.824,0.112,0.661
inn,0.856,0.856,0.57,0.342,0.425
deepaa_linear,0.931,0.932,0.599,0.266,0.211
deepaa_sq,0.988,0.988,0.806,0.622,0.706
deepaa_pre,0.949,0.948,0.706,0.364,0.443


In [13]:
results_class[[f'{k}_ind' for k in dsets.keys()]].round(3)

Unnamed: 0,mnist_train_ind,mnist_test_ind,letters_ind,fashion_ind,kmnist_ind
base,0.999,0.997,0.85,0.517,0.735
base_letters,0.991,0.99,0.236,0.022,0.145
base_fashion,0.991,0.991,0.731,0.002,0.493
inn,0.905,0.903,0.437,0.083,0.183
deepaa_linear,0.874,0.876,0.301,0.05,0.017
deepaa_sq,0.975,0.974,0.54,0.178,0.322
deepaa_pre,0.915,0.914,0.407,0.017,0.063


## Results directly based on INN output

In [22]:
inn = INN().to(device)
inn.load_state_dict(dict(filter(lambda x: 'tmp' not in x[0], torch.load('runs/May13_19-39-26_GLaDOS/checkpoints/generator_in.pt').items())))

# Initialize sampling distribution
latent = torch.empty(0, 32 * 32)
classes = torch.empty(0).long()
with torch.no_grad():
    for x, y in tqdm(data.train_loader):
        x = x.to(device)
        y = y.to(device)
        cond = [
            fill[:, :, :16, :16][y],
            fill[:, :, :8, :8][y],
            make_cond(y)
        ]
        output = inn(x, cond)
        latent = torch.cat([latent, output.data.cpu()])
        classes = torch.cat([classes, y.data.cpu()])

mean = latent.mean(dim=0).to(device)
cov = torch.tensor(np.cov(latent.cpu().numpy().T), device=device,
                   dtype=torch.float)

# latent_dist = torch.distributions.multivariate_normal.MultivariateNormal(mean, cov)
latent_dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(1024, dtype=torch.float, device=device), torch.eye(1024, dtype=torch.float, device=device))

HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




In [None]:
latent_dist = torch.distributions.multivariate_normal.MultivariateNormal(mean, cov)
latent_dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(1024, dtype=torch.float, device=device), torch.eye(1024, dtype=torch.float, device=device))
latent_dist.entropy()

In [None]:
def inn_test(inn, mean, cov, data_loader, rand_y=False):
    with torch.no_grad():
        acc = torch.empty(0, dtype=torch.float, device=device)
        confidence = torch.empty(0, dtype=torch.float, device=device)
        for x, y in tqdm(data_loader):
            x, y = x.to(device), y.to(device)

            probs = torch.empty(512, 0, dtype=torch.float, device=device)
            
            for c in range(10):

                targets = c * torch.ones((512,), dtype=torch.long, device=device)
                
#                 if rand_y:
#                     y = torch.randint(10, (512,))

                cond = [
                    fill[:, :, :16, :16][targets],
                    fill[:, :, :8, :8][targets],
                    make_cond(targets)
                ]
                output = inn(x, cond)
#                 print(output.mean())
#                 print(latent_dist.log_prob(output))
                probs = torch.cat([probs, latent_dist.log_prob(output).reshape(512, 1)], dim=1)
#                 zz = torch.sum(output**2, dim=1)
#                 jac = inn.jacobian(run_forward=False)
#                 print(jac)
#                 print(zz)
#                 raise

#                 neg_log_likeli = 0.5 * zz - jac
#                 probs = torch.cat([probs, -neg_log_likeli.reshape(512, 1)], dim=1)
                
            confidence = torch.cat([confidence, probs.max(dim=1)[0]], dim=0)
            acc = torch.cat([acc, (probs.max(dim=1)[1] == y).float().mean().reshape(1)], dim=0)
                
    return acc.mean().item(), confidence.mean().item(), confidence.std().item()

In [None]:
inn_test(inn, mean, cov, test_loader, rand_y=False)

In [None]:
inn_test(inn, mean, cov, letter_loader, rand_y=False)

In [None]:
inn_test(inn, mean, cov, fashion_loader, rand_y=False)

## Adding Min-Likelihood
<img src=./Loss_Negative_Log-Likelihood_In-Dist.svg width=50% style="background-color: #FFF">

## Translations in bigger images

![](http://glados:6007/data/plugin/images/individualImage?ts=1589014600.3871336&run=May07_17-11-48_GLaDOS&tag=Samples%2FIn-Distribution&sample=0&index=1498)