# Transformer Text Generation (PyTorch)

Implementasi Transformer untuk text generation. Notebook ini dirancang agar metrik evaluasi konsisten dengan implementasi Rust (metrics.rs): BLEU dengan brevity penalty dan smoothing epsilon, diversity n-gram, serta repetition rate berbasis trigram.

In [1]:
import os, math, random
from collections import Counter
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(42)
random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cpu


## Tokenizer

In [2]:
class Tokenizer:
    def __init__(self, max_vocab_size=10000):
        self.max_vocab_size = max_vocab_size
        self.word2idx = {}
        self.idx2word = {}
        # special tokens
        self.PAD_TOKEN = '<PAD>'
        self.UNK_TOKEN = '<UNK>'
        self.SOS_TOKEN = '<SOS>'
        self.EOS_TOKEN = '<EOS>'
        self.word2idx[self.PAD_TOKEN] = 0
        self.word2idx[self.UNK_TOKEN] = 1
        self.word2idx[self.SOS_TOKEN] = 2
        self.word2idx[self.EOS_TOKEN] = 3
        for idx, tok in enumerate([self.PAD_TOKEN, self.UNK_TOKEN, self.SOS_TOKEN, self.EOS_TOKEN]):
            self.idx2word[idx] = tok

    def fit(self, texts: List[str]):
        counts = Counter()
        for t in texts:
            counts.update(t.lower().split())
        most = counts.most_common(self.max_vocab_size - 4)
        for w, _ in most:
            if w not in self.word2idx:
                idx = len(self.word2idx)
                self.word2idx[w] = idx
                self.idx2word[idx] = w
        print('Vocab size:', len(self.word2idx))

    def encode(self, text: str, max_len: int = None, add_sos=False, add_eos=False) -> List[int]:
        toks = text.lower().split()
        ids = []
        if add_sos:
            ids.append(self.word2idx[self.SOS_TOKEN])
        ids.extend(self.word2idx.get(w, self.word2idx[self.UNK_TOKEN]) for w in toks)
        if add_eos:
            ids.append(self.word2idx[self.EOS_TOKEN])
        if max_len is not None:
            if len(ids) > max_len:
                ids = ids[:max_len]
            else:
                ids += [self.word2idx[self.PAD_TOKEN]] * (max_len - len(ids))
        return ids

    def decode(self, ids: List[int]) -> str:
        out = []
        for i in ids:
            w = self.idx2word.get(int(i), self.UNK_TOKEN)
            if w in (self.PAD_TOKEN, self.SOS_TOKEN, self.EOS_TOKEN):
                continue
            out.append(w)
        return ' '.join(out)

    @property
    def vocab_size(self) -> int:
        return len(self.word2idx)

    @property
    def pad_token_id(self) -> int:
        return self.word2idx[self.PAD_TOKEN]

    @property
    def eos_token_id(self) -> int:
        return self.word2idx[self.EOS_TOKEN]

## Dataset dan Loader

In [3]:
def load_lines(path: str, min_words=3) -> List[str]:
    if not os.path.exists(path):
        return []
    out = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            t = line.strip().strip('"').lower()
            if len(t.split()) >= min_words:
                out.append(t)
    return out

class LMDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer: Tokenizer, max_seq_len: int = 128):
        self.texts = texts
        self.tok = tokenizer
        self.max_seq_len = max_seq_len
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        t = self.texts[idx]
        ids = self.tok.encode(t, max_len=self.max_seq_len, add_sos=True, add_eos=True)
        x = torch.tensor(ids[:-1], dtype=torch.long)
        y = torch.tensor(ids[1:],  dtype=torch.long)
        return x, y

## Model: Transformer LM

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)
    def forward(self, x, mask=None):
        B, S, D = x.size()
        Q = self.q(x).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
        K = self.k(x).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
        V = self.v(x).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
        scores = (Q @ K.transpose(-2, -1)) / self.scale
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = attn @ V
        out = out.transpose(1, 2).contiguous().view(B, S, D)
        return self.out(out)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff   = FeedForward(d_model, d_ff, dropout)
        self.do1 = nn.Dropout(dropout)
        self.do2 = nn.Dropout(dropout)
    def forward(self, x, mask=None):
        h = self.attn(self.ln1(x), mask)
        x = x + self.do1(h)
        h = self.ff(self.ln2(x))
        x = x + self.do2(h)
        return x

class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=4, d_ff=1024, max_seq_len=128, dropout=0.1):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0.0, std=0.02)
    def causal_mask(self, S, device):
        return torch.tril(torch.ones(S, S, device=device)).unsqueeze(0).unsqueeze(0)
    def forward(self, x):
        B, S = x.shape
        device = x.device
        pos = torch.arange(0, S, device=device).unsqueeze(0).expand(B, S)
        h = self.dropout(self.tok_emb(x) + self.pos_emb(pos))
        mask = self.causal_mask(S, device)
        for blk in self.blocks:
            h = blk(h, mask)
        h = self.ln_f(h)
        return self.head(h)  # [B, S, V]


## Metrics (selaras metrics.rs)

In [5]:
def perplexity(loss: float) -> float:
    try:
        return math.exp(loss)
    except OverflowError:
        return float('inf')

def _ngram_counts(tokens: List[str], n: int) -> Counter:
    c = Counter()
    L = len(tokens)
    if L < n:
        return c
    for i in range(L - n + 1):
        c[' '.join(tokens[i:i+n])] += 1
    return c

def bleu_bp_smooth(reference: str, candidate: str, max_n: int = 4, eps: float = 1e-9):
    ref = reference.split()
    cand = candidate.split()
    if not ref or not cand:
        return 0.0, [0.0]*max_n, 0.0
    precisions = []
    for n in range(1, max_n+1):
        ref_c  = _ngram_counts(ref, n)
        cand_c = _ngram_counts(cand, n)
        cand_total = sum(cand_c.values())
        if cand_total == 0:
            precisions.append(0.0)
            continue
        matched = sum(min(cnt, ref_c.get(ng, 0)) for ng, cnt in cand_c.items())
        p_n = (matched + eps) / (cand_total + eps)
        precisions.append(p_n)
    c, r = float(len(cand)), float(len(ref))
    bp = math.exp(1.0 - r/c) if c < r else 1.0
    if any(p <= 0.0 for p in precisions):
        geo = 0.0
    else:
        geo = math.exp(sum(math.log(p) for p in precisions) / max_n)
    return bp * geo, precisions, bp

def bleu_n_with_bp(reference: str, candidate: str, n: int = 2, eps: float = 1e-9) -> float:
    bleu, _, _ = bleu_bp_smooth(reference, candidate, max_n=n, eps=eps)
    return bleu

def bleu_score(reference: str, candidate: str, n: int = 1) -> float:
    ref = reference.split()
    cand = candidate.split()
    if len(cand) == 0 or len(ref) < n:
        return 0.0
    ref_c  = _ngram_counts(ref, n)
    cand_c = _ngram_counts(cand, n)
    matched = sum(min(cnt, ref_c.get(ng, 0)) for ng, cnt in cand_c.items())
    total   = sum(cand_c.values())
    return matched / max(1, total)

def diversity_score(text: str, n: int = 2) -> float:
    toks = text.split()
    L = len(toks)
    if L < n:
        return 0.0
    total = L - n + 1
    seen = set(' '.join(toks[i:i+n]) for i in range(total))
    return len(seen) / float(total)

def repetition_score(text: str) -> float:
    toks = text.split()
    L = len(toks)
    if L < 3:
        return 0.0
    cnt = Counter(' '.join(toks[i:i+3]) for i in range(L-2))
    repeated = sum(v for v in cnt.values() if v > 1)
    return repeated / float(L - 2)

class GenerationMetrics:
    def __init__(self, perplexity, bleu_1, bleu_2, diversity_2, diversity_3, repetition):
        self.perplexity = perplexity
        self.bleu_1 = bleu_1
        self.bleu_2 = bleu_2
        self.diversity_2 = diversity_2
        self.diversity_3 = diversity_3
        self.repetition = repetition
    @classmethod
    def calculate(cls, loss: float, reference: str | None, candidate: str):
        ppl = perplexity(loss)
        if reference is not None:
            b1 = bleu_n_with_bp(reference, candidate, n=1)
            b2 = bleu_n_with_bp(reference, candidate, n=2)
        else:
            b1 = 0.0
            b2 = 0.0
        d2 = diversity_score(candidate, 2)
        d3 = diversity_score(candidate, 3)
        rep = repetition_score(candidate)
        return cls(ppl, b1, b2, d2, d3, rep)

def self_bleu(texts: List[str], n: int = 2) -> float:
    if len(texts) < 2:
        return 0.0
    s, c = 0.0, 0
    for i in range(len(texts)):
        for j in range(len(texts)):
            if i == j:
                continue
            s += bleu_n_with_bp(texts[i], texts[j], n=n)
            c += 1
    return s / c if c > 0 else 0.0

## Top-k filtering dan generator wrapper

In [6]:
def top_k_filtering(logits, top_k=50):
    if top_k is None or top_k <= 0:
        return logits
    k = min(top_k, logits.size(-1))
    topk = torch.topk(logits, k=k, dim=-1)
    thresh = topk.values[..., -1].unsqueeze(-1)
    return torch.where(logits < thresh, torch.tensor(float('-inf'), device=logits.device), logits)

@torch.no_grad()
def generate_from_prompt(model: TransformerLM, tok: Tokenizer, prompt: str, max_new=50, temperature=0.9, top_k=50, seq_len=128, eos_id=None):
    start_ids = tok.encode(prompt, max_len=64, add_sos=True, add_eos=False)
    tokens = torch.tensor(start_ids, dtype=torch.long, device=device).unsqueeze(0)
    model.eval()
    for _ in range(max_new):
        if tokens.size(1) >= seq_len:
            break
        logits = model(tokens)
        last = logits[:, -1, :] / max(1e-8, temperature)
        last = top_k_filtering(last, top_k=top_k)
        probs = F.softmax(last, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        tokens = torch.cat([tokens, next_id], dim=1)
        if eos_id is not None and int(next_id.item()) == eos_id:
            break
    return tok.decode(tokens[0].tolist())

## Data preparation dan hyperparameter

In [7]:
CORPUS = 'TextGen1.txt'
texts = load_lines(CORPUS, min_words=3)
if len(texts) == 0:
    demo = [
        'so shaken as we are so wan with care',
        'find we a time for frighted peace to pant',
        'and breathe short winded accents of new broils',
        'to be commenced in strands afar remote',
        'no more the thirsty entrance of this soil',
        'shall daub her lips with her own childrens blood'
    ]
    with open('demo_text.txt', 'w', encoding='utf-8') as f:
        f.write('\n'.join(demo))
    texts = demo
print('Total lines:', len(texts))

random.shuffle(texts)
split = max(1, int(0.9 * len(texts)))
train_texts = texts[:split]
val_texts   = texts[split:]
print(f'Train: {len(train_texts)} | Val: {len(val_texts)}')

tok = Tokenizer(max_vocab_size=10000)
tok.fit(train_texts + val_texts)
PAD = tok.pad_token_id
EOS = tok.eos_token_id

D_MODEL, N_HEADS, N_LAYERS = 256, 8, 4
D_FF, DROPOUT = 1024, 0.1
SEQ_LEN, BATCH, EPOCHS, LR = 128, 32, 1, 1e-4

train_ds = LMDataset(train_texts, tok, SEQ_LEN)
val_ds   = LMDataset(val_texts, tok, SEQ_LEN)
train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True, drop_last=True)
val_dl   = DataLoader(val_ds, batch_size=BATCH, shuffle=False, drop_last=False)

model = TransformerLM(tok.vocab_size, D_MODEL, N_HEADS, N_LAYERS, D_FF, SEQ_LEN, DROPOUT).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
print(f'Total parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M')

Total lines: 10
Train: 9 | Val: 1
Vocab size: 69
Total parameters: 3.23M


## Training loop

In [8]:
def train_epoch(model, loader, opt, device, pad_id):
    model.train()
    total = 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=pad_id)
        opt.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        total += float(loss.item())
    return total / max(1, len(loader))

@torch.no_grad()
def evaluate(model, loader, device, pad_id):
    model.eval()
    total = 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=pad_id)
        total += float(loss.item())
    return total / max(1, len(loader))

best_val = float('inf')
for e in range(1, EPOCHS+1):
    tr = train_epoch(model, train_dl, opt, device, PAD)
    va = evaluate(model, val_dl, device, PAD)
    print(f'Epoch {e} | train {tr:.4f} (ppl {perplexity(tr):.2f}) | val {va:.4f} (ppl {perplexity(va):.2f})')
    if va < best_val:
        best_val = va
        os.makedirs('artifacts', exist_ok=True)
        torch.save({
            'model': model.state_dict(),
            'vocab': tok.word2idx,
            'cfg': {'d_model': D_MODEL, 'n_heads': N_HEADS, 'n_layers': N_LAYERS, 'd_ff': D_FF, 'dropout': DROPOUT, 'seq_len': SEQ_LEN}
        }, 'artifacts/best_transformer.pt')
        print('Saved -> artifacts/best_transformer.pt')

Epoch 1 | train 0.0000 (ppl 1.00) | val 4.3206 (ppl 75.24)
Saved -> artifacts/best_transformer.pt


## Evaluasi dan contoh generasi

In [None]:
import random

T_LIST = [0.7, 1.0, 1.2]
USE_TOP_K = True     # set ke False kalau mau lihat efek suhu lebih kentara
TOP_K_VAL = 50       # diabaikan kalau USE_TOP_K=False

print("\nTesting text generation...")
print("=" * 60)

# ============= 1. GENERATION WITH DIFFERENT TEMPERATURES =============
print("\n 1. Generation Examples (Different Temperatures):")
print("-" * 60)

prompts = ["once upon a time", "the quick brown", "in the beginning"]
for i, p in enumerate(prompts, 1):
    print(f'\n{i}. Prompt: "{p}"')
    for T in T_LIST:
        out = generate_from_prompt(
            model, tok, p,
            max_new=50,  # 50 seperti di Rust
            temperature=T,
            top_k=(TOP_K_VAL if USE_TOP_K else None),
            seq_len=SEQ_LEN,
            eos_id=EOS
        )
        print(f"   [T={T:.1f}] {out}")

# ============= 2. VALIDATION SAMPLES COMPARISON =============
print("\n\n 2. Validation Samples Comparison:")
print("-" * 60)

if len(val_texts) == 0:
    print("  (no validation samples)")
else:
    k = min(3, len(val_texts))
    samples = random.sample(val_texts, k=k)

    for idx, ref in enumerate(samples, 1):
        print(f"\n{idx}. Original:")
        preview = ref[:100] + ("..." if len(ref) > 100 else "")
        print(f"   {preview}")

        # prompt = 3 kata pertama dari reference
        pr = " ".join(ref.split()[:3])

        gen = generate_from_prompt(
            model, tok, pr,
            max_new=60,
            temperature=0.8,
            top_k=(TOP_K_VAL if USE_TOP_K else None),
            seq_len=SEQ_LEN,
            eos_id=EOS
        )
        print("   Generated:")
        print(f"   {gen}")

        # hitung metrik
        m = GenerationMetrics.calculate(loss=0.0, reference=ref, candidate=gen)
        print("   Metrics:")
        print(f"     BLEU-1: {m.bleu_1:.3f}")
        print(f"     BLEU-2: {m.bleu_2:.3f}")
        print(f"     Diversity-2: {m.diversity_2:.3f}")
        print(f"     Repetition: {m.repetition:.3f}")

# ============= 3. MODEL DIVERSITY TEST (Self-BLEU) =============
print("\n\n 3. Model Diversity Test (Self-BLEU):")
print("=" * 60)
print("Testing how diverse the model's outputs are when given the same prompt.")
print("Lower Self-BLEU = More diverse outputs (better creativity)")
print()

diversity_prompts = ["once upon a time", "the quick brown", "in the beginning"]
num_generations = 8  # Generate 8 teks per prompt

for idx, prompt in enumerate(diversity_prompts, 1):
    print(f"\n{idx}.{'─' * 55} Prompt: \"{prompt}\"")
    
    generated_texts = []
    
    # Generate beberapa kali dengan prompt yang SAMA
    for i in range(num_generations):
        gen = generate_from_prompt(
            model, tok, prompt,
            max_new=60,
            temperature=0.8,  # temperature tetap
            top_k=(TOP_K_VAL if USE_TOP_K else None),
            seq_len=SEQ_LEN,
            eos_id=EOS
        )
        
        # Print preview
        preview = gen[:70] + "..." if len(gen) > 70 else gen
        print(f"   {i + 1}. {preview}")
        
        generated_texts.append(gen)
    
    # Hitung Self-BLEU untuk prompt ini
    self_bleu_2 = self_bleu(generated_texts, 2)
    self_bleu_3 = self_bleu(generated_texts, 3)
    
    print(f"\n   Diversity Metrics for this prompt:")
    print(f"     Self-BLEU-2: {self_bleu_2:.3f} (bigram overlap)")
    print(f"     Self-BLEU-3: {self_bleu_3:.3f} (trigram overlap)")
    
    # Interpretasi
    if self_bleu_2 < 0.3:
        interpretation = "Excellent diversity!"
    elif self_bleu_2 < 0.5:
        interpretation = "Good diversity"
    elif self_bleu_2 < 0.7:
        interpretation = "Moderate diversity"
    else:
        interpretation = "Low diversity (repetitive outputs)"
    
    print(f"     Interpretation: {interpretation}")

# ============= 4. TEMPERATURE IMPACT ON DIVERSITY =============
print("\n\n 4. Temperature Impact on Diversity:")
print("-" * 60)

test_prompt = "once upon a time"
temperatures = [0.5, 0.8, 1.0, 1.2]

for temp in temperatures:
    print(f"\n   Temperature: {temp:.1f}")
    generated_texts = []
    
    for i in range(6):
        gen = generate_from_prompt(
            model, tok, test_prompt,
            max_new=60,
            temperature=temp,
            top_k=(TOP_K_VAL if USE_TOP_K else None),
            seq_len=SEQ_LEN,
            eos_id=EOS
        )
        
        preview = gen[:60] + "..." if len(gen) > 60 else gen
        print(f"     {i + 1}. {preview}")
        
        generated_texts.append(gen)
    
    sb = self_bleu(generated_texts, 2)
    print(f"     → Self-BLEU-2: {sb:.3f}")

# ============= 5. SUMMARY =============
print("\n\n" + "=" * 60)
print(" SUMMARY")
print("=" * 60)
print("✓ Generation test completed")
print("✓ Lower Self-BLEU indicates more diverse and creative outputs")
print("✓ Higher temperature generally increases diversity")
print("✓ Optimal temperature balances diversity and coherence (0.7-1.0)")
print("=" * 60)


Testing text generation...

 Generation Examples:

1. Prompt: "once upon a time"
   [T=0.7] <UNK> <UNK> a <UNK> a his bold going mirror, his hath hath whilst-- timon's till friends pike,
   [T=1.0] <UNK> <UNK> a <UNK> knife, engine, engine, noble give your wealth, pike, going them.
   [T=1.2] <UNK> <UNK> a <UNK> ionia, lydia strong noble give in engine, myself in bold timon's a myself hath which which him or death sword, aspiring timon's my his should them. friends need do you. you. death hand my banqueting aspiring i them. seem'd lord ionia, mutinous whilst-- death going a lord sword, sword, aspiring ionia, i need in <UNK> which need shall myself

2. Prompt: "the quick brown"
   [T=0.7] <UNK> <UNK> <UNK> his whilst-- seem'd should friends do engine, a aspiring friends which mutinous engine, his which myself profess strength them. in wealth, <UNK> not rider rider them. you. deliver like for
   [T=1.0] <UNK> <UNK> <UNK> his i hath knife, strong do hath lydia engine, myself a a
   [T=1.