# **一个VQVAE算法的简单实现**
## **算法概述**
- 传统的VAE算法将数据编码为高斯分布
- 利用向量量化（vector quantisation）对数据进行编码
- 

论文链接：*https://arxiv.org/abs/1711.00937*

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import torch.optim as optim

device = 0 if torch.cuda.is_available else "cpu"
dataset_name = "mnist"

In [49]:
class VectorQuantizer(nn.Module):
    
    def __init__(self, n_e, e_dim, beta):
        super(VectorQuantizer, self).__init__()
        self.n_e = n_e
        self.e_dim = e_dim
        self.beta = beta
        
        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
        
    def forward(self, z):
        """
            1. get encoder input (B,C,H,W)
            2. flatten input to (B*H*W, C)
        """
        z = z.permute(0,2,3,1).contiguous()
        z_flattened = z.view(-1, self.e_dim)  # C == self.e_dim ???
        
        # cal the dis between each emb in z_flattened and emb in embedding
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) # norm l2 squared
        
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) # B * H * W, 1
        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.n_e).to(device) # B* H * W, n
        min_encodings.scatter_(1, min_encoding_indices, 1)
        
        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
        
        # compute loss
        loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
        
        # unknown
        z_q = z + (z_q - z).detach()
        
        # perplexity
        e_mean = torch.mean(min_encodings, dim=0) # the frequency of each n_emb
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
        
        z_q = z_q.permute(0, 3, 1, 2).contiguous()
        
        return loss, z_q, perplexity, min_encodings, min_encoding_indices
        

In [16]:
class ResidualLayer(nn.Module):
    def __init__(self, in_dim, h_dim, res_h_dim):
        super(ResidualLayer, self).__init__()
        self.res_block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(in_dim, res_h_dim, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(res_h_dim, h_dim, kernel_size=1, stride=1, bias=False)
        )
        
    def forward(self, x):
        x = x + self.res_block(x)
        return x

class ResidualStack(nn.Module):
    
    def __init__(self, in_dim, h_dim, res_h_dim, n_res_layers):
        super(ResidualStack, self).__init__()
        self.n_res_layers = n_res_layers
        self.stack = nn.ModuleList(
            [ResidualLayer(in_dim, h_dim, res_h_dim)] * n_res_layers)
        
    def forward(self, x):
        for layer in self.stack:
            x = layer(x)
        x = F.relu(x)
        return x

In [40]:
class Encoder(nn.Module):
    
    def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):
        super(Encoder, self).__init__()
        kernel = 4
        stride = 2
        self.conv_stack = nn.Sequential(
            nn.Conv2d(in_dim, h_dim // 2, kernel_size=kernel, stride=stride, padding=1),
            nn.ReLU(),
            nn.Conv2d(h_dim // 2, h_dim, kernel_size=kernel, stride=stride, padding=1),
            nn.ReLU(),
            nn.Conv2d(h_dim, h_dim, kernel_size=kernel-1, stride=stride-1, padding=1),
            ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers)
        )
        
    def forward(self, x):
        return self.conv_stack(x)
    
class Decoder(nn.Module):
    
    def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):
        super(Decoder, self).__init__()
        kernel = 4
        stride = 2
        
        self.inverse_conv_stack = nn.Sequential(
            nn.ConvTranspose2d(in_dim, h_dim, kernel_size=kernel-1, stride=stride-1, padding=1),
            ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers),
            
            nn.ConvTranspose2d(h_dim, h_dim // 2, kernel_size=kernel, stride=stride, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(h_dim // 2, 1, kernel_size=kernel, stride=stride, padding=1)
        )
        
    def forward(self, x):
        x = self.inverse_conv_stack(x)
        return x

In [41]:
class VQVAE(nn.Module):
    
    def __init__(self, h_dim, res_h_dim, n_res_layers,
                 n_embeddings, embedding_dim, beta, save_img_embedding_map=False):
        super(VQVAE, self).__init__()
        
        self.encoder = Encoder(1, h_dim,n_res_layers, res_h_dim)
        self.pre_quantization_conv = nn.Conv2d(h_dim, embedding_dim, kernel_size=1, stride=1) # n_embeddings is a large num usually
        self.vector_quantization = VectorQuantizer(n_embeddings, embedding_dim, beta)
        self.decoder = Decoder(embedding_dim, h_dim, n_res_layers, res_h_dim)
        
    def forward(self, x, verbose=False):
        z_e = self.encoder(x)
        
        z_e = self.pre_quantization_conv(z_e)
        embedding_loss, z_q, perplexity, _, _ = self.vector_quantization(z_e)
        x_hat = self.decoder(z_q)
        
        return embedding_loss, x_hat, perplexity

In [35]:
from datasets import load_dataset
from torchvision.transforms import Compose
from torchvision import transforms
from torch.utils.data import DataLoader

batch_size_train = 32
dataset = load_dataset(dataset_name)
transform = Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Lambda(lambda t:(t*2) - 1)
])



def transforms_(examples):
    examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
    del examples["image"]
    return examples

transformed_dataset = dataset.with_transform(transforms_).remove_columns("label")

dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size_train, shuffle=True)

In [52]:
def train():
    n_hiddens = 128
    n_residual_hiddens = 32
    n_residual_layers = 2
    n_embeddings = 64
    embedding_dim = 512
    beta = 0.25
    
    learning_rate = 3e-4
    epochs = 20
    
    model = VQVAE(n_hiddens, n_residual_hiddens, n_residual_layers, n_embeddings, embedding_dim, beta)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=True)
    model.train()
    model.to(device)
    
    batch_count = 0
    loss_sum_in_one_batch = 0
    for epoch in range(epochs):
        for batch_idx, batch in enumerate(dataloader):
            data = batch["pixel_values"].to(device)

            embedding_loss, x_hat, perplexity = model(data)
            recon_loss = F.mse_loss(x_hat, data)
            loss = embedding_loss + recon_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_sum_in_one_batch += loss.item()
            batch_count += 1
            if batch_count % 100 == 0:
                print(f"epoch {epoch} batch count {batch_count} , avg loss is {loss_sum_in_one_batch / batch_count}")
train()

epoch 0 batch count 100 , avg loss is 0.5926684428751469
epoch 0 batch count 200 , avg loss is 0.8302716314047575
epoch 0 batch count 300 , avg loss is 0.7681336539487044
epoch 0 batch count 400 , avg loss is 0.6432002666592598
epoch 0 batch count 500 , avg loss is 0.5553462177217007
epoch 0 batch count 600 , avg loss is 0.49119510595997173
epoch 0 batch count 700 , avg loss is 0.43868058823049066
epoch 0 batch count 800 , avg loss is 0.39402047006413343
epoch 0 batch count 900 , avg loss is 0.3575250707359778
epoch 0 batch count 1000 , avg loss is 0.3273730631731451
epoch 0 batch count 1100 , avg loss is 0.3019266311993653
epoch 0 batch count 1200 , avg loss is 0.2804014690127224
epoch 0 batch count 1300 , avg loss is 0.2619312561246065
epoch 0 batch count 1400 , avg loss is 0.24590713979942458
epoch 0 batch count 1500 , avg loss is 0.23185643459359806
epoch 0 batch count 1600 , avg loss is 0.21950621614814736
epoch 0 batch count 1700 , avg loss is 0.20847704347551746
epoch 0 batch co

KeyboardInterrupt: 