In [1]:
"""Utility functions for NICE.
"""

import torch
import torch.nn.functional as F

def dequantize(x, dataset):
    '''Dequantize data.

    Add noise sampled from Uniform(0, 1) to each pixel (in [0, 255]).

    Args:
        x: input tensor.
        reverse: True in inference mode, False in training mode.
    Returns:
        dequantized data.
    '''
    noise = torch.distributions.Uniform(0., 1.).sample(x.size())
    return (x * 255. + noise) / 256.

def prepare_data(x, dataset, zca=None, mean=None, reverse=False):
    """Prepares data for NICE.

    In training mode, flatten and dequantize the input.
    In inference mode, reshape tensor into image size.

    Args:
        x: input minibatch.
        dataset: name of dataset.
        zca: ZCA whitening transformation matrix.
        mean: center of original dataset.
        reverse: True if in inference mode, False if in training mode.
    Returns:
        transformed data.
    """
    if reverse:
        assert len(list(x.size())) == 2
        [B, W] = list(x.size())

        if dataset in ['mnist', 'fashion-mnist']:
            assert W == 1 * 28 * 28
            x += mean
            x = x.reshape((B, 1, 28, 28))
        elif dataset in ['svhn', 'cifar10']:
            assert W == 3 * 32 * 32
            x = torch.matmul(x, zca.inverse()) + mean
            x = x.reshape((B, 3, 32, 32))
    else:
        assert len(list(x.size())) == 4
        [B, C, H, W] = list(x.size())

        if dataset in ['mnist', 'fashion-mnist']:
            assert [C, H, W] == [1, 28, 28]
        elif dataset in ['svhn', 'cifar10']:
            assert [C, H, W] == [3, 32, 32]

        x = dequantize(x, dataset)
        x = x.reshape((B, C*H*W))

        if dataset in ['mnist', 'fashion-mnist']:
            x -= mean
        elif dataset in ['svhn', 'cifar10']:
            x = torch.matmul((x - mean), zca)
    return x

"""Standard logistic distribution.
"""
class StandardLogistic(torch.distributions.Distribution):
    def __init__(self):
        super(StandardLogistic, self).__init__()

    def log_prob(self, x):
        """Computes data log-likelihood.

        Args:
            x: input tensor.
        Returns:
            log-likelihood.
        """
        return -(F.softplus(x) + F.softplus(-x))

    def sample(self, size):
        """Samples from the distribution.

        Args:
            size: number of samples to generate.
        Returns:
            samples.
        """
        z = torch.distributions.Uniform(0., 1.).sample(size).cuda()
        return torch.log(z) - torch.log(1. - z)

In [2]:
"""Utility classes for NICE.
"""

import torch
import torch.nn as nn

"""Additive coupling layer.
"""
class Coupling(nn.Module):
    def __init__(self, in_out_dim, mid_dim, hidden, mask_config):
        """Initialize a coupling layer.

        Args:
            in_out_dim: input/output dimensions.
            mid_dim: number of units in a hidden layer.
            hidden: number of hidden layers.
            mask_config: 1 if transform odd units, 0 if transform even units.
        """
        super(Coupling, self).__init__()
        self.mask_config = mask_config

        self.in_block = nn.Sequential(
            nn.Linear(in_out_dim//2, mid_dim),
            nn.ReLU())
        self.mid_block = nn.ModuleList([
            nn.Sequential(
                nn.Linear(mid_dim, mid_dim),
                nn.ReLU()) for _ in range(hidden - 1)])
        self.out_block = nn.Linear(mid_dim, in_out_dim//2)

    def forward(self, x, reverse=False):
        """Forward pass.

        Args:
            x: input tensor.
            reverse: True in inference mode, False in sampling mode.
        Returns:
            transformed tensor.
        """
        [B, W] = list(x.size())
        x = x.reshape((B, W//2, 2))
        if self.mask_config:
            on, off = x[:, :, 0], x[:, :, 1]
        else:
            off, on = x[:, :, 0], x[:, :, 1]

        off_ = self.in_block(off)
        for i in range(len(self.mid_block)):
            off_ = self.mid_block[i](off_)
        shift = self.out_block(off_)
        if reverse:
            on = on - shift
        else:
            on = on + shift

        if self.mask_config:
            x = torch.stack((on, off), dim=2)
        else:
            x = torch.stack((off, on), dim=2)
        return x.reshape((B, W))

"""Log-scaling layer.
"""
class Scaling(nn.Module):
    def __init__(self, dim):
        """Initialize a (log-)scaling layer.

        Args:
            dim: input/output dimensions.
        """
        super(Scaling, self).__init__()
        self.scale = nn.Parameter(
            torch.zeros((1, dim)), requires_grad=True)

    def forward(self, x, reverse=False):
        """Forward pass.

        Args:
            x: input tensor.
            reverse: True in inference mode, False in sampling mode.
        Returns:
            transformed tensor and log-determinant of Jacobian.
        """
        log_det_J = torch.sum(self.scale)
        if reverse:
            x = x * torch.exp(-self.scale)
        else:
            x = x * torch.exp(self.scale)
        return x, log_det_J

"""NICE main model.
"""
class NICE(nn.Module):
    def __init__(self, prior, coupling, 
        in_out_dim, mid_dim, hidden, mask_config):
        """Initialize a NICE.

        Args:
            prior: prior distribution over latent space Z.
            coupling: number of coupling layers.
            in_out_dim: input/output dimensions.
            mid_dim: number of units in a hidden layer.
            hidden: number of hidden layers.
            mask_config: 1 if transform odd units, 0 if transform even units.
        """
        super(NICE, self).__init__()
        self.prior = prior
        self.in_out_dim = in_out_dim

        self.coupling = nn.ModuleList([
            Coupling(in_out_dim=in_out_dim, 
                     mid_dim=mid_dim, 
                     hidden=hidden, 
                     mask_config=(mask_config+i)%2) \
            for i in range(coupling)])
        self.scaling = Scaling(in_out_dim)

    def g(self, z):
        """Transformation g: Z -> X (inverse of f).

        Args:
            z: tensor in latent space Z.
        Returns:
            transformed tensor in data space X.
        """
        x, _ = self.scaling(z, reverse=True)
        for i in reversed(range(len(self.coupling))):
            x = self.coupling[i](x, reverse=True)
        return x

    def f(self, x):
        """Transformation f: X -> Z (inverse of g).

        Args:
            x: tensor in data space X.
        Returns:
            transformed tensor in latent space Z.
        """
        for i in range(len(self.coupling)):
            x = self.coupling[i](x)
        return self.scaling(x)

    def log_prob(self, x):
        """Computes data log-likelihood.

        (See Section 3.3 in the NICE paper.)

        Args:
            x: input minibatch.
        Returns:
            log-likelihood of input.
        """
        z, log_det_J = self.f(x)
        log_ll = torch.sum(self.prior.log_prob(z), dim=1)
        return log_ll + log_det_J

    def sample(self, size):
        """Generates samples.

        Args:
            size: number of samples to generate.
        Returns:
            samples from the data space X.
        """
        z = self.prior.sample((size, self.in_out_dim)).cuda()
        return self.g(z)

    def forward(self, x):
        """Forward pass.

        Args:
            x: input minibatch.
        Returns:
            log-likelihood of input.
        """
        return self.log_prob(x)

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd "/content/drive/MyDrive/NICE_3/NICE-master"
%ls

/content/drive/.shortcut-targets-by-id/1D0bG3a6FGKxZBFzHH4JtJN4IJU_eoNSh/NICE_3/NICE-master
LICENSE     nice.py       [0m[01;34mreconstruction[0m/  Test_NICE_loss.ipynb  zca.py
[01;34mmodels[0m/     [01;34m__pycache__[0m/  [01;34msamples[0m/         train.py
NICE.ipynb  README.md     [01;34mstatistics[0m/      utils.py


In [None]:
!python nice.py
!python utils.py
#!python zca.py

In [4]:
import torch, torchvision
from torchvision import transforms
import numpy as np
#from VAE import Model
import matplotlib.pyplot as plt
import torch.nn as nn

#from google.colab import drive
#drive.mount('/content/drive')
latent = 'logistic'
if latent == 'normal':
    prior = torch.distributions.Normal(
        torch.tensor(0.).to(device), torch.tensor(1.).to(device))
elif latent == 'logistic':
    prior = StandardLogistic()

# model hyperparameters
coupling = 4
mask_config = 1.
(full_dim, mid_dim, hidden) = (3 * 32 * 32, 2000, 4)
in_out_dim=full_dim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
lr = 1e-3
momentum = 0.999
decay = 0.9

flow = NICE(prior, coupling, in_out_dim, mid_dim, hidden, mask_config).to(device)

################@@@@@@@@@
optimizer = torch.optim.Adam(flow.parameters(), lr=lr, betas=(momentum, decay), eps=1e-4)
 
checkpoint = torch.load("/content/drive/MyDrive/NICE_3/NICE-master/models/cifar10/cifar10_bs200_logistic_cp4_md2000_hd4_iter25000.tar",map_location=torch.device(device))

flow.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


transform = torchvision.transforms.Compose(
        [torchvision.transforms.RandomHorizontalFlip(p=0.5),
         torchvision.transforms.ToTensor()])

testset_cifar10 = torchvision.datasets.CIFAR10(root='./data/CIFAR10',train=False, download=True, transform=transform)
testloader_cifar10 = torch.utils.data.DataLoader(testset_cifar10,batch_size=1, shuffle=False, num_workers=2)

testset_svhn = torchvision.datasets.SVHN(root='./data/SVHN', download=True, transform=transform)
testloader_svhn = torch.utils.data.DataLoader(testset_svhn,batch_size=1, shuffle=False, num_workers=2)

  'with `validate_args=False` to turn off validation.')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/CIFAR10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/CIFAR10/cifar-10-python.tar.gz to ./data/CIFAR10
Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./data/SVHN/train_32x32.mat


  0%|          | 0/182040794 [00:00<?, ?it/s]

In [5]:
def test_nice(flow, testloader, sample_size=0, dataset='cifar10'):
    if dataset == 'cifar10':
        zca = torch.load('/content/drive/MyDrive/NICE_3/NICE-master/statistics/cifar10_zca_3.pt')
        mean = torch.load('/content/drive/MyDrive/NICE_3/NICE-master/statistics/cifar10_mean.pt')

    if dataset == 'svhn':
        zca = torch.load('/content/drive/MyDrive/NICE_3/NICE-master/statistics/svhn_zca_3.pt')
        mean = torch.load('/content/drive/MyDrive/NICE_3/NICE-master/statistics/svhn_mean.pt')
    
    running_loss = []
    with torch.no_grad():
        for inputs, _ in testloader:
            inputs = prepare_data(inputs, dataset, zca=zca, mean=mean).to(device)
            loss = -flow(inputs).mean()
            running_loss.append(loss)
    return running_loss
            #z, _ = flow.f(inputs)
            #reconst = flow.g(z).cpu()
            #reconst = prepare_data(reconst, dataset, zca=zca, mean=mean, reverse=True)
            #print(reconst)            
        

In [6]:
test_loss_per_batch_cifar10_nice = test_nice(flow=flow, testloader=testloader_cifar10)
print(test_loss_per_batch_cifar10_nice)
with open('/content/drive/MyDrive/NICE_3/NICE-master/test_loss_per_batch_cifar10_NICE.txt', 'w') as writefile:
    writefile.write(str(test_loss_per_batch_cifar10_nice))

[tensor(3317.7261), tensor(2397.6001), tensor(1952.5781), tensor(2880.3081), tensor(3919.5869), tensor(2287.2915), tensor(3645.9990), tensor(1833.7554), tensor(2117.1987), tensor(3072.1050), tensor(1387.5327), tensor(3927.3618), tensor(2469.3540), tensor(3778.1348), tensor(4788.1821), tensor(4096.6050), tensor(3822.4810), tensor(3220.4541), tensor(2236.6826), tensor(4341.7739), tensor(3315.2114), tensor(1860.6333), tensor(1255.5713), tensor(6070.7241), tensor(2833.4429), tensor(4099.1836), tensor(2341.3911), tensor(2279.9409), tensor(3877.2080), tensor(4746.2632), tensor(2498.6177), tensor(1238.5581), tensor(3358.1821), tensor(2355.0220), tensor(1692.8877), tensor(2924.4639), tensor(2600.7407), tensor(4369.0098), tensor(4880.2612), tensor(2231.5947), tensor(3079.3716), tensor(3785.5815), tensor(3999.7188), tensor(1951.8052), tensor(4582.9380), tensor(2968.3999), tensor(1909.1426), tensor(4168.1060), tensor(2491.6021), tensor(1550.1187), tensor(2789.7842), tensor(2284.8828), tensor(2347

In [7]:
test_loss_per_batch_svhn_nice = test_nice(flow=flow, testloader=testloader_svhn, dataset = 'svhn')
print(test_loss_per_batch_svhn_nice)
with open('/content/drive/MyDrive/NICE_3/NICE-master/test_loss_per_batch_svhn_NICE.txt', 'w') as writefile:
    writefile.write(str(test_loss_per_batch_svhn_nice))

[tensor(3922.2764), tensor(3154.7378), tensor(4497.1177), tensor(4147.5088), tensor(782.2520), tensor(798.5679), tensor(1892.5986), tensor(1894.0825), tensor(1128.7285), tensor(1137.3516), tensor(753.1230), tensor(800.3545), tensor(1713.5566), tensor(1327.9302), tensor(789.0640), tensor(807.2764), tensor(850.1714), tensor(1189.2261), tensor(1354.2710), tensor(1333.1260), tensor(1731.6533), tensor(1589.5034), tensor(836.5874), tensor(853.2065), tensor(2211.6631), tensor(2114.6836), tensor(1559.4082), tensor(1488.9443), tensor(1129.6919), tensor(1232.4990), tensor(1184.5977), tensor(1225.4258), tensor(1205.2939), tensor(1011.7085), tensor(1213.7803), tensor(1165.3608), tensor(1984.5078), tensor(1746.0742), tensor(1795.1079), tensor(1975.4712), tensor(1710.1138), tensor(758.9829), tensor(781.1294), tensor(3814.1675), tensor(855.3633), tensor(859.5024), tensor(869.8892), tensor(932.8467), tensor(911.7441), tensor(2602.8350), tensor(2950.2163), tensor(909.3999), tensor(1034.7192), tensor(90