In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

root = '/root/data/SVHN/'

# Load the SVHN
train_dataset = torchvision.datasets.SVHN(root=root, split='train', transform=transform, download=False)
test_dataset = torchvision.datasets.SVHN(root=root, split='test', transform=transform, download=False)

bs = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=True)

In [3]:
class ResidualBlock(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim, dim, 1)
        self.batchnorm2d = nn.BatchNorm2d(dim)

    def forward(self, x):
        tmp = self.conv1(x)
        tmp = self.batchnorm2d(tmp)
        tmp = self.relu(tmp)
        tmp = self.conv2(tmp)
        tmp = self.batchnorm2d(tmp)
        tmp = x + tmp
        tmp = self.relu(tmp)
        return tmp


class VQVAE(nn.Module):

    def __init__(self, dim, n_embedding):
        super().__init__()
        self.encoder = nn.Sequential(nn.Conv2d(dim[0], dim[1], 4, 2, 1),
                                     nn.ReLU(), nn.Conv2d(dim[1], dim[2], 4, 2, 1),
                                     nn.ReLU(),
                                     ResidualBlock(dim[2]), ResidualBlock(dim[2]))
        self.vq_embedding = nn.Embedding(n_embedding, dim[2])
        self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding,
                                               1.0 / n_embedding)
        self.decoder = nn.Sequential(
            ResidualBlock(dim[2]), ResidualBlock(dim[2]),
            nn.ConvTranspose2d(dim[2], dim[1], 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(dim[1], dim[0], 4, 2, 1))
        self.n_downsample = 2

    def forward(self, x):
        # encode
        ze = self.encoder(x)

        # ze: [N, C, H, W]
        # embedding [K, C]
        embedding = self.vq_embedding.weight.data
        N, C, H, W = ze.shape
        K, _ = embedding.shape
        embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
        ze_broadcast = ze.reshape(N, 1, C, H, W)
        distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)
        nearest_neighbor = torch.argmin(distance, 1)
        # make C to the second dim
        zq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)
        # stop gradient
        decoder_input = ze + (zq - ze).detach()

        # decode
        x_hat = self.decoder(decoder_input)
        return x_hat, ze, zq

    def encode(self, x):
        ze = self.encoder(x)
        embedding = self.vq_embedding.weight.data

        # ze: [N, C, H, W]
        # embedding [K, C]
        N, C, H, W = ze.shape
        K, _ = embedding.shape
        embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
        ze_broadcast = ze.reshape(N, 1, C, H, W)
        distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)
        nearest_neighbor = torch.argmin(distance, 1)
        return nearest_neighbor

    def decode(self, discrete_latent):
        zq = self.vq_embedding(discrete_latent).permute(0, 3, 1, 2)
        x_hat = self.decoder(zq)
        return x_hat

    # Shape: [C, H, W]
    def get_latent_HW(self, input_shape):
        C, H, W = input_shape
        return (H // 2**self.n_downsample, W // 2**self.n_downsample)

In [4]:
from time import time

n_epochs = 1000
learning_rate = 5e-2
alpha = 1
beta = 0.25

pi = torch.tensor(torch.pi)

def loss_fn(x, x_hat, ze, zq):
    l_reconstruct = F.mse_loss(x, x_hat)
    l_1 = F.mse_loss(ze.detach(), zq)
    l_2 = F.mse_loss(ze, zq.detach())
    loss = l_reconstruct + alpha * l_1 + beta * l_2
    return loss/len(x)

def loss_reconstruction(x, x_hat):
    return F.mse_loss(x, x_hat)/len(x)

def loss_emb(ze, zq):
    l_1 = F.mse_loss(ze.detach(), zq)
    l_2 = F.mse_loss(ze, zq.detach())
    return (alpha * l_1 + beta * l_2)/len(ze)

def train(device, model):
    optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
    begin_time = time()
    # train
    with open('/root/fornewdata/SVHN/VQVAE/loss.txt', 'w') as file:
        for i in range(n_epochs):
            for batch_idx, (x, label) in enumerate(train_loader):
                x = x.to(device)
                x_hat, ze, zq = model(x)
                loss = loss_fn(x, x_hat, ze, zq)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            # estimate loss
            model.eval()
            with torch.no_grad():
                each_epoch = 10
                n_samples = 100
                #train
                indices = torch.randperm(len(train_dataset))[:n_samples]
                x = torch.stack([train_dataset[i][0] for i in indices]).to(device)
                label = torch.tensor([train_dataset[i][1] for i in indices]).to(device).unsqueeze(1)
                x_hat, ze, zq = model(x)
                loss = loss_fn(x, x_hat, ze, zq)
                loss_recons = loss_reconstruction(x, x_hat)
                loss_embed = loss_emb(ze, zq)
                if(i % each_epoch == 0):
                    print('====> Epoch: {} elbo loss: {:.7f}'.format(i, loss))
                    print('====> Epoch: {} reconst loss: {:.7f}'.format(i, loss_recons))
                    print('====> Epoch: {} embedding loss: {:.7f}'.format(i, loss_embed))
                file.write(str(loss.item()) + ' ' + str(loss_recons.item()) + ' ' + str(loss_embed.item()) + ' ')
                #test
                indices = torch.randperm(len(test_dataset))[:n_samples]
                x = torch.stack([test_dataset[i][0] for i in indices]).to(device)
                label = torch.tensor([test_dataset[i][1] for i in indices]).to(device).unsqueeze(1)
                x_hat, ze, zq = model(x)
                loss = loss_fn(x, x_hat, ze, zq)
                loss_recons = loss_reconstruction(x, x_hat)
                loss_embed = loss_emb(ze, zq)
                if(i % each_epoch == 0):
                    print('====> Epoch: {} elbo loss: {:.7f}'.format(i, loss))
                    print('====> Epoch: {} reconst loss: {:.7f}'.format(i, loss_recons))
                    print('====> Epoch: {} embedding loss: {:.7f}'.format(i, loss_embed))
                file.write(str(loss.item()) + ' ' + str(loss_recons.item()) + ' ' + str(loss_embed.item()) + '\n')
            
            #time
            if(i % each_epoch == 0):
                training_time = time() - begin_time
                minute = int(training_time // 60)
                second = int(training_time % 60)
                print(f'time loss {minute}:{second}')
        
        torch.save(model, '/root/fornewdata/SVHN/VQVAE/model.pth')
        
    tot_training_time = time() - begin_time
    minute = int(tot_training_time // 60)
    second = int(tot_training_time % 60)
    print(f'total time loss {minute}:{second}')

def initialize_parameters(model):
    for param in model.parameters():
        param.data.normal_(mean=0, std=0.01)

In [None]:
def main():
    device = 'cuda:0'

    model = VQVAE(dim=[3, 500, 500], n_embedding=500).to(device)
    initialize_parameters(model)

    train(device, model)

if __name__ == '__main__':
    main()

In [6]:
device = 'cpu'
model = torch.load('/root/fornewdata/SVHN/VQVAE/model.pth', map_location=device)

In [7]:
# Reconstruction
from torchvision.utils import save_image
with torch.no_grad(): 
    
    for batch_idx, (batch_x, _) in enumerate(train_loader):
        true_imgs = batch_x[0:49].view(-1, 3, 32, 32)
        resized_image = torchvision.transforms.Resize((50, 50))(true_imgs)
        save_image(resized_image, '/root/fornewdata/SVHN/VQVAE/pictures/oring.png', nrow=7)
        break
    
    x = true_imgs.to(device)
    
    reconst_x = model(x)[0]
    reconst_imgs = reconst_x.view(-1, 3, 32, 32)
    resized_image = torchvision.transforms.Resize((50, 50))(reconst_imgs)
    save_image(resized_image, '/root/fornewdata/SVHN/VQVAE/pictures/recons.png', nrow=7)

In [None]:
#drawing

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
        
import numpy as np
import pandas as pd

def moving_average(data, window_size):
    return pd.Series(data).rolling(window=window_size, min_periods=1).mean().values

def drawing():
    plt.clf()
    fig = plt.figure(figsize=(8, 8))
    trains = []
    trains_recon = []
    trains_kl = []
    tests = []
    tests_recon = []
    tests_kl = []
    with open('/root/fornewdata/SVHN/VQVAE/loss.txt', 'r') as file:
        for line in file:
            parts = line.split()
            trains.append(float(parts[0]))
            trains_recon.append(float(parts[1]))
            trains_kl.append(float(parts[2]))
            tests.append(float(parts[3]))
            tests_recon.append(float(parts[4]))
            tests_kl.append(float(parts[5]))
            
    epochs = [len(train_dataset) * i for i in range(1, len(trains) + 1)]
    
    size = 100
    print("Reconstruction Loss: ", sum(trains_recon[-size:])/size)
    print("KL Loss: ", sum(trains_kl[-size:])/size)
    print("Reconstruction Loss: ", sum(tests_recon[-size:])/size)
    print("KL Loss: ", sum(tests_kl[-size:])/size)
    
    window_size = 50
    trains = moving_average(trains, window_size)
    tests = moving_average(tests, window_size)
    trains_recon = moving_average(trains_recon, window_size)
    tests_recon = moving_average(tests_recon, window_size)
    trains_kl = moving_average(trains_kl, window_size)
    tests_kl = moving_average(tests_kl, window_size)
    
    plt.plot(epochs, trains, label='Train Loss')
    plt.plot(epochs, tests, label='Test Loss')
    plt.plot(epochs, trains_recon, label='Train reconstruction Loss')
    plt.plot(epochs, tests_recon, label='Test reconstruction Loss')
    plt.plot(epochs, trains_kl, label='Train embedding Loss')
    plt.plot(epochs, tests_kl, label='Test embedding Loss')
    
    plt.xlabel('Samples')
    plt.ylabel('Loss')
    plt.title('MNIST')
    plt.legend()
    #plt.ylim(-150, -100)
    #plt.xlim(1e5, 1e8)
    plt.xscale('log')
    plt.yscale('log')
    plt.savefig('/root/fornewdata/SVHN/VQVAE/pictures/train.png',bbox_inches='tight')
    plt.show()

drawing()