In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose(
    [transforms.Grayscale(num_output_channels=3), transforms.ToTensor()]
)
train_ds = datasets.MNIST(root=".", train=True, transform=transform, download=True)
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True)

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=3, out_channels=32, padding=1, kernel_size=3, stride=1
        )
        self.conv2 = nn.Conv2d(
            in_channels=32, out_channels=64, padding=1, kernel_size=3, stride=2
        )
        self.conv3 = nn.Conv2d(
            in_channels=64, out_channels=64, padding=1, kernel_size=3, stride=2
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [None]:
x = torch.randn(32, 64, 7, 7)
print(x.shape)
x = x.permute(0, 2, 3, 1)
print(x.shape)
x = x.flatten(0, 2)
print(x.shape)

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(
        self, num_embeddings: int = 128, embedding_dim: int = 64, beta: float = 0.25
    ):
        super().__init__()
        self.beta = beta
        self.n_embeddings = num_embeddings
        self.table = nn.Embedding(
            num_embeddings=num_embeddings,
            embedding_dim=embedding_dim,
        )
        # nn.init.normal_(self.table.weight, 0, 0.1)

    def forward(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size = x.shape[0]
        dimension = x.shape[1]
        height = x.shape[2]
        width = x.shape[3]

        x = x.permute(0, 2, 3, 1)  # BxDxHxW -> BxHxWxD
        x_flat = x.flatten(0, 2)  # BxHxWxD-> (B*H*W)xD = NxD

        distances = (  # (N, K)
            x_flat.pow(2).sum(dim=1, keepdim=True)  # (N, D) -> (N, 1)
            + self.table.weight.pow(2).sum(dim=1)  # (K, D) -> (K,)
            - 2 * x_flat @ self.table.weight.t()  # (N, D) @ (D, K) -> (N, K)
        )

        indices_flat = distances.argmin(dim=1)  # (N,)
        embeddings_flat = self.table(indices_flat)  # (N, D)

        z_q_flat = x_flat + (embeddings_flat - x_flat).detach()

        codebook_loss = (x_flat.detach() - embeddings_flat).pow(2).mean()
        commitment_loss = self.beta * (x_flat - embeddings_flat.detach()).pow(2).mean()
        total_loss = codebook_loss + commitment_loss

        z_q = z_q_flat.view(batch_size, height, width, dimension)
        z_q = z_q.permute(0, 3, 1, 2)

        indices = indices_flat.view(batch_size, height, width)  # (N,) -> (B, H, W)

        return z_q, indices, total_loss

In [None]:
vq = VectorQuantizer()
x = torch.randn(32, 64, 7, 7)
z_q, indices, loss = vq.forward(x)
z_q.shape, indices.shape, loss

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.tconv1 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1)
        self.tconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.tconv3 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.tconv1(x))
        x = F.relu(self.tconv2(x))
        x = self.tconv3(x)
        return x


x = torch.randn(4, 64, 7, 7)
model = Decoder()
y = model(x)
print(y.shape)

In [None]:
class VQVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.vq = VectorQuantizer(num_embeddings=128, embedding_dim=64, beta=0.25)
        self.decoder = Decoder()

    def forward(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        z_e = self.encoder(x)
        z_q, indices, vq_loss = self.vq(z_e)
        logits = self.decoder(z_q)
        return logits, vq_loss, indices

In [None]:
model = VQVAE()
device = "mps" if torch.mps.is_available() else "cpu"
model.to(device)

criterion = nn.BCEWithLogitsLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    for x, _ in train_dl:
        x = x.to(device)
        optimizer.zero_grad()
        logits, vq_loss, indices = model(x)
        # recon_loss = F.binary_cross_entropy_with_logits(logits, x)
        recon_loss = criterion(logits, x)
        loss = recon_loss + vq_loss
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
    print(f"Epoch {epoch}  Loss: {loss.item()}")

In [None]:
# visualize a decoded image and original image side by side
index = 78
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(x[index].cpu().numpy().transpose(1, 2, 0))
plt.subplot(1, 2, 2)
plt.imshow(logits[index].sigmoid().detach().cpu().numpy().transpose(1, 2, 0))
plt.show()