In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerVAE(nn.Module):
    def __init__(self, seq_len, embed_dim, latent_dim, num_heads, num_layers, vocab_size):
        super(TransformerVAE, self).__init__()
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.enc_project = nn.Linear(embed_dim, 2 * embed_dim)

        self.fc_mu = nn.Linear(embed_dim * seq_len, latent_dim)
        self.fc_logvar = nn.Linear(embed_dim * seq_len, latent_dim)
        self.fc_latent_to_hidden = nn.Linear(latent_dim, embed_dim * seq_len)

        decoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=num_layers)

        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def encode(self, x):
        x = self.embedding(x)  # [batch_size, seq_len, embed_dim]
        batch_size, _, embed_dim = x.shape

        x = x.permute(1, 0, 2)  # [seq_len, batch_size, embed_dim]
        x = self.encoder(x)  # [seq_len, batch_size, embed_dim]

        x = self.enc_project(x) # (seq_len, batch_size, 2 * embed_dim)
        x = x.permute(1, 0, 2)  # [batch_size, seq_len, 2 * embed_dim]
        x_mu, x_logvar = x.split(split_size=embed_dim, dim=2) # [batch_size, seq_len, embed_dim]
        x_mu = x_mu.reshape(batch_size, -1)  # [batch_size, seq_len * embed_dim]
        x_logvar = x_logvar.reshape(batch_size, -1)

        mu = self.fc_mu(x_mu)
        logvar = self.fc_logvar(x_logvar)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        # z: (batch_size, latent)
        z = self.fc_latent_to_hidden(z)  # [batch_size, seq_len * embed_dim]
        z = z.view(z.size(0), self.seq_len, self.embed_dim)  # [batch_size, seq_len, embed_dim]
        z = z.permute(1, 0, 2)  # [seq_len, batch_size, embed_dim]

        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(x.device)
        z = self.decoder(z, mask=causal_mask, is_causal=True)  # [seq_len, batch_size, embed_dim]
        z = z.permute(1, 0, 2)  # [batch_size, seq_len, embed_dim]
        z = self.output_layer(z)  # [batch_size, seq_len, vocab_size]
        return z


    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        xh = self.decode(z)
        return xh, mu, logvar


def vae_loss(xh, x, mu, logvar, beta=1.0):
    """
    xh: (batch, seq, vocab)
    x: (batch, seq)
    """
    recon_loss = F.cross_entropy(xh.view(-1, xh.size(-1)), x.view(-1), reduction='mean')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kl_loss /= x.size(0) * x.size(1)
    return recon_loss + beta * kl_loss, recon_loss, kl_loss


seq_len = 32
embed_dim = 128
latent_dim = 32
num_heads = 2
num_layers = 2
vocab_size = 128
batch_size = 4

vae = TransformerVAE(seq_len, embed_dim, latent_dim, num_heads, num_layers, vocab_size)
optimizer = torch.optim.SGD(vae.parameters())

x = torch.randint(0, vocab_size, (batch_size, seq_len))
vae.train()
xh, mu, logvar = vae(x)

vae_loss(xh, x, mu, logvar)

(tensor(5.0003, grad_fn=<AddBackward0>),
 tensor(4.9205, grad_fn=<NllLossBackward0>),
 tensor(0.0798, grad_fn=<DivBackward0>))