In [112]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerVAE(nn.Module):
    def __init__(self, seq_len, embed_dim, latent_dim, num_heads, num_layers, vocab_size):
        super(TransformerVAE, self).__init__()
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fc_mu = nn.Linear(embed_dim * seq_len, latent_dim)
        self.fc_logvar = nn.Linear(embed_dim * seq_len, latent_dim)
        self.fc_latent_to_hidden = nn.Linear(latent_dim, embed_dim * seq_len)

        decoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=num_layers)

        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def encode(self, x):
        x = self.embedding(x)  # [batch_size, seq_len, embed_dim]
        x = x.permute(1, 0, 2)  # [seq_len, batch_size, embed_dim]
        x = self.encoder(x)  # [seq_len, batch_size, embed_dim]
        x = x.permute(1, 0, 2).contiguous()  # [batch_size, seq_len, embed_dim]
        x = x.view(x.size(0), -1)  # [batch_size, seq_len * embed_dim]
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = self.fc_latent_to_hidden(z)  # [batch_size, seq_len * embed_dim]
        z = z.view(z.size(0), self.seq_len, self.embed_dim)  # [batch_size, seq_len, embed_dim]
        z = z.permute(1, 0, 2)  # [seq_len, batch_size, embed_dim]

        causal_mask = torch.triu(torch.ones(self.seq_len, self.seq_len), diagonal=1).bool().to(z.device)
        z = self.decoder(z, mask=causal_mask, is_causal=True)  # [seq_len, batch_size, embed_dim]
        z = z.permute(1, 0, 2)  # [batch_size, seq_len, embed_dim]
        z = self.output_layer(z)  # [batch_size, seq_len, vocab_size]
        return z


    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        xh = self.decode(z)
        return xh, mu, logvar


def loss_function(recon_x, x, mu, logvar, beta=1.0):
    recon_loss = F.cross_entropy(recon_x.view(-1, recon_x.size(-1)), x.view(-1), reduction='mean')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kl_loss /= x.size(0) * x.size(1)
    return recon_loss + beta * kl_loss, recon_loss, kl_loss


seq_len = 32
embed_dim = 128
latent_dim = 32
num_heads = 2
num_layers = 2
vocab_size = 128
batch_size = 4

vae = TransformerVAE(seq_len, embed_dim, latent_dim, num_heads, num_layers, vocab_size)
optimizer = torch.optim.SGD(vae.parameters())

x = torch.randint(0, vocab_size, (batch_size, seq_len))
vae.train()
;

''

In [None]:
with open("tiny.txt", 'r', encoding='utf-8') as f:
    data = f.read()

chars = set([_ for _ in data])
c2i = {c:i for i, c in enumerate(chars)}
i2c = {i:c for c, i in c2i.items()}

class TextDataset(Dataset):
    def __init__(self, file_path, seq_len, c2i, split='train'):
        with open(file_path, 'r', encoding='utf-8') as f:
            data = f.read()

        n = len(data)
        split_idx = int(n * 0.9)
        if split == 'train':
            self.data = torch.tensor([c2i[c] for c in data[:split_idx]])
        elif split == 'val':
            self.data = torch.tensor([c2i[c] for c in data[split_idx:]])
        else:
            raise ValueError("split must be 'train' or 'val'")
        
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_len]
        y = self.data[idx + 1:idx + 1 + self.seq_len]
        return x, y


train_ds = TextDataset("./tiny.txt", seq_len=128, c2i=c2i, split="train")
train_dl = DataLoader(train_ds, batch_size=2)

x, y = next(iter(train_dl))
print("".join([i2c[_.item()] for _ in x[0, 0:128]]))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to 


In [107]:
class CharTransfomer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers):
        super(CharTransfomer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.decoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.decoder = nn.TransformerEncoder(self.decoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        # Input shape: (batch, seq)
        x = self.embedding(x)  # (batch, seq, emb)
        x = x.permute(1, 0, 2)  # (seq, batch, emb)
        x = self.decoder(x)     # (seq, batch, emb)
        x = x.permute(1, 0, 2)  # Back to (batch, seq, emb)
        x = self.output_layer(x)  # (batch, seq, vocab_size)
        return x


vocab_size = len(chars)
embed_dim = 128
num_heads = 2
num_layers = 4

model = CharTransfomer(vocab_size, embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers)
x, y = next(iter(train_dl))
model(x).shape



torch.Size([2, 128, 65])

In [110]:
@torch.no_grad()
def sample_text(model, i2c, device, max_len=100, start_char_idx=None, temperature=1.0):
    model.eval()
    vocab_size = len(i2c)
    if start_char_idx is None:
        start_char_idx = torch.randint(0, vocab_size, (1,)).item()
    input_seq = torch.tensor([[start_char_idx]], device=device)
    generated_text = [i2c[start_char_idx]]

    for _ in range(max_len - 1):
        outputs = model(input_seq)
        logits = outputs[:, -1, :]
        logits = logits / temperature
        probs = torch.softmax(logits, dim=-1)
        next_char_idx = torch.multinomial(probs, num_samples=1).item()
        next_char = i2c[next_char_idx]
        generated_text.append(next_char)
        input_seq = torch.cat([input_seq, torch.tensor([[next_char_idx]], device=device)], dim=1)

    return ''.join(generated_text)


sample_text(model, i2c, device=torch.device("cpu"))


' rxQpPyx?QYA$cOYHW,?sPonWw3\ncbHGwM;\neFciWVQtQ-I!LJo\ntfePVrnO -RvaeP$kbOuIVf,uWE-LcQLW?KR.fYRdlA ?f e'

In [119]:
train_ds = TextDataset("./tiny.txt", seq_len=128, c2i=c2i, split="train")
train_dl = DataLoader(train_ds, batch_size=32)

device = torch.device("cpu")
vocab_size = len(chars)
embed_dim = 256
num_heads = 4
num_layers = 4
model = CharTransfomer(vocab_size, embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
num_epochs = 1
sample_interval = 250

for epoch in range(32):
    model.train()
    tq = tqdm(train_dl, desc=f"epoch {epoch+1}/{num_epochs}", leave=True)

    for step, (inputs, targets) in enumerate(tq):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        tq.set_postfix(loss=f"{loss.item():.4f}")

        if step % sample_interval == 0:
            tqdm.write(sample_text(model, i2c, device, max_len=100))

epoch 1/1:   0%|          | 1/31367 [00:01<11:15:59,  1.29s/it, loss=4.3685]

fPrtes e, 
p T
 ,UrKl 
rA
ee eM;u  ee B

lrereS;t 
r  e;eSea Qh lvur reNeahe 
he pVt
 a
 t phI Vr Ly


epoch 1/1:   1%|          | 251/31367 [01:33<4:12:09,  2.06it/s, loss=2.6858]

xsrecdghihisenses an?
Wiler rkUSe y, EWhesousek s.s? hsen,std afinofn serawho. t, ofe iak s aese hal


epoch 1/1:   1%|          | 268/31367 [01:39<3:13:01,  2.69it/s, loss=2.9016]


KeyboardInterrupt: 