In [1]:
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F

import math
import numpy as np

# Load data

In [2]:
path_train = "data/binarized_mnist_train.amat"
path_valid = "data/binarized_mnist_valid.amat"
path_test = "data/binarized_mnist_test.amat"

In [3]:
def data_loader(path):
    with open(path_train) as file:
        x = [
            np.array(list(map(int, img.split()))).reshape(1, 28, 28)
            for img in file.readlines()
        ]
    return torch.from_numpy(np.asarray(x)).float()

In [4]:
train = torch.utils.data.TensorDataset(data_loader(path_train))
trainloader = torch.utils.data.DataLoader(
    train, batch_size=128, drop_last=True, shuffle=True)

In [5]:
valid = torch.utils.data.TensorDataset(data_loader(path_train))
validloader = torch.utils.data.DataLoader(
    train, batch_size=128, drop_last=True, shuffle=False)

In [6]:
test = torch.utils.data.TensorDataset(data_loader(path_train))
testloader = torch.utils.data.DataLoader(
    train, batch_size=128, drop_last=True, shuffle=False)

# Model

In [10]:
class VAE(nn.Module):
    def __init__(self, dim=100, batch_size=128):
        super(VAE, self).__init__()

        self.dim = dim
        self.batch_size = batch_size

        #ENCODER LAYERS
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3), nn.ELU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=64,
                      kernel_size=3), nn.ELU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=5),
            nn.ELU())
        self.hidden_1 = nn.Linear(256, 200)

        #DECODER LAYERS
        self.hidden_2 = nn.Linear(100, 256)
        self.decoder = nn.Sequential(
            nn.ELU(),
            nn.Conv2d(
                in_channels=256, out_channels=64, kernel_size=5, padding=4),
            nn.ELU(), nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(
                in_channels=64, out_channels=32, kernel_size=3, padding=2),
            nn.ELU(), nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(
                in_channels=32, out_channels=16, kernel_size=3, padding=2),
            nn.ELU(),
            nn.Conv2d(
                in_channels=16, out_channels=1, kernel_size=3, padding=2),
            nn.Sigmoid())

    def encode(self, x):
        q_params = self.hidden_1(self.encoder(x).view(self.batch_size, 256))
        mu = q_params[:, :self.dim]
        logvar = q_params[:, self.dim:]
        return mu, logvar

    def decode(self, z):
        return self.decoder(self.hidden_2(z).view(self.batch_size, 256, 1, 1))

    def sample(self, mu, logvar):
        eps = torch.randn(self.batch_size, self.dim).to(device)
        return mu + eps * (0.5 * logvar).exp()

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.sample(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar

    def criterion(self, x_hat, x, mu, logvar):
        BCE = F.binary_cross_entropy(x_hat, x, reduction='mean')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

    def evaluate(self, loader):
        with torch.no_grad():
            loss = 0
            for i, data in enumerate(loader, 1):
                # get the inputs
                x = data[0].to(device)
                x_hat, mu, logvar = self.forward(x)
                loss += self.criterion(x_hat, x, mu, logvar).item()
        return loss / i

# Train

In [11]:
# create model and move it to device
model = VAE()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Let's use {}".format(device))

Let's use cuda:0


In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
for epoch in range(5):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        x = data[0].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        x_hat, mu, logvar = model.forward(x)

        loss = model.criterion(x_hat, x, mu, logvar)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    # print statistics
    train_loss = running_loss / i
    val_loss = model.evaluate(validloader)
    print('epoch %2d: loss: %5.3f\tval_loss: %5.3f' % (epoch + 1, train_loss,
                                                      val_loss))

epoch  1: loss: 0.647	val_loss: 0.299
epoch  2: loss: 0.288	val_loss: 0.279
epoch  3: loss: 0.277	val_loss: 0.274
epoch  4: loss: 0.273	val_loss: 0.270
epoch  5: loss: 0.270	val_loss: 0.268
