## Import Libraries

In [8]:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn 
import torch 
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from torchvision import datasets, transforms
from tqdm import tqdm


## Data processing / import (MNIST)

In [None]:

train_set = datasets.MNIST(root='../mnist_data', train=True,  download=True)
test_set = datasets.MNIST(root='../mnist_data', train=False, download=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
transform = transforms.Compose([transforms.Resize((20,20))])

class custom_mnist(Dataset):
    def __init__(self, input_data):
        super().__init__()
        self.data = input_data
        

    def __getitem__(self, idx):
        return torch.FloatTensor(np.array(transform(self.data[idx][0])) > 0)

    def __len__(self):
        return len(self.data)

In [4]:
mnist_train = custom_mnist(train_set)
mnist_test = custom_mnist(test_set)

train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=128)

## Original Data Visualization

In [None]:
f, axarr = plt.subplots(10,10)

for i in range(10):
    for j in range(10):
        axarr[i,j].imshow(mnist_train[10*i + j])
        axarr[i,j].axis('off')


## MADE Architecture

In [6]:

class MaskedLayer(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        self.register_buffer('mask', torch.ones(out_features, in_features))

    def set_mask(self, mask):
        self.mask.data.copy_(mask.T)

    
    def forward(self, x):
        return F.linear(x, self.weight*self.mask, self.bias)


# d - number of unique values any pixel can take - (0-255 --> 256) / (0-1 --> 2)
# if d = 2, passing it as 1 should also work

class MADE(nn.Module):
    def __init__(self, input_shape, d,  hidden_layers):
        super().__init__()

        self.in_features = input_shape*d
        self.out_features = input_shape*d

        
        layers = [self.in_features] + hidden_layers + [self.out_features]

        self.net = []
        self.masks = []
        self.m = []
        
        h1 = layers[:-1]
        h2 = layers[1:]

        for idx,(i1,i2) in enumerate(zip(h1,h2)):
            self.net.append(MaskedLayer(i1,i2))
            self.net.append(nn.ReLU())

            self.masks.append(torch.zeros(i1,i2))
            if (idx == 0):
                self.m.append(torch.cat(d*[torch.arange(input_shape)]))
                continue

            self.m.append(torch.arange(i1))

        self.m.append(torch.cat(d*[torch.arange(input_shape)]))

        for i in range(0,len(self.m)-2):
            mask = (self.m[i][:,None] <= self.m[i+1][None, :]).int()
            self.net[2*i].set_mask(mask) 

        mask = (self.m[-2][:,None] < self.m[-1][None, :]).int()
        self.net[-2].set_mask(mask)

        self.net.pop()

        self.net = nn.ModuleList(self.net)



                    
    def forward(self,x):
        return nn.Sequential(*self.net)(x)
        
    

## Training

In [None]:

made = MADE(400,1,[400,400]).to(device)
optimizer = Adam(made.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()
num_epochs = 10
train_loss_hist = []
valid_loss_hist = []

for epoch in range(num_epochs):
    for idx, train_data in tqdm(enumerate(train_loader)):
        data = train_data.to(device)
        b,_,_ = train_data.shape

        data = data.view(b,-1)

        pred = made(data)

        loss = criterion(pred, data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss_hist.append(loss)

    valid_loss = 0
    for idx, valid_data in tqdm(enumerate(test_loader)):
        data = valid_data.to(device)
        b,_,_ = valid_data.shape

        data = data.view(b,-1)

        with torch.no_grad():
            pred = made(data)

        loss = criterion(pred, data)
        valid_loss += loss

    valid_loss_hist.append(valid_loss/idx)

plt.plot(train_loss_hist)
plt.show()
plt.plot(valid_loss_hist)



## Sampling 

In [9]:
def sample(model, num_samples, num_input):
    out = torch.zeros(num_samples,num_input)

    for i in range(num_input):
        with torch.no_grad():
            logits = made(out.to(device))
        samples = torch.bernoulli(torch.sigmoid(logits))

        out[:, i] = samples[:, i]

    return out.reshape(num_samples, 20,20).detach()

In [10]:
out = sample(made, 100, 400)

## Generation

In [None]:
f, axarr = plt.subplots(10,10)

for i in range(10):
    for j in range(10):
        axarr[i,j].imshow(out[10*i + j])
        axarr[i,j].axis('off')

