In [None]:
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [None]:
path_train = "data/binarized_mnist_train.amat"
path_valid = "data/binarized_mnist_valid.amat"
path_test = "data/binarized_mnist_test.amat"


def data_loader(path):
    with open(path_train) as file:
        x=[np.array(list(map(int, img.split()))).reshape(28, 28) for img in file.readlines()]
    return torch.from_numpy(np.asarray(x)).float()
# WE NEED TO GET THE LABEL

train = torch.utils.data.TensorDataset(data_loader(path_train))
trainloader = torch.utils.data.DataLoader(
    train, batch_size=128, drop_last=True, shuffle=True)

test = torch.utils.data.TensorDataset(data_loader(path_train))
testloader = torch.utils.data.DataLoader(
    train, batch_size=128, drop_last=True, shuffle=False)

valid = torch.utils.data.TensorDataset(data_loader(path_train))
validloader = torch.utils.data.DataLoader(
    train, batch_size=128, drop_last=True, shuffle=False)

In [None]:
class VAE(nn.Module):
    def __init__(self, dim=100, batch_size=128):
        super(VAE, self).__init__()

        self.dim = dim
        self.batch_size = batch_size

        #ENCODER LAYERS
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3), nn.ELU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=64,
                      kernel_size=3), nn.ELU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=5),
            nn.ELU(), nn.Linear(256, 200))

        #DECODER LAYERS
        self.decoder = nn.Sequential(
            nn.Linear(100, 256), nn.ELU(),
            nn.Conv2d(
                in_channels=256, out_channels=64, kernel_size=5, padding=4),
            nn.ELU(), nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(
                in_channels=64, out_channels=32, kernel_size=3, padding=2),
            nn.ELU(), nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(
                in_channels=32, out_channels=16, kernel_size=3, padding=2),
            nn.ELU(),
            nn.Conv2d(
                in_channels=16, out_channels=1, kernel_size=3, padding=2))

    def encode(self, x):
        q_params = self.encoder(x)
        mu = q_params[:, :self.dim]
        std = torch.exp(0.5 * q_params[:, self.dim:])
        return mu, std

    def decode(self, z):
        return self.decoder(z)

    def sample(self, mu, std):
        eps = torch.randn(self.batch_size, self.dim)
        return mu + eps * std

    def forward(self, x):
        mu, std = self.encode(x)
        z = self.sample(mu, std)
        x_hat = self.decode(z)
        return x_hat, mu, std

    def criterion(x_hat, x, mu, std):
        BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')
        KL = 0.5 * torch.sum(-1 - 2 * log(std) + mu.pow(2) + std.pow(2))
        return BCE + KLD

    def mytrain(self, trainloader, validloader, epochs, is_train=False):
        if is_train:
            model.train()
        else:
            model.eval()
        
        for epoch in range(epochs):
            for i, data in enumerate(trainloader, 0):
                # get the inputs
                inputs, labels = data

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                x_hat, mu, std = self.forward(inputs)
                loss = self.criterion(x_hat, x, mu, std)
                loss.backward()
                optimizer.step()

                # print statistics
                running_loss += loss.item()
                if i % 2000 == 1999:    # print every 2000 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss / 2000))
                    running_loss = 0.0

In [None]:
model = VAE()
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Let\'s use {}".format(device))

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [None]:
model.mytrain(trainloader, validloader, 5, is_train=True)