In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image

from tqdm import tqdm, tnrange, tqdm_notebook

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm

from notify_run import Notify

In [None]:
notify = Notify()
notify.register()

In [None]:
batch_size = 32
epochs = 50
seed = 1
log_interval = 10

no_cuda = False
cuda = not no_cuda and torch.cuda.is_available()

torch.manual_seed(seed)

device = torch.device("cuda" if cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

In [None]:
#Tools Dataset
dataset = torchvision.datasets.ImageFolder('../../data/',
                                           transform=transforms.Compose([
                                               transforms.Resize(28),
                                               transforms.CenterCrop(28),
                                               transforms.ToTensor()
                                           ]))

train_loader = DataLoader(dataset=dataset,
                         batch_size=batch_size, 
                         shuffle=False)

In [None]:
class VAE(nn.Module):
    def __init__(self, zdim):
        super(VAE, self).__init__()

        # Encoder
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.pool1 = nn.MaxPool2d(2)
        
        # Latent vectors
        self.fc1 = nn.Linear(14*14*32, 128)
        self.fc21 = nn.Linear(128, zdim)
        self.fc22 = nn.Linear(128, zdim)
        
        # Decoder
        self.fc3 = nn.Linear(zdim, 128)
        self.fc4 = nn.Linear(128, 14*14*32)
        
        self.conv3 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv4 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=1, padding=1, bias=False)        
                
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
    
    def encode(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.dropout(self.pool1(x))
        x = x.view(-1, 14 * 14 * 32)
        x = F.relu(self.fc1(x))
        return self.fc21(x), self.fc22(x)
    
    def decode(self, z):
        z = F.relu(self.fc3(z))
        z = F.relu(self.fc4(z))
        z = z.view(-1, 32, 14, 14)
        z = F.interpolate(z, scale_factor=2)
        z = F.relu(self.conv3(z))
        z = torch.sigmoid(self.conv4(z))
        return z
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.sampling(mu, log_var)
        return self.decode(z), mu, log_var

In [None]:
model = VAE(zdim=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [None]:
val_losses = []
train_losses = []

In [None]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            tqdm.write('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       loss.item() / len(data)))

    tqdm.write('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
    train_losses.append(train_loss / len(train_loader.dataset))

In [None]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                        recon_batch.view(batch_size, 3, 64, 64)[:n]])
                save_image(comparison.cpu(),
                           './reconstruction_upsample_' + str(epoch) + '.png', nrow=n)

In [None]:
notify.send('Starting training')
for epoch in tnrange(epochs):
    notify.send('Training on epoch {}'.format(epoch+1))
    train(epoch+1)

In [None]:
torch.save(model.state_dict(), "./mnist-vae-cnn-tools.torch")
notify.send("Saved model")

In [None]:
model.load_state_dict(torch.load("./mnist-vae-cnn-tools.torch"))

In [None]:
plt.plot(val_losses)
plt.title('Tool Validation Loss\nCNN VAE 50 epochs\nz=2')
plt.savefig('./validation_loss.png')

In [None]:
plt.plot(train_losses)
plt.title('Tool Training Loss\nCNN VAE 50 epochs\nz=2')
plt.savefig('./training_loss.png')

In [None]:
with torch.no_grad():
    z = torch.randn(64, 2)
    sample = model.decode(z.cuda())
    save_image(sample.view(64, 3, 28, 28).cpu(), './sample_zdim_{}'.format(2) + '.png')

In [None]:
n = 15
digit_size = 28

u_grid = np.dstack(np.meshgrid(np.linspace(0.05, 0.95, n),
                               np.linspace(0.05, 0.95, n)))
z_grid = norm.ppf(u_grid)

x_decoded = model.decode(torch.from_numpy(z_grid.reshape(n*n, 2)).float().cuda())
x_decoded = x_decoded.reshape(n, n, 3, digit_size, digit_size)

plt.figure(figsize=(10, 10))
plt.imshow(np.block(list(map(list, x_decoded.detach().cpu().numpy()))).transpose(1,2,0))
plt.savefig('./latent_dimension_sample.png')
plt.show()