In [1]:
import random
import re
from pathlib import Path
from collections import Counter
from typing import List, Tuple, Dict, Any, Optional
import math
import json
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
import sacrebleu
from typing import Optional

DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

plt.rcParams["figure.figsize"] = (8,5)
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DATA_PATH = Path("spa.txt")
SAMPLE_N = 10_000

TRAIN_RATIO = 0.80
VAL_RATIO   = 0.10
TEST_RATIO  = 0.10

LOWERCASE = True
MIN_FREQ = 2
MAX_VOCAB_SRC = 20_000
MAX_VOCAB_TGT = 20_000

MAX_LEN_SRC = 60
MAX_LEN_TGT = 60

BATCH_SIZE = 64
NUM_WORKERS = 0

D_MODEL = 256
DROPOUT = 0.1
MAX_LEN = 256

BASE_DIR = Path("outputs")
TRANS_DIR = Path("outputs_transformer")

DEVICE: mps


In [2]:
def load_pairs(path: Path) -> List[Tuple[str, str]]:
    pairs: List[Tuple[str, str]] = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split("\t")
            if len(parts) < 2:
                continue
            en, es = parts[0], parts[1]
            pairs.append((en, es))
    return pairs

def split_pairs(
    pairs: List[Tuple[str, str]],
    sample_n: Optional[int] = None,
    train_ratio: float = 0.8,
    val_ratio: float = 0.1,
    test_ratio: float = 0.1,
    seed: int = 42,
):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6

    rng = random.Random(seed)
    pairs = pairs[:]
    rng.shuffle(pairs)

    if sample_n is not None:
        pairs = pairs[:sample_n]

    n = len(pairs)
    n_train = int(n * train_ratio)
    n_val   = int(n * val_ratio)

    train = pairs[:n_train]
    val   = pairs[n_train:n_train + n_val]
    test  = pairs[n_train + n_val:]

    return train, val, test

pairs = load_pairs(DATA_PATH)
train_pairs, val_pairs, test_pairs = split_pairs(
    pairs,
    sample_n=SAMPLE_N,
    train_ratio=TRAIN_RATIO,
    val_ratio=VAL_RATIO,
    test_ratio=TEST_RATIO,
    seed=SEED,
)

print("Total pairs in file:", len(pairs))
print("Using:", len(train_pairs) + len(val_pairs) + len(test_pairs))
print("Train:", len(train_pairs), "Val:", len(val_pairs), "Test:", len(test_pairs))


Total pairs in file: 118964
Using: 10000
Train: 8000 Val: 1000 Test: 1000


In [3]:
PAD = "<pad>"
UNK = "<unk>"
BOS = "<bos>"
EOS = "<eos>"
SPECIAL_TOKENS = [PAD, UNK, BOS, EOS]

def normalize(text: str, lowercase: bool = True) -> str:
    text = text.strip()
    if lowercase:
        text = text.lower()
    text = re.sub(r"\s+", " ", text)
    return text

def tokenize(text: str) -> List[str]:
    text = re.sub(r"([.,!?;:()\"'])", r" \1 ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text.split() if text else []

print(tokenize(normalize("Hello, World!   ", True)))

['hello', ',', 'world', '!']


In [4]:
def build_vocab(
    pairs: List[Tuple[str, str]],
    side: str,
    min_freq: int = 2,
    max_vocab: Optional[int] = None,
    lowercase: bool = True,
) -> Dict[str, Any]:
    assert side in {"src", "tgt"}
    idx = 0 if side == "src" else 1

    counter = Counter()
    for en, es in pairs:
        txt = en if idx == 0 else es
        toks = tokenize(normalize(txt, lowercase))
        counter.update(toks)

    itos = list(SPECIAL_TOKENS)
    for tok, c in counter.most_common():
        if c < min_freq:
            break
        if tok in SPECIAL_TOKENS:
            continue
        if max_vocab is not None and len(itos) >= max_vocab:
            break
        itos.append(tok)

    stoi = {t: i for i, t in enumerate(itos)}
    return {"itos": itos, "stoi": stoi}

src_vocab = build_vocab(train_pairs, "src", min_freq=MIN_FREQ, max_vocab=MAX_VOCAB_SRC, lowercase=LOWERCASE)
tgt_vocab = build_vocab(train_pairs, "tgt", min_freq=MIN_FREQ, max_vocab=MAX_VOCAB_TGT, lowercase=LOWERCASE)

pad_id_src = src_vocab["stoi"][PAD]
pad_id_tgt = tgt_vocab["stoi"][PAD]
bos_id_tgt = tgt_vocab["stoi"][BOS]
eos_id_tgt = tgt_vocab["stoi"][EOS]

print("SRC vocab size:", len(src_vocab["itos"]))
print("TGT vocab size:", len(tgt_vocab["itos"]))
print("pad_id_src:", pad_id_src, "pad_id_tgt:", pad_id_tgt, "bos_id_tgt:", bos_id_tgt, "eos_id_tgt:", eos_id_tgt)

SRC vocab size: 2279
TGT vocab size: 3052
pad_id_src: 0 pad_id_tgt: 0 bos_id_tgt: 2 eos_id_tgt: 3


In [5]:
def encode_src(text: str, vocab: Dict[str, Any], max_len: int) -> List[int]:
    toks = tokenize(normalize(text, LOWERCASE))
    toks = toks[:max_len]
    ids = [vocab["stoi"].get(t, vocab["stoi"][UNK]) for t in toks]
    ids.append(vocab["stoi"][EOS])
    return ids

def encode_tgt(text: str, vocab: Dict[str, Any], max_len: int) -> List[int]:
    toks = tokenize(normalize(text, LOWERCASE))
    toks = toks[:max_len]
    toks = [BOS] + toks + [EOS]
    ids = [vocab["stoi"].get(t, vocab["stoi"][UNK]) for t in toks]
    return ids

class TranslationDataset(Dataset):
    def __init__(self, pairs: List[Tuple[str, str]]):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        en, es = self.pairs[idx]
        src_ids = encode_src(en, src_vocab, MAX_LEN_SRC)
        tgt_ids = encode_tgt(es, tgt_vocab, MAX_LEN_TGT)
        return src_ids, tgt_ids

def make_subsequent_mask(sz: int) -> torch.Tensor:
    return torch.triu(torch.ones((sz, sz), dtype=torch.bool), diagonal=1)

def collate_fn(batch):
    src_seqs, tgt_seqs = zip(*batch)
    B = len(batch)

    src_lens = [len(s) for s in src_seqs]
    tgt_lens = [len(t) for t in tgt_seqs]

    max_src = max(src_lens)
    max_tgt = max(tgt_lens)

    src = torch.full((B, max_src), pad_id_src, dtype=torch.long)
    tgt = torch.full((B, max_tgt), pad_id_tgt, dtype=torch.long)

    for i, (s, t) in enumerate(zip(src_seqs, tgt_seqs)):
        src[i, :len(s)] = torch.tensor(s, dtype=torch.long)
        tgt[i, :len(t)] = torch.tensor(t, dtype=torch.long)

    tgt_in  = tgt[:, :-1].contiguous()
    tgt_out = tgt[:, 1:].contiguous()

    src_key_padding_mask = (src == pad_id_src)
    tgt_key_padding_mask = (tgt_in == pad_id_tgt)

    Tt = tgt_in.size(1)
    tgt_subsequent_mask = make_subsequent_mask(Tt)

    return {
        "src": src,
        "tgt_in": tgt_in,
        "tgt_out": tgt_out,
        "src_lens": torch.tensor(src_lens, dtype=torch.long),
        "tgt_lens": torch.tensor(tgt_lens, dtype=torch.long),
        "src_key_padding_mask": src_key_padding_mask,
        "tgt_key_padding_mask": tgt_key_padding_mask,
        "tgt_subsequent_mask": tgt_subsequent_mask,
    }

train_dl = DataLoader(TranslationDataset(train_pairs), batch_size=BATCH_SIZE, shuffle=True,
                      num_workers=NUM_WORKERS, collate_fn=collate_fn, drop_last=True)
val_dl   = DataLoader(TranslationDataset(val_pairs), batch_size=BATCH_SIZE, shuffle=False,
                      num_workers=NUM_WORKERS, collate_fn=collate_fn)
test_dl  = DataLoader(TranslationDataset(test_pairs), batch_size=BATCH_SIZE, shuffle=False,
                      num_workers=NUM_WORKERS, collate_fn=collate_fn)

batch = next(iter(train_dl))
print("src:", batch["src"].shape)
print("tgt_in:", batch["tgt_in"].shape)
print("tgt_out:", batch["tgt_out"].shape)
print("src_key_padding_mask:", batch["src_key_padding_mask"].shape, batch["src_key_padding_mask"].dtype)
print("tgt_key_padding_mask:", batch["tgt_key_padding_mask"].shape, batch["tgt_key_padding_mask"].dtype)
print("tgt_subsequent_mask:", batch["tgt_subsequent_mask"].shape, batch["tgt_subsequent_mask"].dtype)

src: torch.Size([64, 18])
tgt_in: torch.Size([64, 16])
tgt_out: torch.Size([64, 16])
src_key_padding_mask: torch.Size([64, 18]) torch.bool
tgt_key_padding_mask: torch.Size([64, 16]) torch.bool
tgt_subsequent_mask: torch.Size([16, 16]) torch.bool


In [6]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, pad_idx: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.scale = math.sqrt(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.emb(x) * self.scale

In [7]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        T = x.size(1)
        x = x + self.pe[:, :T, :].to(x.dtype)
        return self.dropout(x)

In [8]:
class TransformerInputEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, pad_idx: int, max_len: int, dropout: float):
        super().__init__()
        self.tok = TokenEmbedding(vocab_size, d_model, pad_idx)
        self.pos = SinusoidalPositionalEncoding(d_model, max_len=max_len, dropout=dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.pos(self.tok(x))

In [9]:
src_embed = TransformerInputEmbedding(
    vocab_size=len(src_vocab["itos"]),
    d_model=D_MODEL,
    pad_idx=pad_id_src,
    max_len=MAX_LEN,
    dropout=DROPOUT
).to(DEVICE)

tgt_embed = TransformerInputEmbedding(
    vocab_size=len(tgt_vocab["itos"]),
    d_model=D_MODEL,
    pad_idx=pad_id_tgt,
    max_len=MAX_LEN,
    dropout=DROPOUT
).to(DEVICE)

batch = next(iter(train_dl))
src = batch["src"].to(DEVICE)
tgt_in = batch["tgt_in"].to(DEVICE)

src_x = src_embed(src)
tgt_x = tgt_embed(tgt_in)

print("src ids:", src.shape, "-> src embed:", src_x.shape)
print("tgt ids:", tgt_in.shape, "-> tgt embed:", tgt_x.shape)

src ids: torch.Size([64, 18]) -> src embed: torch.Size([64, 18, 256])
tgt ids: torch.Size([64, 16]) -> tgt embed: torch.Size([64, 16, 256])


In [10]:
def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    attn_mask: Optional[torch.Tensor] = None,
    key_padding_mask: Optional[torch.Tensor] = None,
    dropout_p: float = 0.0,
    training: bool = False,
):
    Dh = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Dh)

    if attn_mask is not None:
        if attn_mask.dim() == 2:
            scores = scores.masked_fill(attn_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
        else:
            scores = scores.masked_fill(attn_mask, float("-inf"))

    if key_padding_mask is not None:
        scores = scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(1), float("-inf"))

    attn = torch.softmax(scores, dim=-1)
    if dropout_p > 0.0:
        attn = F.dropout(attn, p=dropout_p, training=training)

    out = torch.matmul(attn, V)
    return out, attn

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.dropout = dropout

        self.Wq = nn.Linear(d_model, d_model, bias=False)
        self.Wk = nn.Linear(d_model, d_model, bias=False)
        self.Wv = nn.Linear(d_model, d_model, bias=False)
        self.Wo = nn.Linear(d_model, d_model, bias=False)

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        B, T, D = x.shape
        x = x.view(B, T, self.num_heads, self.d_head)
        return x.transpose(1, 2)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        B, H, T, Dh = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(B, T, H * Dh)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
    ):
        Q = self._split_heads(self.Wq(q))
        K = self._split_heads(self.Wk(k))
        V = self._split_heads(self.Wv(v))

        out_h, attn = scaled_dot_product_attention(
            Q, K, V,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            dropout_p=self.dropout,
            training=self.training
        )

        out = self.Wo(self._merge_heads(out_h))
        return out, attn

In [12]:
torch.manual_seed(0)

B, Tq, Tk, D, H = 2, 4, 6, 8, 2
mha = MultiHeadAttention(d_model=D, num_heads=H, dropout=0.0).to(DEVICE)

q = torch.randn(B, Tq, D, device=DEVICE)
k = torch.randn(B, Tk, D, device=DEVICE)
v = torch.randn(B, Tk, D, device=DEVICE)

key_padding_mask = torch.zeros(B, Tk, dtype=torch.bool, device=DEVICE)
key_padding_mask[0, -2:] = True

attn_mask = torch.triu(torch.ones((Tq, Tk), dtype=torch.bool, device=DEVICE), diagonal=1)

out, attn = mha(q, k, v, attn_mask=None, key_padding_mask=key_padding_mask)
print("Without causal mask -> out:", out.shape, "attn:", attn.shape)

out2, attn2 = mha(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
print("With causal mask -> out:", out2.shape, "attn:", attn2.shape)

print("attn batch0 head0 last2 keys (should be ~0):", attn[0,0,0,-2:].detach().cpu())

Without causal mask -> out: torch.Size([2, 4, 8]) attn: torch.Size([2, 2, 4, 6])
With causal mask -> out: torch.Size([2, 4, 8]) attn: torch.Size([2, 2, 4, 6])
attn batch0 head0 last2 keys (should be ~0): tensor([0., 0.])


In [13]:
batch = next(iter(train_dl))
src = batch["src"].to(DEVICE)
tgt_in = batch["tgt_in"].to(DEVICE)

D_MODEL = 64
NUM_HEADS = 8

emb_src = nn.Embedding(int(src.max().item()) + 1, D_MODEL).to(DEVICE)
emb_tgt = nn.Embedding(int(tgt_in.max().item()) + 1, D_MODEL).to(DEVICE)

src_x = emb_src(src)
tgt_x = emb_tgt(tgt_in)

mha = MultiHeadAttention(d_model=D_MODEL, num_heads=NUM_HEADS, dropout=0.0).to(DEVICE)

src_pad_mask = batch["src_key_padding_mask"].to(DEVICE)
tgt_pad_mask = batch["tgt_key_padding_mask"].to(DEVICE)
tgt_causal = batch["tgt_subsequent_mask"].to(DEVICE)

enc_out, enc_attn = mha(src_x, src_x, src_x, attn_mask=None, key_padding_mask=src_pad_mask)
print("enc_out:", enc_out.shape, "enc_attn:", enc_attn.shape)

dec_out, dec_attn = mha(tgt_x, tgt_x, tgt_x, attn_mask=tgt_causal, key_padding_mask=tgt_pad_mask)
print("dec_out:", dec_out.shape, "dec_attn:", dec_attn.shape)

cross_out, cross_attn = mha(tgt_x, src_x, src_x, attn_mask=None, key_padding_mask=src_pad_mask)
print("cross_out:", cross_out.shape, "cross_attn:", cross_attn.shape)

enc_out: torch.Size([64, 18, 64]) enc_attn: torch.Size([64, 8, 18, 18])
dec_out: torch.Size([64, 20, 64]) dec_attn: torch.Size([64, 8, 20, 20])
cross_out: torch.Size([64, 20, 64]) cross_attn: torch.Size([64, 8, 20, 18])


In [14]:
class PositionwiseFFN(nn.Module):
    def __init__(self, d_model: int, d_ff: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x):
        return self.net(x)

In [15]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int = 1024, dropout: float = 0.1):
        super().__init__()

        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout=dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.ffn = PositionwiseFFN(d_model, d_ff=d_ff, dropout=dropout)

        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_key_padding_mask=None):
        attn_out, _ = self.self_attn(
            src, src, src,
            attn_mask=None,
            key_padding_mask=src_key_padding_mask
        )
        src = self.norm1(src + self.dropout(attn_out))

        ffn_out = self.ffn(src)
        src = self.norm2(src + self.dropout(ffn_out))

        return src

In [16]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers: int, d_model: int, num_heads: int, d_ff: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff=d_ff, dropout=dropout)
            for _ in range(num_layers)
        ])

    def forward(self, src, src_key_padding_mask=None):
        out = src
        for layer in self.layers:
            out = layer(out, src_key_padding_mask=src_key_padding_mask)
        return out

In [17]:
torch.manual_seed(0)

B, T, D = 2, 10, 32
NUM_HEADS = 4
NUM_LAYERS = 3

encoder = TransformerEncoder(
    num_layers=NUM_LAYERS,
    d_model=D,
    num_heads=NUM_HEADS,
    d_ff=128,
    dropout=0.1
).to(DEVICE)

x = torch.randn(B, T, D).to(DEVICE)

pad_mask = torch.zeros(B, T, dtype=torch.bool).to(DEVICE)
pad_mask[0, -3:] = True

out = encoder(x, src_key_padding_mask=pad_mask)
print("Input:", x.shape)
print("Output:", out.shape)

Input: torch.Size([2, 10, 32])
Output: torch.Size([2, 10, 32])


In [18]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int = 1024, dropout: float = 0.1):
        super().__init__()

        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout=dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout=dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.ffn = PositionwiseFFN(d_model, d_ff=d_ff, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        tgt,
        memory,
        tgt_subsequent_mask=None,
        tgt_key_padding_mask=None,
        memory_key_padding_mask=None,
    ):
        self_out, _ = self.self_attn(
            tgt, tgt, tgt,
            attn_mask=tgt_subsequent_mask,
            key_padding_mask=tgt_key_padding_mask
        )
        tgt = self.norm1(tgt + self.dropout(self_out))

        cross_out, _ = self.cross_attn(
            tgt, memory, memory,
            attn_mask=None,
            key_padding_mask=memory_key_padding_mask
        )
        tgt = self.norm2(tgt + self.dropout(cross_out))

        ffn_out = self.ffn(tgt)
        tgt = self.norm3(tgt + self.dropout(ffn_out))

        return tgt

In [19]:
class TransformerDecoder(nn.Module):
    def __init__(self, num_layers: int, d_model: int, num_heads: int, d_ff: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff=d_ff, dropout=dropout)
            for _ in range(num_layers)
        ])

    def forward(
        self,
        tgt,
        memory,
        tgt_subsequent_mask=None,
        tgt_key_padding_mask=None,
        memory_key_padding_mask=None,
    ):
        out = tgt
        for layer in self.layers:
            out = layer(
                out,
                memory,
                tgt_subsequent_mask=tgt_subsequent_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask,
            )
        return out

In [20]:
torch.manual_seed(0)

B, Ts, Tt, D = 2, 9, 7, 32
NUM_HEADS = 4
NUM_LAYERS = 2

decoder = TransformerDecoder(
    num_layers=NUM_LAYERS,
    d_model=D,
    num_heads=NUM_HEADS,
    d_ff=128,
    dropout=0.1
).to(DEVICE)

tgt = torch.randn(B, Tt, D, device=DEVICE)
memory = torch.randn(B, Ts, D, device=DEVICE)

tgt_sub = torch.triu(torch.ones((Tt, Tt), dtype=torch.bool, device=DEVICE), diagonal=1)
tgt_pad = torch.zeros(B, Tt, dtype=torch.bool, device=DEVICE); tgt_pad[0, -2:] = True
src_pad = torch.zeros(B, Ts, dtype=torch.bool, device=DEVICE); src_pad[1, -3:] = True

out = decoder(
    tgt,
    memory,
    tgt_subsequent_mask=tgt_sub,
    tgt_key_padding_mask=tgt_pad,
    memory_key_padding_mask=src_pad
)

print("tgt:", tgt.shape)
print("memory:", memory.shape)
print("out:", out.shape)

tgt: torch.Size([2, 7, 32])
memory: torch.Size([2, 9, 32])
out: torch.Size([2, 7, 32])


In [21]:
class TransformerSeq2Seq(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        tgt_vocab_size,
        pad_id_src,
        pad_id_tgt,
        d_model=256,
        num_heads=8,
        num_layers=4,
        d_ff=1024,
        dropout=0.1,
        max_len=256,
    ):
        super().__init__()

        self.src_embed = TransformerInputEmbedding(
            src_vocab_size, d_model, pad_id_src, max_len, dropout
        )
        self.tgt_embed = TransformerInputEmbedding(
            tgt_vocab_size, d_model, pad_id_tgt, max_len, dropout
        )

        self.encoder = TransformerEncoder(
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            d_ff=d_ff,
            dropout=dropout
        )

        self.decoder = TransformerDecoder(
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            d_ff=d_ff,
            dropout=dropout
        )

        self.out_proj = nn.Linear(d_model, tgt_vocab_size)

        self.pad_id_src = pad_id_src
        self.pad_id_tgt = pad_id_tgt

    def forward(
        self,
        src,
        tgt_in,
        src_key_padding_mask=None,
        tgt_key_padding_mask=None,
        tgt_subsequent_mask=None,
    ):
        src_x = self.src_embed(src)
        tgt_x = self.tgt_embed(tgt_in)

        memory = self.encoder(
            src_x,
            src_key_padding_mask=src_key_padding_mask
        )

        dec_out = self.decoder(
            tgt_x,
            memory,
            tgt_subsequent_mask=tgt_subsequent_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask
        )

        logits = self.out_proj(dec_out)
        return logits

In [22]:
D_MODEL = 256
NUM_HEADS = 8
NUM_LAYERS = 4
D_FF = 1024
DROPOUT = 0.1
MAX_LEN = 256

model = TransformerSeq2Seq(
    src_vocab_size=len(src_vocab["itos"]),
    tgt_vocab_size=len(tgt_vocab["itos"]),
    pad_id_src=pad_id_src,
    pad_id_tgt=pad_id_tgt,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    d_ff=D_FF,
    dropout=DROPOUT,
    max_len=MAX_LEN,
).to(DEVICE)

print("Model parameters:", sum(p.numel() for p in model.parameters())/1e6, "M")

Model parameters: 9.509612 M


In [23]:
batch = next(iter(train_dl))

src = batch["src"].to(DEVICE)
tgt_in = batch["tgt_in"].to(DEVICE)

src_pad_mask = batch["src_key_padding_mask"].to(DEVICE)
tgt_pad_mask = batch["tgt_key_padding_mask"].to(DEVICE)
tgt_causal = batch["tgt_subsequent_mask"].to(DEVICE)

logits = model(
    src,
    tgt_in,
    src_key_padding_mask=src_pad_mask,
    tgt_key_padding_mask=tgt_pad_mask,
    tgt_subsequent_mask=tgt_causal
)

print("src:", src.shape)
print("tgt_in:", tgt_in.shape)
print("logits:", logits.shape)

src: torch.Size([64, 27])
tgt_in: torch.Size([64, 37])
logits: torch.Size([64, 37, 3052])


In [24]:
OUT_DIR = Path("outputs_transformer")
OUT_DIR.mkdir(parents=True, exist_ok=True)

CKPT_DIR = OUT_DIR / "checkpoints"
CKPT_DIR.mkdir(parents=True, exist_ok=True)

LOG_PATH = OUT_DIR / "run_log.json"

def append_jsonl(path: Path, obj: dict):
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(obj, ensure_ascii=False) + "\n")

def safe_update_json(path: Path, key: str, value: dict):
    if path.exists():
        try:
            data = json.loads(path.read_text(encoding="utf-8"))
        except Exception:
            data = {}
    else:
        data = {}

    data[key] = value
    path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")

In [25]:
@dataclass
class TrainConfig:
    epochs: int = 50
    lr: float = 3e-4
    weight_decay: float = 0.0
    grad_clip: float = 1.0

cfg = TrainConfig()

In [26]:
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id_tgt)

opt = torch.optim.Adam(
    model.parameters(),
    lr=cfg.lr,
    weight_decay=cfg.weight_decay
)

In [27]:
def train_one_epoch(model, dl, opt, loss_fn, grad_clip: float):
    model.train()
    total_loss = 0.0
    steps = 0

    for batch in dl:
        src = batch["src"].to(DEVICE)
        tgt_in = batch["tgt_in"].to(DEVICE)
        tgt_out = batch["tgt_out"].to(DEVICE)

        src_pad_mask = batch["src_key_padding_mask"].to(DEVICE)
        tgt_pad_mask = batch["tgt_key_padding_mask"].to(DEVICE)
        tgt_causal = batch["tgt_subsequent_mask"].to(DEVICE)

        logits = model(
            src, tgt_in,
            src_key_padding_mask=src_pad_mask,
            tgt_key_padding_mask=tgt_pad_mask,
            tgt_subsequent_mask=tgt_causal
        )
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        opt.step()

        total_loss += float(loss.item())
        steps += 1

    return total_loss / max(steps, 1)


@torch.no_grad()
def eval_one_epoch(model, dl, loss_fn):
    model.eval()
    total_loss = 0.0
    steps = 0

    for batch in dl:
        src = batch["src"].to(DEVICE)
        tgt_in = batch["tgt_in"].to(DEVICE)
        tgt_out = batch["tgt_out"].to(DEVICE)

        src_pad_mask = batch["src_key_padding_mask"].to(DEVICE)
        tgt_pad_mask = batch["tgt_key_padding_mask"].to(DEVICE)
        tgt_causal = batch["tgt_subsequent_mask"].to(DEVICE)

        logits = model(
            src, tgt_in,
            src_key_padding_mask=src_pad_mask,
            tgt_key_padding_mask=tgt_pad_mask,
            tgt_subsequent_mask=tgt_causal
        )

        loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

        total_loss += float(loss.item())
        steps += 1

    return total_loss / max(steps, 1)

In [28]:
best_val = float("inf")
best_path = CKPT_DIR / "transformer_best.pt"

run_id = time.strftime("%Y%m%d_%H%M%S")

history = []

print("Starting training...")
t0_global = time.time()

for epoch in range(1, cfg.epochs + 1):
    t0 = time.time()

    train_loss = train_one_epoch(model, train_dl, opt, loss_fn, cfg.grad_clip)
    val_loss = eval_one_epoch(model, val_dl, loss_fn)

    epoch_time = time.time() - t0

    improved = val_loss < best_val
    if improved:
        best_val = val_loss
        torch.save(
            {
                "model_state": model.state_dict(),
                "config": {
                    "d_model": getattr(model, "d_model", None),
                    "num_layers": getattr(model, "num_layers", None),
                    "num_heads": getattr(model, "num_heads", None),
                },
                "epoch": epoch,
                "train_loss": train_loss,
                "val_loss": val_loss,
            },
            best_path
        )

    row = {
        "run_id": run_id,
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "epoch_time_sec": epoch_time,
        "best_val_loss_so_far": best_val,
        "saved_best": bool(improved),
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    }
    history.append(row)

    append_jsonl(OUT_DIR / "run_log.jsonl", row)
    safe_update_json(LOG_PATH, f"epoch_{epoch}", row)

    print(f"Epoch {epoch:02d} | train {train_loss:.4f} | val {val_loss:.4f} | time {epoch_time:.1f}s")

t_total = time.time() - t0_global
print(f"Training finished. Total time: {t_total:.1f}s")
print("Best checkpoint:", best_path)
print("Logs:", OUT_DIR / "run_log.jsonl", "and", LOG_PATH)


Starting training...
Epoch 01 | train 4.8136 | val 3.8597 | time 105.9s
Epoch 02 | train 3.8265 | val 3.3927 | time 39.4s
Epoch 03 | train 3.3822 | val 3.1408 | time 33.3s
Epoch 04 | train 3.0538 | val 2.9318 | time 52.2s
Epoch 05 | train 2.7713 | val 2.7202 | time 26.4s
Epoch 06 | train 2.5186 | val 2.5867 | time 28.5s
Epoch 07 | train 2.2932 | val 2.4803 | time 25.2s
Epoch 08 | train 2.0807 | val 2.4074 | time 25.9s
Epoch 09 | train 1.8863 | val 2.3346 | time 21.5s
Epoch 10 | train 1.6965 | val 2.2839 | time 22.1s
Epoch 11 | train 1.5261 | val 2.2515 | time 21.6s
Epoch 12 | train 1.3670 | val 2.2669 | time 21.3s
Epoch 13 | train 1.2145 | val 2.2031 | time 21.4s
Epoch 14 | train 1.0763 | val 2.2334 | time 21.4s
Epoch 15 | train 0.9676 | val 2.1965 | time 20.9s
Epoch 16 | train 0.8499 | val 2.2560 | time 22.2s
Epoch 17 | train 0.7498 | val 2.3235 | time 22.7s
Epoch 18 | train 0.6567 | val 2.3107 | time 20.8s
Epoch 19 | train 0.5851 | val 2.3496 | time 21.2s
Epoch 20 | train 0.5165 | va

In [29]:
CKPT_PATH = Path("outputs_transformer/checkpoints/transformer_best.pt")

assert CKPT_PATH.exists(), f"Checkpoint not found: {CKPT_PATH}"

ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

print("Loaded checkpoint from epoch:", ckpt.get("epoch"))

Loaded checkpoint from epoch: 15


In [30]:
@torch.no_grad()
def greedy_decode_transformer(model, src, src_pad_mask, bos_id, eos_id, max_len=40):
    B = src.size(0)
    ys = torch.full((B,1), bos_id, dtype=torch.long, device=DEVICE)

    for _ in range(max_len):
        tgt_pad_mask = (ys == pad_id_tgt)

        T = ys.size(1)
        tgt_causal = torch.triu(
            torch.ones((T, T), dtype=torch.bool, device=DEVICE),
            diagonal=1
        )

        logits = model(
            src,
            ys,
            src_key_padding_mask=src_pad_mask,
            tgt_key_padding_mask=tgt_pad_mask,
            tgt_subsequent_mask=tgt_causal
        )

        next_token = logits[:, -1].argmax(dim=-1, keepdim=True)
        ys = torch.cat([ys, next_token], dim=1)

        if (next_token == eos_id).all():
            break

    return ys

In [31]:
def ids_to_sentence(ids, vocab):
    tokens = []
    for i in ids:
        tok = vocab["itos"][int(i)]
        if tok == "<eos>":
            break
        if tok not in ["<bos>", "<pad>"]:
            tokens.append(tok)
    return " ".join(tokens)

In [32]:
refs = []
hyps = []

bos_id_tgt = tgt_vocab["stoi"]["<bos>"]
eos_id_tgt = tgt_vocab["stoi"]["<eos>"]

for batch in tqdm(test_dl, desc="Decoding"):
    src = batch["src"].to(DEVICE)
    src_pad_mask = batch["src_key_padding_mask"].to(DEVICE)
    tgt_out = batch["tgt_out"]

    pred_ids = greedy_decode_transformer(
        model,
        src,
        src_pad_mask,
        bos_id_tgt,
        eos_id_tgt,
        max_len=40
    )

    for ref_seq, hyp_seq in zip(tgt_out, pred_ids.cpu()):
        ref = ids_to_sentence(ref_seq, tgt_vocab)
        hyp = ids_to_sentence(hyp_seq, tgt_vocab)

        refs.append(ref)
        hyps.append(hyp)

print("Example:")
print("REF:", refs[0])
print("HYP:", hyps[0])

Decoding: 100%|██████████| 16/16 [00:46<00:00,  2.88s/it]

Example:
REF: él es su único hijo .
HYP: es solo niño .





In [40]:
bleu = sacrebleu.corpus_bleu(hyps, [refs], tokenize="13a", force=True)
print("BLEU:", bleu.score)

BLEU: 28.755664016702184


In [41]:
OUT_DIR = Path("outputs_transformer")
OUT_DIR.mkdir(exist_ok=True, parents=True)

pred_path = OUT_DIR / "transformer_predictions.tsv"
metrics_path = OUT_DIR / "transformer_bleu.json"

with open(pred_path, "w", encoding="utf-8") as f:
    for r, h in zip(refs, hyps):
        f.write(r + "\t" + h + "\n")

with open(metrics_path, "w", encoding="utf-8") as f:
    json.dump({
        "bleu": bleu.score,
        "num_sentences": len(hyps)
    }, f, indent=2)

print("Saved:")
print("-", pred_path)
print("-", metrics_path)

Saved:
- outputs_transformer/transformer_predictions.tsv
- outputs_transformer/transformer_bleu.json
