In [62]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

bs = 100
# MNIST Dataset
#train_dataset = datasets.MNIST(root='/root/data/', train=True, transform=transforms.ToTensor(), download=True)
train_dataset = datasets.MNIST(root='/root/data/', train=True, transform=transforms.ToTensor(), download=False)
test_dataset = datasets.MNIST(root='/root/data/', train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [63]:
import torch
import torch.nn as nn



class VQVAE(nn.Module):

    def __init__(self, back_kernels = [3, 3], channels = [1, 64, 256], latent_dim = 2, n_embedding = 10) -> None:
        super().__init__()
        
        
        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        modules.append(
            nn.Sequential(
                nn.Conv2d(pre_channel,
                          latent_dim,
                          kernel_size=3,
                          stride=1,
                          padding=1),
            )
        )
        self.encoder = nn.Sequential(*modules)
        self.vq_embedding = nn.Embedding(n_embedding, latent_dim)
        self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding, 1.0 / n_embedding)
        
        # decoder
        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv2d(latent_dim,
                          pre_channel,
                          kernel_size=3,
                          stride=1,
                          padding=1),
                nn.BatchNorm2d(pre_channel),
                nn.ReLU()
            )
        )
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(channels[len(back_kernels)-i],
                                       channels[len(back_kernels)-i-1],
                                       kernel_size=back_kernels[i],
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(channels[len(back_kernels)-i-1]),
                    nn.ReLU()
                )
            )
        self.decoder = nn.Sequential(*modules)
        

    def forward(self, x):
        ze = self.encoder(x)
        # ze: [Num, Channel, Height, Width]
        # embedding: [emb_num, Channel]
        embedding = self.vq_embedding.weight.data
        Num, Channel, Height, Width = ze.shape
        emb_num, _ = embedding.shape
        embedding_broad = embedding.reshape(1, emb_num, Channel, 1, 1)
        ze_broad = ze.reshape(Num, 1, Channel, Height, Width)
        distance = torch.sum((embedding_broad - ze_broad)**2, 2)
        nearest_nei = torch.argmin(distance, 1)
        zq = self.vq_embedding(nearest_nei).permute(0, 3, 1, 2)
        decoder_input = ze + (zq - ze).detach()
        # decode
        x_hat = self.decoder(decoder_input)
        return x_hat, ze, zq


In [64]:
from time import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

n_epochs = 50
kl_weight = 0.00025
lr = 0.005
alpha = 1
beta = 0.25

def loss_fn(x, x_hat, ze, zq):
    l_reconstruct = F.mse_loss(x, x_hat)
    l_1 = F.mse_loss(ze.detach(), zq)
    l_2 = F.mse_loss(ze, zq.detach())
    loss = l_reconstruct + alpha * l_1 + beta * l_2
    return loss

def train(device, model):
    optimizer = torch.optim.Adam(model.parameters(), lr)
    begin_time = time()
    # train
    loss_list = []
    for i in range(n_epochs):
        loss_sum = 0
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.to(device)
            x_hat, ze, zq = model(x)
            loss = loss_fn(x, x_hat, ze, zq)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(i, batch_idx * len(x), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item() / len(x)))
        print('====> Epoch: {} Average loss: {:.4f}'.format(i, loss_sum / len(train_loader.dataset)))
        loss_list.append(loss_sum / len(train_loader.dataset))
        training_time = time() - begin_time
        minute = int(training_time // 60)
        second = int(training_time % 60)
        print(f'time loss {minute}:{second}')
        torch.save(model.state_dict(), '/root/VQVAE/VQVAE_model.pth')
    tot_training_time = time() - begin_time
    minute = int(tot_training_time // 60)
    second = int(tot_training_time % 60)
    print(f'total time loss {minute}:{second}')
    #plot epoch
    epochs = range(1, len(loss_list) + 1)
    plt.plot(epochs, loss_list, marker='o', linestyle='-')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss per Epoch')
    plt.grid(True)
    plt.savefig('/root/VQVAE/pictures/training.png')
    plt.show()

from PIL import Image, ImageDraw, ImageFont

def reconstruct(device, model):
    model.eval()
    num = 8
    image_list = []
    for i in range(num):
        image_tensor, _ = train_dataset[i]
        image_list.append(image_tensor)
    batch = torch.stack(image_list)
    x = batch.to(device)
    output = model(x)[0]
    output = output.detach().cpu()
    input = batch.detach().cpu()
    concatenated_images = torch.cat((input, output), dim=2)
    save_image(concatenated_images, '/root/VQVAE/pictures/compare.png')
    original_image = Image.open('/root/VQVAE/pictures/compare.png')
    original_width, original_height = original_image.size
    new_width = original_width + 100
    new_height = original_height
    new_image = Image.new("RGB", (new_width, new_height), color=(255, 255, 255))
    new_image.paste(original_image, (100, 0))
    draw = ImageDraw.Draw(new_image)
    title_height = original_height // 2
    draw.text((10, 10), "Input", fill=(0, 0, 0))
    draw.text((10, title_height + 10), "VQVAE", fill=(0, 0, 0))
    new_image.save('/root/VQVAE/pictures/compare.png')
    
def latent_space(device, model):
    model.eval()
    model.eval()
    num = 1000
    image_list = []
    label_list = []
    for i in range(num):
        image_tensor, label = train_dataset[i]
        image_list.append(image_tensor)
        label_list.append(label)
    batch = torch.stack(image_list)
    x = batch.to(device)
    _, ze, zq = model(x)
    z_cpu = ze.cpu().detach().numpy()
    label_list = np.array(label_list)
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'brown']
    unique_labels = {}
    for i in range(len(label_list)):
        label = label_list[i]
        if label not in unique_labels:
            unique_labels[label] = colors[label]
        plt.scatter(z_cpu[i][0], z_cpu[i][1], color=colors[label], s=10)
    legend_handles = []
    for label, color in unique_labels.items():
        legend_handles.append(plt.Line2D([0], [0], marker='o', color='w', label=f'{label}', markerfacecolor=color, markersize=10))
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('Scatter Plot with Different Colors')
    plt.legend(handles=legend_handles)
    plt.savefig('/root/VQVAE/pictures/latent_space.png')
    plt.show()

In [67]:
def main():
    device = 'cuda:0'

    model = VQVAE().to(device)

    # Load the model
    model.load_state_dict(torch.load('/root/VQVAE/VQVAE_model.pth', 'cuda:0'))

    #Choose which to play
    #train(device, model)
    #generate(device, model)
    reconstruct(device, model)
    #latent_space(device, model)
    print(sum(p.numel() for p in model.parameters() if p.requires_grad))

if __name__ == '__main__':
    main()

307225
