In [7]:
import torch
from torch import nn
from torch import optim
from torchvision import transforms
from torchvision import datasets

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
numworkers = 4
batch_size=10

In [10]:
imagenet_data = datasets.MNIST('../datasets/', download=True)
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=numworkers)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../datasets/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ../datasets/MNIST/raw/train-images-idx3-ubyte.gz to ../datasets/MNIST/raw


28.4%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../datasets/MNIST/raw/train-labels-idx1-ubyte.gz


0.5%5%

Extracting ../datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ../datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ../datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../datasets/MNIST/raw
Processing...
Done!




In [11]:
def LinearBlock(_in,_out,reg=True, inplace=True):
    layers = []
    layers.append(nn.Linear(_in,_out))
    if reg:
        layers.append(nn.ReLU(inplace=inplace))
    return layers

In [13]:
class AdditiveCoupling(nn.Module):
    '''
    Single Layer to be used in NICE. For MNIST we use 5 of these that are rectified.
    '''
    def __init__(self, _in, _out, numlayers=5):
        super(AdditiveCoupling,self).__init__()
        
        relu = nn.ReLU(inplace=True)

        self.inputLayer = nn.Sequential(relu(nn.Linear(_in,1000)))
        self.hiddenLayer = nn.Sequential(relu(nn.Linear(1000,1000)))
        self.outputLayer = nn.Sequential(relu(nn.Linear(1000,_out)))
        self.reverseInputLayer = nn.Sequential(relu(nn.Linear(1000,_in)))
        self.reverseOutputLayer = nn.Sequential(relu(nn.Linear(_out,1000)))
        self.s = nn.Parameter(torch.zeros(1))
        self.si = nn.Parameter(torch.zeros(1))
        
    def forward(self,x):
        x1,x2 = self.split(x)
        h1_1 = x1
        h1_2 = x2 + self.inputLayer(x1)
        #
        h2_1 = h1_2
        h2_2 = h1_1 + self.hiddenLayer(x2)
        #
        h3_1 = h2_1
        h3_2 = h2_2 + self.hiddenLayer(x1)
        #
        h4_1 = h3_2
        h4_2 = h3_1 + self.outputLayer(x2)
        #
        h4 = torch.cat((h4_1,h4_2))
        h = torch.exp(self.s)*h4
        return h
    
    def backward(self,y):
        y1,y2 = self.split(y)
        x1_1 = y1
        x1_2 = y2 - self.reverseInputLayer(y1)
        # 
        x2_1 = x1_2
        x2_2 = x1_1 - self.hiddenLayer(y2)
        #
        x3_1 = x2_1
        x3_2 = x2_2 - self.hiddenLayer(y1)
        #
        x4_1 = x3_2
        x4_2 = x3_1 - self.hiddenLayer(y2)
        #
        h4 = torch.cat((x4_1,x4_2))
        h = torch.exp(self.si)*h4