In [2]:
# import the usual libraries:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import tqdm

import numpy as np

In [5]:
# import custom modules/losses from NICE:
from nice.models import NICEModel
from nice.loss import logistic_nice_loglkhd, LogisticPriorNICELoss

## 1. model

Let's make sure that our model is capable of inverting arbitrary inputs.

In [109]:
# build model with and without batch norm, three layers deep:
nice_bn = NICEModel(input_dim=256, hidden_dim=64, num_layers=3, bn=True)
nice_mlp = NICEModel(input_dim=256, hidden_dim=64, num_layers=3, bn=False)

In [7]:
# compare a tensor with its inverse-mapped value:
def recover_tensor(tsr, mdl):
    with torch.no_grad():
        return torch.dist(mdl.inverse(mdl(tsr)), tsr, p=1)

In [35]:
# run this multiple times and make sure the values are small:
recover_tensor(torch.randn(8,256), nice_bn)

tensor(1.6177e-05)

In [68]:
# same thing for the non-batchnormed model:
recover_tensor(torch.randn(8,256), nice_mlp)

tensor(0.0455)

# 2. losses
Now let's make sure that we can train against our loss module to memorize a single instance.

In [69]:
# this will be the random tensor that we'll try to get our model to memorize; run this cell *once*.
constant = torch.randn(8,256)
print(constant)

tensor([[ 1.0470e-01, -3.2764e-01,  2.5212e-01,  ..., -6.7200e-01,
         -1.0572e+00,  1.8751e-01],
        [-3.2082e-01, -8.4130e-01,  7.4957e-01,  ..., -7.2816e-01,
          2.1122e-01,  1.9313e+00],
        [ 9.4984e-01,  4.1410e-01, -5.2190e-01,  ...,  2.1097e-01,
          1.3180e-02, -1.6768e+00],
        ...,
        [-1.2008e-01, -1.4522e-01, -8.4159e-01,  ...,  1.0249e+00,
         -9.0380e-01,  1.2813e-03],
        [ 4.3895e-01,  1.0337e+00, -1.0610e+00,  ...,  9.5216e-01,
         -9.6363e-02, -2.7394e-02],
        [-1.1182e+00, -2.3659e+00,  2.8266e-01,  ...,  3.7139e+00,
         -3.8203e-01, -7.4203e-01]])


In [187]:
# --- part 1: fitting with the batchnorm NICE model:
loss_fn = LogisticPriorNICELoss(size_average=True)
opt = optim.Adam(nice_bn.parameters(), lr=0.00005, betas=(0.9, 0.999), eps=1e-8)
def train(n=1):
    nice_bn.train()
    for k in tqdm.trange(n):
        opt.zero_grad()
        (-loss_fn(nice_bn(constant), nice_bn.scaling_diag)).backward()
        opt.step()
    with torch.no_grad():
        nice_bn.eval()
        print(-loss_fn(nice_bn(constant), nice_bn.scaling_diag).item())
        nice_bn.train()

In [111]:
# try out the loss function, which should compute negative log-likelihood against a logistic distribution:
with torch.no_grad():
    nice_bn.eval()
    print(-loss_fn(nice_bn(constant), nice_bn.scaling_diag).item())
    nice_bn.train()

253.4134979248047


In [189]:
train(n=100)

100%|██████████| 100/100 [00:08<00:00, 12.39it/s]

nan





In [183]:
# define method to sample (elementwise) from a standard logistic:
def sample_logistic_like(tsr):
    Z = torch.rand_like(tsr)
    Y = torch.log(Z) - torch.log(1-Z)
    return Y
with torch.no_grad():
    print(logistic_nice_loglkhd(sample_logistic_like(constant)).sum().item())

-4070.748779296875


In [191]:
# see if we can recover the tensor from sampling:
with torch.no_grad():
    print(torch.dist(constant, nice_bn.inverse(sample_logistic_like(constant)), p=2.).item())

nan


In [None]:
# --- part 2: the same thing, with the standard (no-batchnorm) NICE model:
# (N.B.: run this *after* running the above cells)

## 3. training
Finally, let's overfit against a single MNIST image and recover it.