source:: https://github.com/Jackson-Kang/Pytorch-VAE-tutorial/blob/master/02_Vector_Quantized_Variational_AutoEncoder.ipynb

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm
from torchvision.utils import save_image, make_grid

In [3]:

dataset_path = '~/datasets'

cuda = True
DEVICE = torch.device("cuda" if cuda else "cpu")

batch_size = 128
img_size = (32, 32) # (width, height)

input_dim = 3
hidden_dim = 512
latent_dim = 16
n_embeddings= 512
output_dim = 3
commitment_beta = 0.25

lr = 2e-4

epochs = 50

print_step = 50

In [4]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

mnist_transform=transforms.Compose([ # 여러 변환 순차적 적용
    transforms.ToTensor()
    ])
kwargs={'num_workers': 1, 'pin_memory': True}
train_dataset=CIFAR10(dataset_path, transform=mnist_transform, train=True, download=True) # train=True → 학습용 데이터셋 (training set) 을 불러오라는 의미
test_dataset  = CIFAR10(dataset_path, transform=mnist_transform, train=False, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs) # 키워드 인자들을 딕셔너리처럼 받아옴
test_loader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=False,  **kwargs)

100%|██████████| 170M/170M [00:04<00:00, 39.5MB/s]


In [5]:
class Encoder(nn.Module):
  def __init__(self, input, hidden, out, kernel_size=(4,4,3,1), stride=2):
    super().__init__()

    kernel1, kernel2, kernel3, kernel4= kernel_size
    self.conv1=nn.Conv2d(input, hidden, kernel1, stride, padding=1)
    # nn.Conv2d(input_dim, hidden_dim는 채널 수, out 채널 수만큼 독립적인 hidden개의 커널이 생성됨 (사이즈는 kernel1)
    self.conv2=nn.Conv2d(hidden, hidden, kernel2, stride, padding=1)

    self.residualconv1=nn.Conv2d(hidden, hidden, kernel3, padding=1)
    self.residualconv2=nn.Conv2d(hidden, hidden, kernel4, padding=0)

    self.proj=nn.Conv2d(hidden, out, kernel_size=1) # 해상도는 유지하며 채널을 out개로 만듬

  def forward(self, x):
    x=self.conv1(x)
    x=self.conv2(x)
    x=F.relu(x)

    y=self.residualconv1(x)
    y=y+x

    x=F.relu(y)
    y=self.residualconv2(x)
    y=y+x

    y=self.proj(y)
    return y


In [6]:
class VQEmbeddingEMA(nn.Module): # 코드북(임베딩 테이블) 학습
  def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5):
    super().__init__()
    self.commitment_cost=commitment_cost
    self.decay=decay
    self.epsilon=epsilon

    init_bound = 1 / n_embeddings
    embedding = torch.Tensor(n_embeddings, embedding_dim)
    embedding.uniform_(-init_bound, init_bound)    # Uniform distribution에 따라 제자리에서 초기화
    '''
    코드북을 학습하기 때문에 적절한 값으로 초기화하지 않으면 학습 초기에 코드북 벡터들이 지나치게 크거나 유사해서 학습
    이 불안정해짐 - 작은 값의 균등분포로 초기
    '''
    self.register_buffer("embedding", embedding) # 코드북 임베딩은 일반적인 gradient descent로 학습되지 않음
    self.register_buffer("ema_count", torch.zeros(n_embeddings)) # EMA방식은 직접 수식으로 업데이트
    self.register_buffer("ema_weight", self.embedding.clone())

  def encode(self, x):
    '''
    x를 입력으로 받아서 코드북 참조해서 가장 가까운 인덱스로 양자화하고
    그 벡터(코드북)와 인덱스 반환
    '''
    M, D=self.embedding.size()
    x_flat= x.detach().reshape(-1,D) # x.detach(): 해당 텐서에서 기울기 흐름 끊음

    distances=torch.cdist(x_flat, self.embedding, p=2) # .cdist: 각 x_flat[i]와 코드북의 모든 벡터 간 거리 계산

    indices=torch.argmin(distances.float(), dim=1) # 가장 가까운 코드북벡터 찾기
    quantized = F.embedding(indices, self.embedding) # 인덱스들을 통해 각 위치에 대응하는 코드북 벡터들 뽑아냄
    quantized = quantized.view_as(x)                 # 다시 x shape대로 복원
    return quantized, indices.view(x.shape[0], x.shape[2], x.shape[3])  # 양자화된 출력 텐서 (B, D, H, W)와 각 위치에 어떤 코드북 인덱스 사용했는지

  def retrieve_random_codebook(self, random_indices):
    quantized = F.embedding(random_indices, self.embedding)
    quantized = quantized.permute(0, 3, 1, 2) # 텐서의 차원 순서를 바꾸는 함수

    return quantized

  def forward(self, x):
    M,D=self.embedding.size() # M개 코드, D 임베딩벡터크기
    x_flat=x.detach().reshape(-1,D)

    distances=torch.cdist(x_flat, self.embedding,p=2)

    indices=torch.argmin(distances.float(), dim=1)
    encodings=F.one_hot(indices, M).float() # 양자화된 코드북 인덱스를 원핫벡터로 / 각 코드북 벡터가 얼마나 선택되었는지 카운트

    quantized=F.embedding(indices, self.embedding)
    quantized=quantized.view_as(x)

    if self.training:
      self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)
      n = torch.sum(self.ema_count)
      self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n

      dw = torch.matmul(encodings.t(), x_flat)
      self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw
      self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)

    codebook_loss=F.mse_loss(x.detach(),quantized) # codebook만 학습
    e_latent_loss=F.mse_loss(x, quantized.detach()) # encoder만 학습
    commitment_loss=self.commitment_cost*e_latent_loss # scale된 encoder 손실

    quantized = x + (quantized - x).detach() # gradient override를 위한 테크닉- encoder 쪽에는 gradient를 흘리되, codebook 쪽에는 gradient가 안 흐르도록

    avg_probs = torch.mean(encodings, dim=0) # 각 코드북 벡터의 선택 확률(avg_probs)
    perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # 코드북 사용의 다양성을 측정

    return quantized, commitment_loss, codebook_loss, perplexity






In [7]:
class Decoder(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, kernel_sizes=(1, 3, 2, 2), stride=2):
        super(Decoder, self).__init__()

        kernel_1, kernel_2, kernel_3, kernel_4 = kernel_sizes

        self.in_proj = nn.Conv2d(input_dim, hidden_dim, kernel_size=1)

        self.residual_conv_1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_1, padding=0)
        self.residual_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_2, padding=1)

        self.strided_t_conv_1 = nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_3, stride, padding=0)
        self.strided_t_conv_2 = nn.ConvTranspose2d(hidden_dim, output_dim, kernel_4, stride, padding=0)

    def forward(self, x):

        x = self.in_proj(x)

        y = self.residual_conv_1(x)
        y = y+x
        x = F.relu(y)

        y = self.residual_conv_2(x)
        y = y+x
        y = F.relu(y)

        y = self.strided_t_conv_1(y)
        y = self.strided_t_conv_2(y)

        return y

In [11]:
class Model(nn.Module):
    def __init__(self, Encoder, Codebook, Decoder):
        super(Model, self).__init__()
        self.encoder = Encoder
        self.codebook = Codebook
        self.decoder = Decoder

    def forward(self, x):
        z = self.encoder(x)
        z_quantized, commitment_loss, codebook_loss, perplexity = self.codebook(z)
        x_hat = self.decoder(z_quantized)

        return x_hat, commitment_loss, codebook_loss, perplexity

In [16]:
encoder = Encoder(input=input_dim, hidden=hidden_dim, out=latent_dim)
codebook = VQEmbeddingEMA(n_embeddings=n_embeddings, embedding_dim=latent_dim)
decoder = Decoder(input_dim=latent_dim, hidden_dim=hidden_dim, output_dim=output_dim)

model = Model(Encoder=encoder, Codebook=codebook, Decoder=decoder).to(DEVICE)

In [17]:
from torch.optim import Adam

mse_loss = nn.MSELoss()

optimizer = Adam(model.parameters(), lr=lr)

In [18]:
model.train()

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, commitment_loss, codebook_loss, perplexity = model(x)
        recon_loss = mse_loss(x_hat, x)

        loss =  recon_loss + commitment_loss * commitment_beta + codebook_loss

        loss.backward()
        optimizer.step()

        if batch_idx % print_step ==0:
            print("epoch:", epoch + 1, "(", batch_idx + 1, ") recon_loss:", recon_loss.item(), " perplexity: ", perplexity.item(),
              " commit_loss: ", commitment_loss.item(), "\n\t codebook loss: ", codebook_loss.item(), " total_loss: ", loss.item(), "\n")

print("Finish!!")

epoch: 1 ( 1 ) recon_loss: 0.5390774011611938  perplexity:  14.41859245300293  commit_loss:  0.0031496435403823853 
	 codebook loss:  0.012598574161529541  total_loss:  0.5524634122848511 

epoch: 1 ( 51 ) recon_loss: 0.03720606490969658  perplexity:  25.929798126220703  commit_loss:  0.04992463439702988 
	 codebook loss:  0.1996985375881195  total_loss:  0.2493857592344284 

epoch: 1 ( 101 ) recon_loss: 0.023040538653731346  perplexity:  93.21887969970703  commit_loss:  0.03653959929943085 
	 codebook loss:  0.1461583971977234  total_loss:  0.1783338338136673 

epoch: 1 ( 151 ) recon_loss: 0.017749948427081108  perplexity:  162.18373107910156  commit_loss:  0.031406134366989136 
	 codebook loss:  0.12562453746795654  total_loss:  0.1512260138988495 

epoch: 1 ( 201 ) recon_loss: 0.01515465509146452  perplexity:  203.0373077392578  commit_loss:  0.028093479573726654 
	 codebook loss:  0.11237391829490662  total_loss:  0.13455194234848022 



KeyboardInterrupt: 