In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, latent_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
    def forward(self, x):
        return self.net(x)


class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, z):
        return self.net(z)


class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost
        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)

    def forward(self, z_e):
        # z_e: [B, D]
        flat_input = z_e.view(-1, self.embedding_dim)
        # Compute distances to codebook entries
        distances2 = (
            torch.sum(flat_input ** 2, dim=1, keepdim=True)
            + torch.sum(self.embeddings.weight ** 2, dim=1)
            - 2 * torch.matmul(flat_input, self.embeddings.weight.t())
        )
        encoding_indices = torch.argmin(distances2, dim=1)
        quantized = self.embeddings(encoding_indices).view(z_e.shape)
        # Codebook and commitment losses
        commitment_loss = F.mse_loss(quantized.detach(), z_e)
        codebook_loss = F.mse_loss(quantized, z_e.detach())
        loss = codebook_loss + self.commitment_cost * commitment_loss
        # Straight-through estimator
        quantized = z_e + (quantized - z_e).detach()
        return quantized, loss, encoding_indices


class VQVAE(nn.Module):
    def __init__(self, in_dim, hidden_dim, latent_dim, num_embeddings, commitment_cost):
        super().__init__()
        self.encoder = Encoder(in_dim, hidden_dim, latent_dim)
        self.vq = VectorQuantizer(num_embeddings, latent_dim, commitment_cost)
        self.decoder = Decoder(latent_dim, hidden_dim, in_dim)

    def forward(self, x):
        z_e = self.encoder(x)
        z_q, vq_loss, encoding_indices = self.vq(z_e)
        x_recon = self.decoder(z_q)
        recon_loss = F.mse_loss(x_recon, x)
        return x_recon, vq_loss, recon_loss, encoding_indices

In [10]:
in_dim = 32
hidden_dim = 64
latent_dim = 16
num_embeddings = 128
commitment_cost = 0.25

model = VQVAE(in_dim, hidden_dim, latent_dim, num_embeddings, commitment_cost)
x = torch.randn(8, in_dim)  # batch of data
x_recon, vq_loss, recon_loss, encoding_indices = model(x)
total_loss = recon_loss + vq_loss

print(f"Reconstruction loss: {recon_loss.item()}")
print(f"VQ loss (commitment + codebook): {vq_loss.item()}")
print(f"Total loss:{total_loss.item()}")
print(f"Encoding indices (codebook ids used): {encoding_indices}")
print()

# Suppose we want to sample 4 new examples
num_samples = 4
sampled_indices = torch.randint(low=0, high=num_embeddings, size=(num_samples,))
sampled_embeddings = model.vq.embeddings(sampled_indices)
generated = model.decoder(sampled_embeddings)
print(f"Sampled embeddings: {sampled_embeddings}")
print()
print(f"Generated samples: {generated}")
print()

Reconstruction loss: 1.0836589336395264
VQ loss (commitment + codebook): 0.08529667556285858
Total loss:1.1689555644989014
Encoding indices (codebook ids used): tensor([  0,  13,  27,   3,  63, 115,  64,   2])

Sampled embeddings: tensor([[ 0.0021, -0.0006,  0.0030, -0.0030, -0.0055,  0.0051,  0.0027,  0.0025,
         -0.0011, -0.0010, -0.0070, -0.0073, -0.0061, -0.0044, -0.0062,  0.0070],
        [ 0.0013, -0.0036,  0.0056, -0.0037, -0.0001,  0.0076,  0.0014,  0.0066,
          0.0077, -0.0047,  0.0075, -0.0028,  0.0022,  0.0040, -0.0010,  0.0009],
        [-0.0062,  0.0046,  0.0078, -0.0004, -0.0036,  0.0034,  0.0018, -0.0010,
         -0.0067, -0.0063,  0.0059, -0.0013, -0.0059, -0.0045, -0.0014,  0.0050],
        [ 0.0013,  0.0076, -0.0018, -0.0040,  0.0054,  0.0004,  0.0076,  0.0044,
         -0.0009, -0.0040, -0.0014,  0.0003,  0.0069,  0.0044,  0.0029,  0.0037]],
       grad_fn=<EmbeddingBackward0>)

Generated samples: tensor([[-0.1059,  0.0553,  0.2098, -0.0345, -0.0691, -0.14