In [1]:
import torch
import torchvision
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import FashionMNIST
import random
import os

# Define the hyper-parameter and load the training dataset

In [2]:
if not os.path.exists('./vae_img'):
    os.mkdir('./vae_img')


num_epochs = 5
batch_size = 32
learning_rate = 0.001

random.seed(2019)
torch.manual_seed(2019)

img_transform = transforms.Compose([
    transforms.ToTensor()
])

dataset = FashionMNIST('./data', transform=img_transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


# Define a VAE architecture

In [3]:
class VAE(nn.Module):
    def __init__(self, image_channels=1, h_dim=1600, z_dim=10):
        super(VAE, self).__init__()

        self.conv1 = nn.Conv2d(image_channels, 32, kernel_size=4, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2)
        self.deconv1 = nn.ConvTranspose2d(32, image_channels, kernel_size=4, stride=2)
        
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        z = mu + std * eps
        return z
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def encode(self, x):
        h_1 = F.relu(self.conv1(x))
        h_2 = F.relu(self.conv2(h_1))
        flat = h_2.view(h_2.size(0), -1)
        z, mu, logvar = self.bottleneck(flat)
        return z, mu, logvar

    def decode(self, z):
        fc = self.fc3(z)
        reshape = fc.view(fc.size(0), 64, 5, 5)
        h_2 = F.relu(self.deconv2(reshape))
        h_1 = self.deconv1(h_2)
        return F.sigmoid(h_1)

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        z = self.decode(z)
        return z, mu, logvar


model = VAE()
if torch.cuda.is_available():
    model.cuda()

# Define a Loss function and optimizer

In [4]:
reconstruction_function = nn.MSELoss(size_average=False)


def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    BCE = reconstruction_function(recon_x, x)  # mse loss
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)

    return BCE + KLD


optimizer = optim.Adam(model.parameters(), lr=learning_rate)



# Start to training and save the reconstruction images

In [5]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(dataloader):
        img, _ = data
        img = Variable(img)
        if torch.cuda.is_available():
            img = img.cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(img)
        loss = loss_function(recon_batch, img, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch_idx * len(img),
                len(dataloader.dataset), 100. * batch_idx / len(dataloader),
                loss.item() / len(img)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(dataloader.dataset)))
    if epoch % 1 == 0:
        save = img.cpu().data
        save_image(img, './vae_img/input_{}.png'.format(epoch))
        save = recon_batch.cpu().data
        save_image(save, './vae_img/output_{}.png'.format(epoch))

torch.save(model.state_dict(), './vae.pth')



====> Epoch: 0 Average loss: 32.2195
====> Epoch: 1 Average loss: 26.9809
====> Epoch: 2 Average loss: 26.3041
====> Epoch: 3 Average loss: 25.9495
====> Epoch: 4 Average loss: 25.7153
