In [None]:
import os

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
from torch import optim
from torch import nn
from comet_ml import Experiment

In [None]:
experiment = Experiment(api_key="E3oWJUSFulpXpCUQfc5oGz0zY", project_name="pytorch-vae")

In [None]:
img_size = 64
latent_dim = 128
channels = 3

In [None]:
ds_size = img_size // 2**4
h_dim = 256*ds_size**2
h_dim

In [None]:
Tensor = torch.cuda.FloatTensor

In [None]:
class Normal(object):
    def __init__(self, mu, sigma, log_sigma, v=None, r=None):
        self.mu = mu
        self.sigma = sigma  # either stdev diagonal itself, or stdev diagonal from decomposition
        self.logsigma = log_sigma
        dim = mu.get_shape()
        if v is None:
            v = Tensor(*dim)
        if r is None:
            r = Tensor(*dim)
        self.v = v
        self.r = r

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.conv_blocks = nn.Sequential(
            nn.Conv2d(channels, 32, 3, 1, 1),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(32, 0.8),
            
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(64, 0.8),
            
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(128, 0.8),
            
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(256, 0.8)
        )

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        return out

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.init_size = img_size // 2**4
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 256*self.init_size**2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(256),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.BatchNorm2d(128, 0.8),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.BatchNorm2d(64, 0.8),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 64, 3, stride=1, padding=1),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.BatchNorm2d(32, 0.8),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 32, 3, stride=1, padding=1),
            nn.Conv2d(32, channels, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 256, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [None]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self._enc_mu = nn.Linear(h_dim, latent_dim)
        self._enc_log_sigma = nn.Linear(h_dim, latent_dim)

    def _sample_latent(self, h_enc):
        """
        Return the latent normal sample z ~ N(mu, sigma^2)
        """
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = Variable(Tensor(np.random.normal(0, 1, size=sigma.size())))

        self.z_mean = mu
        self.z_sigma = sigma

        return mu + sigma * Variable(std_z, requires_grad=False)  # Reparameterization trick

    def forward(self, state):
        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

In [None]:
def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)

In [None]:
device = torch.device("cuda")

In [None]:
encoder = Encoder().cuda()
decoder = Decoder().cuda()
vae = VAE(encoder, decoder).cuda()

In [None]:
print(encoder)
print(decoder)
print(vae)

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
# Initialize weights
vae.apply(weights_init_normal)

In [None]:
dataroot = "/home/santiago/Downloads/celebA/"

In [None]:
batch_size = 64
workers = 4
dataset = datasets.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.CenterCrop(128),
                               transforms.Resize(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=int(workers))

In [None]:
criterion = nn.MSELoss().cuda()
optimizer = optim.Adam(vae.parameters(), lr=0.0001)

In [None]:
print('Number of samples: ', len(dataset))

In [None]:
os.makedirs("../images", exist_ok=True)
os.makedirs("../checkpoints", exist_ok=True)

In [None]:
n_epochs = 1
sample_interval = 100

In [None]:
batches_done = 0

In [None]:
with experiment.train():
    for epoch in range(n_epochs):
        for i, data in enumerate(dataloader):
            inputs = Variable(data[0].type(Tensor))
            optimizer.zero_grad()
            dec = vae(inputs)
            ll = latent_loss(vae.z_mean, vae.z_sigma)
            loss = criterion(dec, inputs) + ll
            loss.backward()
            optimizer.step()
            l = loss.data[0]
            batches_done += 1
            experiment.log_metric("loss", l, step=batches_done)
            if batches_done % sample_interval == 0:
                save_image(dec.data[:25], "../images/%d.png" % batches_done, nrow=5, normalize=True)
                torch.save(encoder.state_dict(), '../checkpoints/encoder_%d.pth' % batches_done)
                torch.save(decoder.state_dict(), '../checkpoints/decoder_%d.pth' % batches_done)
                torch.save(vae.state_dict(), '../checkpoints/vae_%d.pth' % batches_done)
                torch.save(optimizer.state_dict(), '../checkpoints/optimizer_%d.pth' % batches_done)
            print("epoch: {}/{}, step: {}/{}, global_step: {}, loss: {}".format(epoch, n_epochs, i, len(dataloader), batches_done, l))

In [None]:
batches_done

In [None]:
save_image(dec.data[:25], "../images/%d.png" % batches_done, nrow=5, normalize=True)
torch.save(encoder.state_dict(), '../checkpoints/encoder_%d.pth' % batches_done)
torch.save(decoder.state_dict(), '../checkpoints/decoder_%d.pth' % batches_done)
torch.save(vae.state_dict(), '../checkpoints/vae_%d.pth' % batches_done)
torch.save(optimizer.state_dict(), '../checkpoints/optimizer_%d.pth' % batches_done)

In [None]:
z = Variable(Tensor(np.random.normal(0, 1, (inputs.shape[0], latent_dim))))

In [None]:
test = decoder(z)

In [None]:
save_image(test.data[:25], "../images/test.png", nrow=5, normalize=True)