## Neural Discrete Representation Learning, Aaron van den Oord et al., 

### 관련 자료
[유튜브 강의](https://www.youtube.com/watch?v=WTqPCPeipEY)

코드 🧑‍💻
- https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/vae.py
- https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py
- https://keras.io/examples/generative/vq_vae/
https://github.com/nadavbh12/VQ-VAE
https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py
https://github.com/ritheshkumar95/pytorch-vqvae
https://nbviewer.org/github/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb
https://github.com/HwangJohn/vq-vae-test/blob/main/model/vqvae.py

### 핵심 정리
1. VAE에 continuous latent space(normal distribution)가 아닌 discrete latent embedding(categorical distribution)을 적용하였다
2. 즉 딕셔너리 형태로 discrete 카테고리 매핑할 수 있다
3. 이름은 Variational이 들어가지만 그냥 AE이다. (KLD가 없다)

### Motivations

[1] Discrete categorical nature
- 실 세계의 물체들을 discrete category로 생각할 수 있다 : cat, car, etc.
- DNA sequence : A,C,G,T
- 알파벳 : A,B,C …


[2] Posterior collapse

- Poseteior collapse 정의
    - $x$에서 posterior parameter로 전달되는 시그널이 너무 약하거나 노이즈가 많으면 posterior이 collapsing된다고 표현함
    - 이 경우 posterior $q_{\phi}(z |x)$에서 나오는 샘플 $z$ ($z \sim q_{\phi}(z|x)$)가 디코더에서 무시되기 시작함
    - 시그널이 noisy하다는 것은  $\mu_d, \sigma_d$가 unstable ⇒ 샘플된 $z$ 또한 unstable ⇒ 디코더가 $z$를 무시하기 시작
    - Decoder가 z를 무시하게 된다 ⇒ $\hat{x}$가 $z$와 independent해진다.

- Posterior collapse 원인
    - 1) 인코더와 디코더의 불균형
    - 인코더가 너무 약하거나 ⇒ 의미있는 시그널을 충분히 인코딩하지 않게 됨
    - 인코더가 너무 강하거나 ⇒ 데이터안에 있는 노이즈까지 인코딩
    - 디코더가 상대적으로 강해하면 ⇒  의미없는 정보를 알아서 맞춰서 시그널을 살려냄

    - 2) ELBO와 evidence 사이의 gap, true posterior approximation의 실패
    - 3) Ill-posed problem이기 때문에 조건에 맞는 다양한 latent z가 존재할 수 있는 가능성
    - 4) 가정한 Gaussian prior가 적합하지 않는 가능성

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

![Image](https://github.com/user-attachments/assets/1d4acc4e-12ed-403c-8298-436b017ca8a4)

- 3가지 구성
- `Encoder`
- `Decoder` : 코드북 벡터를 이용해 디코딩
- `VQ Layer`
    - `embedding dictionary (코드북)` $e \in \Bbb{R}^{K \times D}$를 포함하도록 함
    - 각각의 임베딩 하나는 $e_i \in \Bbb{R}^D, i \in [1, \dots, K]$

- Encoder에서 나온 `z_e`와 `코드북`의 각각 임베딩과 비교를 해서 가장 비슷한 녀석에 대응되는 값으로 치환하도록 함

- 인덱스만 따져보면 Poseterior categorial distribution 형태로 생각할 수 있음
$$
q(z=k |x) = \begin{cases}
   1 &\text{for}  k = \argmin_j ||z_e(x) - e_j ||_2 \\
   0 &\text{otherwise }
\end{cases}
$$



![Image](https://github.com/user-attachments/assets/355773d1-2f9c-4032-b338-b269d7f8aa4f)


In [None]:
class VectorQuantizer(nn.Module):
    def __init__(
        self, 
        num_embeddings,
        embedding_dim,
        commitment_cost,
    ):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost

        # 코드북 생성
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        # 초기화
        nn.init.uniform_(self.embedding.weight, -1.0/self.num_embeddings, 1.0/self.num_embeddings)

    def forward(self, z_e):
        """
        z_e: (B, D, H, W)
        return:
          z_q: quantized output same shape (B, D, H, W)
          vq_loss, commit_loss
          encoding_indices: (B, H, W)
        """
        # (B, D, H, W) -> (B*H*W, D)
        z = z_e.permute(0, 2, 3, 1).contiguous()
        z_flattened  = z.view(-1, self.embedding_dim)  # (B*H*W, D)

        # (2) 임베딩 벡터와 거리 계산 후 (3) 가장 가까운 임베딩 index 선택
        # torch.cdist 결과 형태: (B*H*W, K)
        min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)

        # (4) 인덱스로부터 임베딩 벡터를 가져온 뒤 (5) reshape
        z_q = self.embedding(min_encoding_indices).view(z.shape) # (B*H*W, D)


        z_q = z_q.view(z_e.size(0), z_e.size(2), z_e.size(3), self.embedding_dim) # (B,H,W,D)

        z_q = z_q.permute(0,3,1,2).contiguous() # (B,D,H,W)

        loss = torch.zeros((z.shape[0])).to(z.device, dtype=z.dtype)
        if self.training:
            # (6) 임베딩 파라미터 쪽 gradient 계산
            vq_loss = self.commitment_cost * torch.mean((z_q.detach() - z_e) ** 2)
            # (7) 인코더 파라미터 쪽 gradient 계산
            commit_loss = torch.mean((z_q - z_e.detach()) ** 2)

            loss = vq_loss + self.commitment_cost * commit_loss

        # quantized output: z_e + (z_q - z_e).detach()
        # trick: pass z_q gradient to encoder
        z_q_out = z_e + (z_q - z_e).detach()

        return z_q_out, loss, min_encoding_indices

### Objectives

![Image](https://github.com/user-attachments/assets/59f31927-cb95-4c07-b20f-fe9d08021bb5)


$$
\mathcal{L} = \mathcal{L}_{cookbook} + \mathcal{L}_{recon} + \mathcal{L}_{commitment} \\
\mathcal{L}_{recon} = ||D(E(x)) - x||^2_2 \\
\mathcal{L}_{cookbook} = \frac{1}{s} \sum_s ||sg[h_s] - e_{z_z}||^2_2

$$
- Recon loss : encoder, decoder를 최적화
- VQ Loss(Codebook Loss)
    - 임베딩을 업데이트하는 역할 
    - 코드북이 인코더 아웃풋에 있을 법한 것들로 구성되게끔
    - gradient = ||sg[z_e] - e||^2
- Commitment Loss
    - 인코더 파라미터 업데이트하는 역할
    - 인코더가 되도록이면 코드북 스러운 아웃풋을 내도록
    - gradient = ||z_e - sg[e]||^2


- 위처럼 이중으로 Loss를 거는 이유는 VQ Layer안의 argmin operation 때문
- non-linear + 미분 불가능 (not differentiated with respect to its input) ⇒ backprop이 안됨

<img src="https://github.com/user-attachments/assets/51c66354-bcc6-44ef-8f3f-63bbfd2e7660" width="250" height="250"/>


- 디코더 쪽에서 오는 grad는 VQ를 무시하고 인코더에 그대로 복사하는 식으로 전달 가능 (straight-through estimator)
- encoder와 decoder가 같은 channel space를 갖기 때문에 decoder의 gradient가 여전히 encoder에도 도움이 되길 기대
- 디코더 쪽의 grad는 VQ Layer로 전달되지 않는 상황

- VQ Loss와 Commitment Loss를 분리하지 않고 ||z_e - z_q|| 하면 안되는 이유 : 서로 고정된 참조 대상으로 삼아야 학습이 안정 (마치 DQN에서 off-policy하듯이)


In [None]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)
    

class Upsample(nn.Module):
    def __init__(
            self, 
            channels: int,
            use_conv: bool = False,
            use_conv_transpose: bool = False,
            out_channels = None,
            interpolate=True,
    ):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.interpolate = interpolate

        conv = None
        if use_conv_transpose:
            conv = nn.ConvTranspose2d(
                channels, self.out_channels, kernel_size=4, stride=2, padding=1, bias=True
            )
        elif use_conv:
            conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=3, padding=1, bias=True)

        self.conv = conv

    def forward(self, hidden_states):
        if self.use_conv_transpose:
            return self.conv(hidden_states)
        
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
            hidden_states = hidden_states.to(torch.float32)

        hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")

        if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
            hidden_states = hidden_states.to(dtype)

        if self.use_conv:
            hidden_states = self.conv(hidden_states)
        return hidden_states


class Downsample(nn.Module):
    def __init__(
            self,
            channels: int,
            use_conv: bool = False,
            out_channels: Optional[int] = None,
            padding: int = 1,
    ):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding


    def forward(self, hidden_states):
        if self.use_conv and self.padding == 0:
            pad = (0, 1, 0, 1)
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)

        hidden_states = self.conv(hidden_states)
        return hidden_states
    


class ResBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            groups,
            eps,
            use_in_shortcut: Optional[bool] = None,
    ):
        super().__init__()
        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        self.nonlinearity = Swish()


        if self.up:


        self.use_in_shortcut = self.in_channels != out_channels if use_in_shortcut is None else use_in_shortcut



    def forward(self, x):
        """
        GroupNorm > Swish > Conv > UpSample DownSample > GroupNorm > Swish > Conv > Shortcut
        """
        hidden_states = x
        hidden_states = self.norm1(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)

        if self.upsample is not None:
        elif self.downsample is not None:

        hidden_states = self.conv1(hidden_states)

        hidden_states = self.norm2(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)

        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        if self.conv_shortcut is not None:
            x = self.conv_shortcut(x)

        output_tensor = (x + hidden_states)
        return output_tensor


class DownEncoderBlock(nn.Module):
    def __init__(self, ):
        

class Encoder(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            block_out_channels: Tuple[int, ...] = (64,),

    ):
        super().__init__()
        self.conv_in = nn.Conv2d(
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )

        # down
        self.down_blocks = nn.ModuleList([])


        # out
        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)


    def forward(self, sample):
        sample = self.conv_in(sample)
        # down
        for down_block in self.down_blocks:
            sample = down_block(sample)

        # post-process
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        return sample
    

class Decoder(nn.Module):


In [None]:
class VQVAE(nn.Module):
    def __init__(
        self, 
        in_channels=3,
        embedding_dim=64,
        num_embeddings=512,
        commitment_cost=0.25,
        hidden_channels=128,
    ):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        
        self.encoder = Encoder()
        self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
        self.vq = VectorQuantizer()
        self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
        self.decoder = Decoder()


    def encode(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        return h
        
    def decode(self, h):
        z_q, vq_loss, _, perplexity = self.vq(h)
        z_q2 = self.post_quant_conv(z_q)
        dec = self.decoder(z_q2)
        return dec, vq_loss, perplexity

    def forward(self, x):
        h = self.encode(x)
        x_recon, vq_loss, perplexity = self.decode(h)
        return x_recon, vq_loss, perplexity

In [None]:
import numpy as np

import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid


batch_size = 256
num_training_updates = 10000
num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
embedding_dim = 64
num_embeddings = 512
commitment_cost = 0.25
decay = 0.99
learning_rate = 1e-3


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

training_data = datasets.CIFAR10(root="/home/aiteam/tykim/generative_model/data", train=True, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

validation_data = datasets.CIFAR10(root="/home/aiteam/tykim/generative_model/data", train=False, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))


training_loader = DataLoader(training_data, 
                             batch_size=batch_size, 
                             shuffle=True,
                             pin_memory=True)

validation_loader = DataLoader(validation_data,
                               batch_size=32,
                               shuffle=True,
                               pin_memory=True)


model = VQVAE(embedding_dim=embedding_dim, num_embeddings=num_embeddings, commitment_cost=commitment_cost).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)


model.train()
train_res_recon_error = []
train_res_perplexity = []

for i in range(num_training_updates):
    (data, _) = next(iter(training_loader))
    data = data.to(device)
    optimizer.zero_grad()

    data_recon, vq_loss, perplexity = model(data)
    recon_error = F.mse_loss(data_recon, data, reduction='mean') * data.shape[0]
    loss =  recon_error + vq_loss * data.shape[0]
    loss.backward()

    optimizer.step()

    train_res_recon_error.append(recon_error.item())
    train_res_perplexity.append(perplexity.item())

    if (i+1) % 100 == 0:
        print('%d iterations' % (i+1))
        print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:]))
        print('perplexity: %.3f' % np.mean(train_res_perplexity[-100:]))
        print()

In [None]:
model.eval()

(valid_originals, _) = next(iter(validation_loader))
valid_originals = valid_originals.to(device)

vq_output_eval = model._pre_vq_conv(model._encoder(valid_originals))
_, valid_quantize, _, _ = model._vq_vae(vq_output_eval)
valid_reconstructions = model._decoder(valid_quantize)