In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as model
import torchvision.utils as vutils
from torch.distributions import Categorical

import random
import numpy as np
import math
from IPython.display import clear_output
from PIL import Image
from tqdm.notebook import trange, tqdm

## Config

In [4]:
# config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_workers = 0
batch_size = 64

## Dataset

In [3]:
# define transform
transfrom = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

test_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [6]:
# MNIST dataset
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    transform=transfrom,
    download=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)

100%|██████████| 9.91M/9.91M [01:09<00:00, 143kB/s] 
100%|██████████| 28.9k/28.9k [00:00<00:00, 57.6kB/s]
100%|██████████| 1.65M/1.65M [00:02<00:00, 564kB/s] 
100%|██████████| 4.54k/4.54k [00:00<?, ?B/s]


## VQ-VAE

### Vector Quantizer Class  
Disini kita akan mendefiniskan kelas dan operasi yang ada di bagian tengah dari VQ-VAE.  
K = ukuran code book

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, code_book_size, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        '''
        Args :
            code_book_size : jumlah kode (vektor) yayng tersimpan di codebook -- ukuran kamus kuantisasi 
            embedding_dim : dimensi dari codebook -- ukuran vektor di codebook
            commitment_cost : mengontrol seberapa keras encoder dipaksa agar outputnya dekat dengan codebook -- hyperparamater dari commitment loss
        '''
        self.code_book_size = code_book_size
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost

        # inisialisasi codebook
        self.embedding = nn.Embedding(code_book_size, embedding_dim)            # ukuran dari codebook 
        # isi codebook dengan vektor acak dari distribusi uniform
        self.embedding.weight.data.uniform_(-1.0 / code_book_size, 1.0 / code_book_size)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 3, 1).contiguous()                        # (batch_size, C, H, W) -> (batch_size, H, W, C)
        input_shape = inputs.shape

        # flatten -> ubah jadi 1D
        flat_input = inputs.view(-1, 1, self.embedding_dim)                     # (batch_size * H * W, 1, C)

        # hitung jarak antara setiap vektor input dengan setiap vektor di codebook
        distances = (flat_input - self.embedding.weight.unsqueeze(0)).pow(2).mean(2)

        