In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

import argparse
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [17]:
class TextDataset(Dataset):
    def __init__(self, file_path, seq_len, c2i, split="train"):
        with open(file_path, "r", encoding="utf-8") as f:
            data = f.read()

        n = len(data)
        split_idx = int(n * 0.9)
        if split == "train":
            self.data = torch.tensor([c2i[c] for c in data[:split_idx]])
        elif split == "val":
            self.data = torch.tensor([c2i[c] for c in data[split_idx:]])
        else:
            raise ValueError("split must be 'train' or 'val'")

        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        x = self.data[idx : idx + self.seq_len]
        y = self.data[idx + 1 : idx + 1 + self.seq_len]
        return x, y


class CharTransfomer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers):
        super(CharTransfomer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.decoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads
        )
        self.decoder = nn.TransformerEncoder(self.decoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        (
            _,
            seq_len,
        ) = x.shape
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(x.device)

        # Input shape: (batch, seq)
        x = self.embedding(x)  # (batch, seq, emb)
        x = x.permute(1, 0, 2)  # (seq, batch, emb)
        x = self.decoder(x, mask=causal_mask, is_causal=True)  # (seq, batch, emb)
        x = x.permute(1, 0, 2)  # Back to (batch, seq, emb)
        x = self.output_layer(x)  # (batch, seq, vocab_size)
        return x


@torch.no_grad()
def sample_text(model, i2c, device, max_len=100, start_char_idx=None, temperature=1.0):
    model.eval()
    vocab_size = len(i2c)
    if start_char_idx is None:
        start_char_idx = torch.randint(0, vocab_size, (1,)).item()
    input_seq = torch.tensor([[start_char_idx]], device=device)
    generated_text = [i2c[start_char_idx]]

    for _ in range(max_len - 1):
        outputs, _, _ = model(input_seq)
        logits = outputs[:, -1, :]
        logits = logits / temperature
        probs = torch.softmax(logits, dim=-1)
        next_char_idx = torch.multinomial(probs, num_samples=1).item()
        next_char = i2c[next_char_idx]
        generated_text.append(next_char)
        input_seq = torch.cat(
            [input_seq, torch.tensor([[next_char_idx]], device=device)], dim=1
        )

    return "".join(generated_text)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerVAE(nn.Module):
    def __init__(self, seq_len, embed_dim, latent_dim, num_heads, num_layers, vocab_size):
        super(TransformerVAE, self).__init__()
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.enc_project = nn.Linear(embed_dim, 2 * embed_dim)

        self.fc_mu = nn.Linear(embed_dim * seq_len, latent_dim)
        self.fc_logvar = nn.Linear(embed_dim * seq_len, latent_dim)
        self.fc_latent_to_hidden = nn.Linear(latent_dim, embed_dim * seq_len)

        decoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=num_layers)

        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def encode(self, x):
        x = self.embedding(x)  # [batch_size, seq_len, embed_dim]
        batch_size, _, embed_dim = x.shape

        x = x.permute(1, 0, 2)  # [seq_len, batch_size, embed_dim]
        x = self.encoder(x)  # [seq_len, batch_size, embed_dim]

        x = self.enc_project(x) # (seq_len, batch_size, 2 * embed_dim)
        x = x.permute(1, 0, 2)  # [batch_size, seq_len, 2 * embed_dim]
        x_mu, x_logvar = x.split(split_size=embed_dim, dim=2) # [batch_size, seq_len, embed_dim]
        x_mu = x_mu.reshape(batch_size, -1)  # [batch_size, seq_len * embed_dim]
        x_logvar = x_logvar.reshape(batch_size, -1)

        mu = self.fc_mu(x_mu)
        logvar = self.fc_logvar(x_logvar)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        # z: (batch_size, latent)
        z = self.fc_latent_to_hidden(z)  # [batch_size, seq_len * embed_dim]
        z = z.view(z.size(0), self.seq_len, self.embed_dim)  # [batch_size, seq_len, embed_dim]
        z = z.permute(1, 0, 2)  # [seq_len, batch_size, embed_dim]

        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(x.device)
        z = self.decoder(z, mask=causal_mask, is_causal=True)  # [seq_len, batch_size, embed_dim]
        z = z.permute(1, 0, 2)  # [batch_size, seq_len, embed_dim]
        z = self.output_layer(z)  # [batch_size, seq_len, vocab_size]
        return z

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        xh = self.decode(z)
        return xh, mu, logvar

    def sample(self, num_samples, device):
        z = torch.randn(num_samples, self.latent_dim).to(device)
        x = self.decode(z)
        return x


def vae_loss(xh, x, mu, logvar, beta=1.0):
    """
    xh: (batch, seq, vocab)
    x: (batch, seq)
    """
    recon_loss = F.cross_entropy(xh.view(-1, xh.size(-1)), x.view(-1), reduction='mean')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kl_loss /= x.size(0) * x.size(1)
    return recon_loss + beta * kl_loss, recon_loss, kl_loss


seq_len = 32
embed_dim = 128
latent_dim = 32
num_heads = 2
num_layers = 2
vocab_size = 128
batch_size = 4

vae = TransformerVAE(seq_len, embed_dim, latent_dim, num_heads, num_layers, vocab_size)
optimizer = torch.optim.SGD(vae.parameters())

x = torch.randint(0, vocab_size, (batch_size, seq_len))
vae.train()
xh, mu, logvar = vae(x)

vae_loss(xh, x, mu, logvar)


# vae.sample(1, device=torch.device("cpu")).shape



torch.Size([1, 32, 128])

In [44]:
@dataclass
class Args:
    device = "cpu"
    seq_len = 32
    batch_size = 2
    embed_dim = 128
    latent_dim = 512
    num_heads = 2
    num_layers = 2
    num_epochs = 1

args = Args()

with open("tiny.txt", "r", encoding="utf-8") as f:
    data = f.read()

chars = set([_ for _ in data])
c2i = {c: i for i, c in enumerate(chars)}
i2c = {i: c for c, i in c2i.items()}

train_ds = TextDataset("./tiny.txt", seq_len=args.seq_len, c2i=c2i, split="train")
train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)

if args.device == "cuda" and not torch.cuda.is_available():
    device = torch.device("cpu")
else:
    device = torch.device(args.device)
print("using device", device)

vocab_size = len(chars)
model = TransformerVAE(seq_len=args.seq_len, 
                       embed_dim=args.embed_dim, 
                       latent_dim=args.latent_dim, 
                       num_heads=args.num_heads, 
                       num_layers=args.num_layers, 
                       vocab_size=vocab_size)


def sample(model):
    x = model.sample(1, device) # (batch, seq, vocab)
    tokens = x.argmax(dim=2) # (batch, seq, vocab)
    chars = [i2c[_.item()] for _ in tokens.flatten()]
    return "".join(chars)


model.to(device)
optimizer = optim.Adam(model.parameters())

x, y = next(iter(train_dl))
for i in range(1000):
    xh, mu, logvar = model(x)
    loss, _, _ = vae_loss(xh, y, mu, logvar)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 100 == 0:
        print(loss.item())
        print(sample(model))


using device cpu
5.706575870513916
    e   e ee eee eeee ee eeee  e




0.9964932799339294
etneir tri  ams the e eer mh h l
0.9007115364074707
tongue can name thee, let me hav
1.5066030025482178
tongeo uan nase tuernnret oe maa
1.3298052549362183

tngue t n   re ttee,hlet me hav
1.775547981262207

th hmmttebune hfth ever.h hestl
0.907386064529419
tongue can name thee, let me hav
0.9067481756210327

thei h rebunesrtor evea   his l
0.5479682087898254
tohgiertribufes tooeev et mhih l


KeyboardInterrupt: 

In [47]:
print(sample(model))

tnteh  utairae  ttein  ettme h e


In [48]:
sample_interval = 250

for epoch in range(args.num_epochs):
    tq = tqdm(train_dl, desc=f"epoch {epoch+1}/{args.num_epochs}")

    for step, (inputs, targets) in enumerate(tq):
        inputs, targets = inputs.to(device), targets.to(device)

        xh, mu, logvar = model(inputs)
        loss, _, _ = vae_loss(xh, targets, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        tq.set_postfix(loss=f"{loss.item():.4f}")

        if step % sample_interval == 0:
            print(sample(model))

epoch 1/1:   0%|          | 6/501911 [00:00<4:50:54, 28.75it/s, loss=5.4719]

noneur tniburts thttevern,Thi el


epoch 1/1:   0%|          | 258/501911 [00:07<4:01:30, 34.62it/s, loss=3.2659]

                                


epoch 1/1:   0%|          | 506/501911 [00:14<4:02:59, 34.39it/s, loss=3.4416]

                                


epoch 1/1:   0%|          | 754/501911 [00:22<4:16:52, 32.52it/s, loss=3.0969]

                                


epoch 1/1:   0%|          | 759/501911 [00:22<4:06:05, 33.94it/s, loss=3.0628]


KeyboardInterrupt: 