In [19]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.utils.data as data
import torch.optim as optim
import torchvision
from tqdm import tqdm, trange
import numpy as np


class StopExecution(Exception):
    def _render_traceback_(self):
        pass

In [3]:
# set CUDA training on if detected:
if torch.cuda.is_available():
    DEVICE = torch.device("cuda:0")
    CUDA = True
else:
    DEVICE = torch.device("cpu")
    CUDA = False


In [4]:
def rescale(x, lo, hi):
    """Rescale a tensor to [lo,hi]."""
    assert lo < hi, "[rescale] lo={0} must be smaller than hi={1}".format(lo, hi)
    old_width = torch.max(x) - torch.min(x)
    old_center = torch.min(x) + (old_width / 2.0)
    new_width = float(hi - lo)
    new_center = lo + (new_width / 2.0)
    # shift everything back to zero:
    x = x - old_center
    # rescale to correct width:
    x = x * (new_width / old_width)
    # shift everything to the new center:
    x = x + new_center
    # return:
    return x


def zca_matrix(data_tensor):
    """
    Helper function: compute ZCA whitening matrix across a dataset ~ (N, C, H, W).
    """
    # 1. flatten dataset:
    X = data_tensor.view(data_tensor.shape[0], -1)

    # 2. zero-center the matrix:
    X = rescale(X, -1.0, 1.0)

    # 3. compute covariances:
    cov = torch.t(X) @ X

    # 4. compute ZCA(X) == U @ (diag(1/S)) @ torch.t(V) where U, S, V = SVD(cov):
    U, S, V = torch.svd(cov)
    return U @ torch.diag(torch.reciprocal(S)) @ torch.t(V)

In [5]:
def download_torchvision_data(dataset: str = "mnist"):
    ### download training datasets:
    if dataset == "mnist":
        print("Downloading MNIST...")
        mnist = torchvision.datasets.MNIST(
            root="./datasets/MNIST/torchvision",
            train=True,
            transform=torchvision.transforms.ToTensor(),
            download=True,
        )
    elif dataset == "cifar10":
        print("Downloading CIFAR10...")
        cifar10 = torchvision.datasets.CIFAR10(
            root="./datasets/CIFAR10",
            train=True,
            transform=torchvision.transforms.ToTensor(),
            download=True,
        )
        ### save ZCA whitening matrices:
        print("Computing CIFAR10 ZCA matrix...")
        torch.save(
            zca_matrix(torch.cat([x for (x, _) in cifar10], dim=0)),
            "./datasets/CIFAR10/zca_matrix.pt",
        )
    elif dataset == "svhn":
        print("Downloading SVHN...")
        svhn = torchvision.datasets.SVHN(
            root="./datasets/SVHN",
            split="train",
            transform=torchvision.transforms.ToTensor(),
            download=True,
        )
        ### save ZCA whitening matrices:
        print("Computing SVHN ZCA matrix...")
        torch.save(
            zca_matrix(torch.cat([x for (x, _) in svhn], dim=0)),
            "./datasets/SVHN/zca_matrix.pt",
        )
    else:
        raise NotImplementedError(
            f"Dataset {dataset} not implemented yet. Please choose from ['mnist', 'cifar10', 'svhn']"
        )
    print("...All done.")


In [6]:
download_torchvision_data("mnist")


Downloading MNIST...
...All done.


In [7]:
def load_mnist(train=True, batch_size=1, num_workers=0):
    """Rescale and preprocess MNIST dataset."""
    mnist_transform = torchvision.transforms.Compose(
        [
            # convert PIL image to tensor:
            torchvision.transforms.ToTensor(),
            # flatten:
            torchvision.transforms.Lambda(lambda x: x.view(-1)),
            # add uniform noise:
            torchvision.transforms.Lambda(
                lambda x: (x + torch.rand_like(x).div_(256.0))
            ),
            # rescale to [0,1]:
            torchvision.transforms.Lambda(lambda x: rescale(x, 0.0, 1.0)),
        ]
    )
    return data.DataLoader(
        torchvision.datasets.MNIST(
            root="./datasets/MNIST/torchvision",
            train=train,
            transform=mnist_transform,
            download=False,
        ),
        batch_size=batch_size,
        pin_memory=CUDA,
        drop_last=train,
        num_workers=num_workers,
    )


In [17]:
class Args:
    dataset = "mnist"
    epochs = 1500
    batch_size = 16
    nlayers = 5
    nhidden = 1000
    prior = "gaussian"
    lr = 0.001
    beta1 = 0.9
    beta2 = 0.01
    eps = 0.0001
    lmbda = 0.0


args = Args()

In [9]:
_get_even = lambda xs: xs[:, 0::2]
_get_odd = lambda xs: xs[:, 1::2]


def _interleave(first, second, order):
    """
    Given 2 rank-2 tensors with same batch dimension, interleave their columns.

    The tensors "first" and "second" are assumed to be of shape (B,M) and (B,N)
    where M = N or N+1, repsectively.
    """
    cols = []
    if order == "even":
        for k in range(second.shape[1]):
            cols.append(first[:, k])
            cols.append(second[:, k])
        if first.shape[1] > second.shape[1]:
            cols.append(first[:, -1])
    else:
        for k in range(first.shape[1]):
            cols.append(second[:, k])
            cols.append(first[:, k])
        if second.shape[1] > first.shape[1]:
            cols.append(second[:, -1])
    return torch.stack(cols, dim=1)


class _BaseCouplingLayer(nn.Module):
    def __init__(self, dim, partition, nonlinearity):
        """
        Base coupling layer that handles the permutation of the inputs and wraps
        an instance of torch.nn.Module.

        Usage:
        >> layer = AdditiveCouplingLayer(1000, 'even', nn.Sequential(...))

        Args:
        * dim: dimension of the inputs.
        * partition: str, 'even' or 'odd'. If 'even', the even-valued columns are sent to
        pass through the activation module.
        * nonlinearity: an instance of torch.nn.Module.
        """
        super(_BaseCouplingLayer, self).__init__()
        # store input dimension of incoming values:
        self.dim = dim
        # store partition choice and make shorthands for 1st and second partitions:
        assert partition in [
            "even",
            "odd",
        ], "[_BaseCouplingLayer] Partition type must be `even` or `odd`!"
        self.partition = partition
        if partition == "even":
            self._first = _get_even
            self._second = _get_odd
        else:
            self._first = _get_odd
            self._second = _get_even
        # store nonlinear function module:
        # (n.b. this can be a complex instance of torch.nn.Module, for ex. a deep ReLU network)
        self.add_module("nonlinearity", nonlinearity)

    def forward(self, x):
        """Map an input through the partition and nonlinearity."""
        return _interleave(
            self._first(x),
            self.coupling_law(self._second(x), self.nonlinearity(self._first(x))),
            self.partition,
        )

    def inverse(self, y):
        """Inverse mapping through the layer. Gradients should be turned off for this pass."""
        return _interleave(
            self._first(y),
            self.anticoupling_law(self._second(y), self.nonlinearity(self._first(y))),
            self.partition,
        )

    def coupling_law(self, a, b):
        # (a,b) --> g(a,b)
        raise NotImplementedError(
            "[_BaseCouplingLayer] Don't call abstract base layer!"
        )

    def anticoupling_law(self, a, b):
        # (a,b) --> g^{-1}(a,b)
        raise NotImplementedError(
            "[_BaseCouplingLayer] Don't call abstract base layer!"
        )


class AdditiveCouplingLayer(_BaseCouplingLayer):
    """Layer with coupling law g(a;b) := a + b."""

    def coupling_law(self, a, b):
        return a + b

    def anticoupling_law(self, a, b):
        return a - b


class MultiplicativeCouplingLayer(_BaseCouplingLayer):
    """Layer with coupling law g(a;b) := a .* b."""

    def coupling_law(self, a, b):
        return torch.mul(a, b)

    def anticoupling_law(self, a, b):
        return torch.mul(a, torch.reciprocal(b))

In [10]:
def _build_relu_network(latent_dim, hidden_dim, num_layers):
    """Helper function to construct a ReLU network of varying number of layers."""
    _modules = [nn.Linear(latent_dim, hidden_dim)]
    for _ in range(num_layers):
        _modules.append(nn.Linear(hidden_dim, hidden_dim))
        _modules.append(nn.ReLU())
        _modules.append(nn.BatchNorm1d(hidden_dim))
    _modules.append(nn.Linear(hidden_dim, latent_dim))
    return nn.Sequential(*_modules)


class NICEModel(nn.Module):
    """
    Replication of model from the paper:
      "Nonlinear Independent Components Estimation",
      Laurent Dinh, David Krueger, Yoshua Bengio (2014)
      https://arxiv.org/abs/1410.8516

    Contains the following components:
    * four additive coupling layers with nonlinearity functions consisting of
      five-layer RELUs
    * a diagonal scaling matrix output layer
    """

    def __init__(self, input_dim, hidden_dim, num_layers):
        super(NICEModel, self).__init__()
        assert (
            input_dim % 2 == 0
        ), "[NICEModel] only even input dimensions supported for now"
        assert num_layers > 2, "[NICEModel] num_layers must be at least 3"
        self.input_dim = input_dim
        half_dim = int(input_dim / 2)
        self.layer1 = AdditiveCouplingLayer(
            input_dim, "odd", _build_relu_network(half_dim, hidden_dim, num_layers)
        )
        self.layer2 = AdditiveCouplingLayer(
            input_dim, "even", _build_relu_network(half_dim, hidden_dim, num_layers)
        )
        self.layer3 = AdditiveCouplingLayer(
            input_dim, "odd", _build_relu_network(half_dim, hidden_dim, num_layers)
        )
        self.layer4 = AdditiveCouplingLayer(
            input_dim, "even", _build_relu_network(half_dim, hidden_dim, num_layers)
        )
        self.scaling_diag = nn.Parameter(torch.ones(input_dim))

        # randomly initialize weights:
        for p in self.layer1.parameters():
            if len(p.shape) > 1:
                init.kaiming_uniform_(p, nonlinearity="relu")
            else:
                init.normal_(p, mean=0.0, std=0.001)
        for p in self.layer2.parameters():
            if len(p.shape) > 1:
                init.kaiming_uniform_(p, nonlinearity="relu")
            else:
                init.normal_(p, mean=0.0, std=0.001)
        for p in self.layer3.parameters():
            if len(p.shape) > 1:
                init.kaiming_uniform_(p, nonlinearity="relu")
            else:
                init.normal_(p, mean=0.0, std=0.001)
        for p in self.layer4.parameters():
            if len(p.shape) > 1:
                init.kaiming_uniform_(p, nonlinearity="relu")
            else:
                init.normal_(p, mean=0.0, std=0.001)

    def forward(self, xs):
        """
        Forward pass through all invertible coupling layers.

        Args:
        * xs: float tensor of shape (B,dim).

        Returns:
        * ys: float tensor of shape (B,dim).
        """
        ys = self.layer1(xs)
        ys = self.layer2(ys)
        ys = self.layer3(ys)
        ys = self.layer4(ys)
        ys = torch.matmul(ys, torch.diag(torch.exp(self.scaling_diag)))
        return ys

    def inverse(self, ys):
        """Invert a set of draws from gaussians"""
        with torch.no_grad():
            xs = torch.matmul(
                ys, torch.diag(torch.reciprocal(torch.exp(self.scaling_diag)))
            )
            xs = self.layer4.inverse(xs)
            xs = self.layer3.inverse(xs)
            xs = self.layer2.inverse(xs)
            xs = self.layer1.inverse(xs)
        return xs

In [11]:
def gaussian_nice_loglkhd(h, diag):
    """
    Definition of log-likelihood function with a Gaussian prior, as in the paper.

    Args:
    * h: float tensor of shape (N,D). First dimension is batch dim, second dim consists of components
      of a factorized probability distribution.
    * diag: scaling diagonal of shape (D,).

    Returns:
    * loss: torch float tensor of shape (N,).
    """
    # \sum^D_i s_{ii} - { (1/2) * \sum^D_i  h_i**2) + (D/2) * log(2\pi) }
    return torch.sum(diag) - (
        0.5 * torch.sum(torch.pow(h, 2), dim=1)
        + h.size(1) * 0.5 * torch.log(torch.tensor(2 * np.pi))
    )


def logistic_nice_loglkhd(h, diag):
    """
    Definition of log-likelihood function with a Logistic prior.

    Same arguments/returns as gaussian_nice_loglkhd.
    """
    # \sum^D_i s_{ii} - { \sum^D_i log(exp(h)+1) + torch.log(exp(-h)+1) }
    return torch.sum(diag) - (
        torch.sum(torch.log1p(torch.exp(h)) + torch.log1p(torch.exp(-h)), dim=1)
    )


# wrap above loss functions in Modules:
class GaussianPriorNICELoss(nn.Module):
    def __init__(self, size_average=True):
        super(GaussianPriorNICELoss, self).__init__()
        self.size_average = size_average

    def forward(self, fx, diag):
        if self.size_average:
            return torch.mean(-gaussian_nice_loglkhd(fx, diag))
        else:
            return torch.sum(-gaussian_nice_loglkhd(fx, diag))


class LogisticPriorNICELoss(nn.Module):
    def __init__(self, size_average=True):
        super(LogisticPriorNICELoss, self).__init__()
        self.size_average = size_average

    def forward(self, fx, diag):
        if self.size_average:
            return torch.mean(-logistic_nice_loglkhd(fx, diag))
        else:
            return torch.sum(-logistic_nice_loglkhd(fx, diag))

In [18]:
if args.prior == "logistic":
    nice_loss_fn = LogisticPriorNICELoss(size_average=True)
else:
    nice_loss_fn = GaussianPriorNICELoss(size_average=True)


def loss_fn(fx):
    """Compute NICE loss w/r/t a prior and optional L1 regularization."""
    if args.lmbda == 0.0:
        return nice_loss_fn(fx, model.scaling_diag)
    else:
        return nice_loss_fn(fx, model.scaling_diag) + args.lmbda * l1_norm(
            model, include_bias=True
        )

In [12]:
def train(args):
    """Construct a NICE model and train over a number of epochs."""
    # === choose which dataset to build:
    if args.dataset == "mnist":
        dataloader_fn = load_mnist
        input_dim = 28 * 28
    if args.dataset == "svhn":
        dataloader_fn = load_svhn
        input_dim = 32 * 32 * 3
    if args.dataset == "cifar10":
        dataloader_fn = load_cifar10
        input_dim = 32 * 32 * 3
    else:
        raise NotImplementedError(
            "[train] dataset {} not supported".format(args.dataset)
        )

    # === build model & optimizer:
    model = NICEModel(input_dim, args.nhidden, args.nlayers)

In [15]:
input_dim = 28 * 28
dataloader_fn = load_mnist
dataloader = dataloader_fn(train=True, batch_size=args.batch_size)
dataloader.dataset.data.shape

torch.Size([60000, 28, 28])

In [16]:
model = NICEModel(input_dim, args.nhidden, args.nlayers)
opt = optim.Adam(
    model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), eps=args.eps
)

In [21]:
for inputs, _ in tqdm(dataloader):
    opt.zero_grad()
    fx = model(inputs)
    print(fx.shape)
    print(fx, "\n")
    loss = loss_fn(fx)
    print(loss)
    loss.backward()
    opt.step()
    raise StopExecution

  0%|          | 0/3750 [00:01<?, ?it/s]

torch.Size([16, 784])
tensor([[0.0098, 0.0059, 0.0161,  ..., 0.0052, 0.0149, 0.0215],
        [0.0103, 0.0086, 0.0142,  ..., 0.0054, 0.0146, 0.0142],
        [0.0058, 0.0088, 0.0162,  ..., 0.0079, 0.0204, 0.0200],
        ...,
        [0.0062, 0.0089, 0.0084,  ..., 0.0063, 0.0169, 0.0232],
        [0.0102, 0.0095, 0.0149,  ..., 0.0042, 0.0123, 0.0206],
        [0.0055, 0.0110, 0.0090,  ..., 0.0043, 0.0112, 0.0152]],
       grad_fn=<MmBackward0>) 

tensor(227.2925, grad_fn=<MeanBackward0>)



