https://www.youtube.com/watch?v=C9QSpl5nmrY

In [52]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as L
from torch.optim import Adam
import tiktoken

class PositionEncoding(nn.Module):
    def __init__(self, d_model=256, max_len=1024):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = torch.exp(-torch.arange(0, d_model, 2).float() * (torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, word_embeddings):
        # Supports inputs of shape (seq_len, d_model) or (batch, seq_len, d_model)
        if word_embeddings.dim() == 2:
            seq_len, d = word_embeddings.size()
            return word_embeddings + self.pe[:seq_len, :d]
        elif word_embeddings.dim() == 3:
            b, seq_len, d = word_embeddings.size()
            pe = self.pe[:seq_len, :d].unsqueeze(0)  # (1, seq_len, d)
            return word_embeddings + pe
        else:
            raise ValueError(f"Unsupported embedding tensor shape: {word_embeddings.shape}")

class Attention(nn.Module):
    def __init__(self, d_model=256):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

    def forward(self, q_emb, k_emb, v_emb, mask=None):
        q = self.W_q(q_emb)
        k = self.W_k(k_emb)
        v = self.W_v(v_emb)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        weights = F.softmax(scores, dim=-1)
        return torch.matmul(weights, v)

class DecoderOnlyTransformer(L.LightningModule):
    def __init__(self, vocab_size, d_model=256, max_len=1024):
        super().__init__()
        self.save_hyperparameters()
        self.we = nn.Embedding(vocab_size, d_model)
        self.pe = PositionEncoding(d_model, max_len)
        self.attn = Attention(d_model)
        self.fc = nn.Linear(d_model, vocab_size)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, token_ids):
        # token_ids: (batch, seq_len) or (seq_len,)
        emb = self.we(token_ids)
        emb_pe = self.pe(emb)
        if emb_pe.dim() == 2:
            seq_len = emb_pe.size(0)
            mask = torch.tril(torch.ones(seq_len, seq_len, device=emb_pe.device)).logical_not()
        else:
            b, seq_len, _ = emb_pe.size()
            mask = torch.tril(torch.ones(seq_len, seq_len, device=emb_pe.device)).logical_not()
            mask = mask.unsqueeze(0).expand(b, seq_len, seq_len)
        attn_out = self.attn(emb_pe, emb_pe, emb_pe, mask)
        logits = self.fc(attn_out + emb_pe)
        return logits

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-4)

    def training_step(self, batch, batch_idx):
        input_ids, target_ids = batch  # both (batch, seq_len)
        logits = self.forward(input_ids)
        b, seq_len, vocab = logits.size()
        loss = self.loss_fn(logits.view(b * seq_len, vocab), target_ids.view(-1))
        print(f"Loss: {loss.item()}")
        self.log('train_loss', loss)
        return loss

    def generate(self, input_ids, max_new_tokens=50):
        self.eval()
        if input_ids.dim() == 1:
            generated = input_ids.unsqueeze(0).to(self.device)
        else:
            generated = input_ids.to(self.device)
        for _ in range(max_new_tokens):
            logits = self.forward(generated)  # (b, seq, vocab)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)  # (b, 1)
            generated = torch.cat([generated, next_token], dim=1)
        return generated

# Load JSON-formatted QA data and tokenize using GPT-2 BPE
def load_qa_data(file_path, encoder, seq_len):
    inputs, targets = [], []
    pad_id = encoder.eot_token_id if hasattr(encoder, 'eot_token_id') else encoder.encode('\n')[0]
    with open(file_path, 'r') as f:
        data = json.load(f)
    for item in data:
        context_ids = encoder.encode(item['context'])
        question_ids = encoder.encode(item['question'])
        answer_ids = encoder.encode(item['answers'])
        # questions_ids = context_ids + question_ids  # Ensure question ends with pad token
        tokens = context_ids + question_ids
        answer_tokens = answer_ids
        # pad or truncate to seq_len+1
        if len(tokens) < seq_len + 1:
            tokens += [pad_id] * (seq_len + 1 - len(tokens))
        if len(answer_tokens) < seq_len + 1:
            answer_tokens += [pad_id] * (seq_len + 1 - len(answer_tokens))
        tokens = tokens[:seq_len]
        answer_tokens = answer_tokens[:seq_len]
        inp = torch.tensor(tokens)
        tgt = torch.tensor(answer_tokens)
        inputs.append(inp)
        targets.append(tgt)
    return inputs, targets
  
  
from torch.utils.data import DataLoader, TensorDataset

# Hyperparameters
seq_len = 512
d_model = 256
max_len = seq_len + 50
batch_size = 4
# current_dir = os.path.dirname(os.path.abspath(__file__))
data_file = 'datasets/resonning.json'

# GPT-2 BPE encoder
encoder = tiktoken.get_encoding("gpt2")
vocab_size = encoder.n_vocab

# Prepare data
inputs, targets = load_qa_data(data_file, encoder, seq_len)
dataset = TensorDataset(torch.stack(inputs), torch.stack(targets))
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model
model = DecoderOnlyTransformer(vocab_size=vocab_size, d_model=d_model, max_len=max_len)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

device: cuda


In [53]:
# Trainer
epochs = 60
trainer = L.Trainer(max_epochs=epochs, logger=True)
trainer.fit(model, loader)
# Save the model
torch.save(obj=model.state_dict(), f="models/decoder_only_transformer.pth")

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | we      | Embedding        | 12.9 M | train
1 | pe      | PositionEncoding | 0      | train
2 | attn    | Attention        | 196 K  | train
3 | fc      | Linear           | 12.9 M | train
4 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params
103.914   Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

Loss: 10.842822074890137
Loss: 10.610391616821289
Loss: 10.375754356384277
Loss: 10.140053749084473
Loss: 9.904047966003418
Loss: 9.668194770812988
Loss: 9.432716369628906
Loss: 9.197669982910156
Loss: 8.963007926940918
Loss: 8.728620529174805
Loss: 8.494359016418457
Loss: 8.260065078735352
Loss: 8.0255765914917
Loss: 7.7907395362854
Loss: 7.555402755737305
Loss: 7.319425106048584
Loss: 7.082683086395264
Loss: 6.84505033493042
Loss: 6.606421947479248
Loss: 6.3666911125183105
Loss: 6.125766277313232
Loss: 5.8835625648498535
Loss: 5.6400146484375
Loss: 5.395069599151611
Loss: 5.14870023727417
Loss: 4.900914669036865
Loss: 4.651764869689941
Loss: 4.401371479034424
Loss: 4.149939060211182
Loss: 3.8977901935577393
Loss: 3.645408868789673
Loss: 3.39347243309021
Loss: 3.142918348312378
Loss: 2.894987106323242
Loss: 2.6512742042541504
Loss: 2.4137537479400635
Loss: 2.1847567558288574
Loss: 1.9668797254562378
Loss: 1.762803077697754
Loss: 1.575027346611023
Loss: 1.405551552772522
Loss: 1.255595

`Trainer.fit` stopped: `max_epochs=60` reached.


In [54]:
# load the model for inference
model = DecoderOnlyTransformer(vocab_size=vocab_size, d_model=d_model, max_len=max_len)
model.load_state_dict(torch.load("models/decoder_only_transformer.pth"))
model.to(device)

DecoderOnlyTransformer(
  (we): Embedding(50257, 256)
  (pe): PositionEncoding()
  (attn): Attention(
    (W_q): Linear(in_features=256, out_features=256, bias=False)
    (W_k): Linear(in_features=256, out_features=256, bias=False)
    (W_v): Linear(in_features=256, out_features=256, bias=False)
  )
  (fc): Linear(in_features=256, out_features=50257, bias=True)
  (loss_fn): CrossEntropyLoss()
)

In [62]:
# Example generation on first sample
sample_input = inputs[0]
out = encoder.decode(sample_input.tolist())
question = "What is the capital of France?"
context = "France is a country in Europe. Its capital is Paris, known for the Eiffel Tower."
prompt = f"{context}{question}"
idx = torch.tensor([encoder.encode(prompt)], device=device)
out_ids = model.generate(idx, max_new_tokens=20).squeeze(0)
text_out = encoder.decode(out_ids.tolist())
print(text_out)

France is a country in Europe. Its capital is Paris, known for the Eiffel Tower.What is the capital of France?




















