In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor

In [None]:
class ResBlockBasic(nn.Module):
    def __init__(
        self,
        n_channels: int,
    ):
        super().__init__()
        conv_params = {
            "in_channels": n_channels,
            "out_channels": n_channels,
            "kernel_size": 3,
            "stride": 1,
            "padding": 1,
        }
        self.conv1 = nn.Conv2d(**conv_params)
        self.conv2 = nn.Conv2d(**conv_params)

    def forward(self, x: Tensor) -> Tensor:
        y = F.relu(self.conv1(x))
        y = self.conv2(y)
        return x + y


class ResBlock(nn.Module):
    def __init__(
        self,
        n_channels: int,
        num_groups: int = 32,
    ):
        super().__init__()
        conv_params = {
            "in_channels": n_channels,
            "out_channels": n_channels,
            "kernel_size": 3,
            "stride": 1,
            "padding": 1,
        }
        self.conv1 = nn.Conv2d(**conv_params)
        self.conv2 = nn.Conv2d(**conv_params)
        self.gn1 = nn.GroupNorm(num_groups=num_groups, num_channels=n_channels)
        self.gn2 = nn.GroupNorm(num_groups=num_groups, num_channels=n_channels)

    def forward(self, x: Tensor) -> Tensor:
        y = self.gn1(x)
        y = F.silu(y)
        y = self.conv1(y)

        y = self.gn2(y)
        y = F.silu(y)
        y = self.conv2(y)

        return x + y


class NonLocalBlock(nn.Module):
    def __init__(self, n_channels: int):
        super().__init__()
        self.gn = nn.GroupNorm(num_channels=n_channels, num_groups=32)

        conv_params = {
            "in_channels": n_channels,
            "out_channels": n_channels,
            "kernel_size": 1,
            "stride": 1,
            "padding": 0,
        }
        self.q_conv = nn.Conv2d(**conv_params)
        self.k_conv = nn.Conv2d(**conv_params)
        self.v_conv = nn.Conv2d(**conv_params)
        self.out_conv = nn.Conv2d(**conv_params)

    def forward(self, x: Tensor) -> Tensor:
        y = self.gn(x)

        q = self.q_conv(y)  # (B, C, H, W) -> (B, C, H, W)
        k = self.k_conv(y)
        v = self.v_conv(y)

        q = torch.flatten(q, 2)  # (B, C, H, W) -> (B, C, H*W)
        k = torch.flatten(k, 2)
        v = torch.flatten(v, 2)

        n_channels = y.shape[1]

        q_t = q.permute(0, 2, 1)  # Transpose pixels and channels

        attn = q_t.bmm(k)  # (B, H*W, C) @ (B, C, H*W) -> (B, H*W, H*W)
        attn_norm = torch.softmax(attn.div(int(n_channels) ** 0.5), -1)

        # (B, C, H*W) @ (B, H*W, H*W)
        out = v.bmm(attn_norm.permute(0, 2, 1))
        out = out.reshape(y.shape[0], y.shape[1], y.shape[2], y.shape[3])
        out = self.out_conv(out)

        return out + x


class Encoder(nn.Module):
    def __init__(
        self,
        latent_dim: int = 256,
    ):
        super().__init__()
        downsample_params = {
            "kernel_size": 3,
            "stride": 2,
            "padding": 1,
        }
        self.initial_conv = nn.Conv2d(
            in_channels=3, out_channels=128, kernel_size=3, stride=1, padding=1
        )

        self.conv1 = nn.Conv2d(in_channels=128, out_channels=256, **downsample_params)
        self.conv2 = nn.Conv2d(in_channels=256, out_channels=512, **downsample_params)
        self.conv3 = nn.Conv2d(in_channels=512, out_channels=1024, **downsample_params)
        self.conv4 = nn.Conv2d(in_channels=1024, out_channels=1024, **downsample_params)

        self.resblock1 = ResBlock(n_channels=128, num_groups=32)
        self.resblock2 = ResBlock(n_channels=256, num_groups=32)
        self.resblock3 = ResBlock(n_channels=512, num_groups=32)
        self.resblock4 = ResBlock(n_channels=1024, num_groups=32)

        self.resblock5 = ResBlock(n_channels=1024, num_groups=32)
        self.resblock6 = ResBlock(n_channels=1024, num_groups=32)

        self.nonlocal_block = NonLocalBlock(n_channels=1024)

        self.latent_conv = nn.Conv2d(
            in_channels=1024,
            out_channels=latent_dim,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.gn = nn.GroupNorm(num_groups=32, num_channels=1024)

    def forward(self, x: Tensor) -> Tensor:
        x = self.initial_conv(x)

        x = self.resblock1(x)
        x = self.conv1(x)

        x = self.resblock2(x)
        x = self.conv2(x)

        x = self.resblock3(x)
        x = self.conv3(x)

        x = self.resblock4(x)
        x = self.conv4(x)

        x = self.resblock5(x)
        x = self.nonlocal_block(x)
        x = self.resblock6(x)

        x = self.gn(x)
        x = F.silu(x)
        x = self.latent_conv(x)

        return x


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )

    def forward(self, x: Tensor) -> Tensor:
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        x = self.conv(x)
        return x


class Decoder(nn.Module):
    def __init__(self, latent_dim: int = 256):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=latent_dim,
            out_channels=1024,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.resblock_nl1 = ResBlock(n_channels=1024, num_groups=32)
        self.resblock_nl2 = ResBlock(n_channels=1024, num_groups=32)

        self.nonlocal_block = NonLocalBlock(n_channels=1024)

        self.upblock1 = UpsampleBlock(in_channels=1024, out_channels=1024)
        self.upblock2 = UpsampleBlock(in_channels=1024, out_channels=512)
        self.upblock3 = UpsampleBlock(in_channels=512, out_channels=256)
        self.upblock4 = UpsampleBlock(in_channels=256, out_channels=128)

        self.resblock1 = ResBlock(n_channels=1024, num_groups=32)
        self.resblock2 = ResBlock(n_channels=1024, num_groups=32)
        self.resblock3 = ResBlock(n_channels=512, num_groups=32)
        self.resblock4 = ResBlock(n_channels=256, num_groups=32)

        self.gn = nn.GroupNorm(num_channels=128, num_groups=32)
        self.conv2 = nn.Conv2d(
            in_channels=128, out_channels=3, kernel_size=3, stride=1, padding=1
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)

        x = self.resblock_nl1(x)
        x = self.nonlocal_block(x)
        x = self.resblock_nl2(x)

        x = self.resblock1(x)
        x = self.upblock1(x)
        x = self.resblock2(x)
        x = self.upblock2(x)
        x = self.resblock3(x)
        x = self.upblock3(x)
        x = self.resblock4(x)
        x = self.upblock4(x)

        x = self.gn(x)
        x = F.silu(x)
        x = self.conv2(x)

        return x


class VectorQuantizer(nn.Module):
    def __init__(
        self,
        codebook_size: int,
        embedding_dim: int,
    ):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.table = nn.Embedding(  # (K, C)
            num_embeddings=codebook_size,
            embedding_dim=self.embedding_dim,
        )

    def forward(self, z: Tensor) -> Tensor:
        batch_size, n_channels, height, width = z.shape
        assert n_channels == self.embedding_dim, (
            "Channel and embedding dimension is not the same"
        )

        # (B, C, H, W) -> (B, H, W, C)
        z = z.permute(0, 2, 3, 1)

        # (B, H, W, C) -> (B*H*W, C) = (N, C)
        z_flat = z.reshape(-1, n_channels)

        # (N, K)
        distances = (
            z_flat.pow(2).sum(dim=1, keepdim=True)  # (N, C) -> (N, 1)
            + self.table.weight.pow(2).sum(dim=1)  # (K, C) -> (K,)
            - 2 * (z_flat @ self.table.weight.T)  # (N, C) @ (C, K) -> (N, K)
        )

        # Indices of the closest embedding for each input embedding vector
        indices = distances.argmin(dim=1)
        z_q_raw = self.table.weight[indices]

        # Loss calculation before the straight through estimator
        codebook_loss = (z_flat.detach() - z_q_raw).pow(2).mean()
        commitment_loss = (z_flat - z_q_raw.detach()).pow(2).mean()
        total_loss = codebook_loss + commitment_loss

        # Straight through estimator to pass the gradients
        z_q = z_flat + (z_q_raw - z_flat).detach()  # (N, C)

        # Reshape the quantized embeddings back to the expected shape
        z_q = z_q.view(batch_size, height, width, n_channels)
        z_q = z_q.permute(0, 3, 1, 2)

        per_back_indices = indices.view(batch_size, -1)

        return z_q, per_back_indices, total_loss


class VAR(nn.Module):
    def __init__(
        self,
        scales: tuple[int] = (1, 2, 4, 8, 16),
        latent_dim: int = 256,
        codebook_size: int = 512,
    ):
        super().__init__()
        self.scales = scales
        self.encoder = Encoder(latent_dim=latent_dim)
        self.decoder = Decoder(latent_dim=latent_dim)
        self.quantizer = VectorQuantizer(
            codebook_size=codebook_size,
            embedding_dim=latent_dim,
        )

        conv_params = {
            "in_channels": latent_dim,
            "out_channels": latent_dim,
            "kernel_size": 3,
            "stride": 1,
            "padding": 1,
        }
        self.convs = nn.ModuleList([nn.Conv2d(**conv_params) for _ in self.scales])

    def forward(self, x: Tensor) -> Tensor:
        f = self.encoder(x)
        f_original = f.clone()

        f_hat = torch.zeros_like(f)
        q_loss = 0.0
        scale_indices: list[Tensor] = []

        for scale_id, scale in enumerate(self.scales):
            f_inter_scale = F.interpolate(
                f,
                size=(scale, scale),
                mode="bilinear",
                align_corners=False,
            )
            z_q, indices, loss_scale = self.quantizer(f_inter_scale)
            z_inter_full = F.interpolate(
                z_q,
                size=(self.scales[-1], self.scales[-1]),
                mode="bilinear",
                align_corners=False,
            )
            conv_out = self.convs[scale_id](z_inter_full)
            f = f - conv_out
            f_hat = f_hat + conv_out
            q_loss = q_loss + loss_scale
            scale_indices.append(indices)

        x_hat = self.decoder(f_hat)

        f_loss = (f_hat - f_original).pow(2).mean()
        r_loss = (x_hat - x).pow(2).mean()

        return x_hat, q_loss, f_loss, r_loss, scale_indices


x = torch.randn(1, 3, 256, 256)
var = VAR()
x_hat, _, _, _, scale_indices = var(x)
print(x_hat.shape)
print([t.shape for t in scale_indices])

In [None]:
class VARTransformer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int = 256,
        scales: tuple[int] = (1, 2, 4, 8, 16),
        n_transformer_layers: int = 4,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.scales = scales
        self.max_sequence_length = sum([x**2 for x in self.scales])
        self.n_transformer_layers = n_transformer_layers

        self.start_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        self.vocab_embeddings = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
        )
        self.positional_embeddings = nn.Embedding(
            num_embeddings=self.max_sequence_length,
            embedding_dim=embedding_dim,
        )

        self.blocks = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=self.embedding_dim,
                nhead=16,
                dim_feedforward=4 * embedding_dim,
                batch_first=True,
            ),
            num_layers=self.n_transformer_layers,
        )

        self.classification_head = nn.Linear(
            in_features=self.embedding_dim,
            out_features=self.vocab_size,
        )

    @torch.no_grad()
    def mask_matrix(self) -> Tensor:
        size = self.max_sequence_length
        mask = torch.zeros(size, size)
        mask = torch.fill(mask, -torch.inf)

        # First add the mask for the start token
        mask[0, 0] = 0
        current_start = 1

        # Add the mask for each of the scales
        # except the last which is not used as input
        for scale in self.scales[:-1]:
            # We are working with the upsampled scale (hence *2)
            n_tokens = (scale * 2) ** 2

            current_end = current_start + n_tokens
            mask[current_start:current_end, 0:current_end] = 0
            current_start = current_start + n_tokens

        return mask

    def forward(self, x: list[Tensor]) -> Tensor:
        # Each x is a list of Tensors for each scale
        # where each Tensor has shape (batch_size, N) for that scale
        batch_size = x[0].shape[0]
        input_list = [self.start_token.expand(batch_size, -1, -1)]

        # Building the input for the transformer
        for scale_idx, scale_tokens in enumerate(x):
            if scale_idx == (len(self.scales) - 1):
                continue

            print(f"scale tokens: {scale_tokens.shape}")

            # 1. Embedding the tokens in embeddings
            # - Scale tokens is (B, seq_length)
            # - Embedding is (B, seq_length, C)
            words: Tensor = self.vocab_embeddings(scale_tokens)

            print(f"word embeddings: {words.shape}")

            # 2. Reshape into a batched grid of embeddings for the scale (B, C, H, W)
            # - We push the sequence length to th__e end, and then reshape it to the grid
            # - This prepares us for the interpolation
            n_channels = self.embedding_dim
            height, width = self.scales[scale_idx], self.scales[scale_idx]
            words = words.permute(0, 2, 1)  # (B, seq_length, C) -> (B, C, seq_length)
            print(f"words permuted: {words.shape}")
            words = words.view(
                -1, n_channels, height, width
            )  # (B, C, seq_length) -> (B, C, H, W)
            print(f"words view: {words.shape}")

            # 3. Interpolate to the size of the next scale (H*2, W*2)
            words = F.interpolate(
                words, scale_factor=2, mode="bilinear"
            )  # (B, C, H, W)
            print(f"words upscaled: {words.shape}")

            # 4. Reshape it back to the sequence for the transformer (B, H*2*W*2, C)
            words = words.permute(0, 2, 3, 1)  # (B, C, H, W) -> (B, H, W, C)
            print(f"words channel last: {words.shape}")
            words = words.view(
                -1, height * width * 4, n_channels
            )  # (B, H, W, C) -> (B, seq_length, C)
            print(f"words (b, seq, C): {words.shape}")

            input_list.append(words)

        x = torch.cat(input_list, dim=1)
        sequence_length = x.shape[1]  # (B, seq_len, C)
        pos_indices = torch.arange(sequence_length, device=x.device)
        pos_emb = self.positional_embeddings(pos_indices)

        x = x + pos_emb

        output: nn.TransformerEncoder = self.blocks(x, mask=self.mask_matrix())
        output = self.classification_head(output)
        print(output.shape)
        return output


x = [
    torch.arange(1, 2, 1).unsqueeze(0),
    torch.arange(1, 5, 1).unsqueeze(0),
    torch.arange(1, 17, 1).unsqueeze(0),
    torch.arange(1, 65, 1).unsqueeze(0),
    torch.arange(1, 257, 1).unsqueeze(0),
]
model = VARTransformer(512)
y = model(x)