In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange, repeat, pack, unpack
from torch.utils.data import DataLoader
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

import lovely_tensors as lt
lt.monkey_patch()

from torchvision import datasets
from torchvision import transforms
from torchvision.utils import make_grid

In [None]:
class VQ(nn.Module):
    def __init__(self, commitment_cost, num_embedding = 10, embedding_dim = 32):
        super().__init__()

        self._num_embeddings = num_embedding
        self._embedding_dim = embedding_dim
        self.embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # B, C, H, W => B, H, W, C
        inputs = inputs.permute(0, 2, 3, 1).contiguos()
        input_shape = inputs.shape
        #  B, H, W, C => BHW, C
        flat_input = inputs.view(-1, self._embedding_dim)
        # [BHW, num_embedding]
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) +
                     torch.sum(self._embedding.weight**2, dim=1)
                     -2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        # BHW, 1, dim 방향으로 
        encoding_indices = torch.argmin(distances, dim=1, keepdim=True)
        # [BHW, num_embs]
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

In [None]:
class VQEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super().__init__()

In [None]:
class CosineSimCodebook(nn.Module):
    def __init__(self):
        super().__init__()
        

In [None]:
training_data = datasets.CIFAR10(root="/home/aiteam/tykim/generative_model/data", 
                                 train=True, 
                                 download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

validation_data = datasets.CIFAR10(root="/home/aiteam/tykim/generative_model/data",
                                   train=False, 
                                   download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

In [None]:
training_data.data.shape