In [8]:
print("yes")

yes


## Imports & Global Constants

In [9]:
import json
import math
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

try:
    from datasets import load_dataset
except ImportError:
    load_dataset = None
try:
    from huggingface_hub import hf_hub_download
except ImportError:
    hf_hub_download = None

PAD_TOKEN = "<pad>"
SOS_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
UNK_TOKEN = "<unk>"
SPECIAL_TOKENS = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN]
MAX_TRAIN_LIMIT = 100_000


## Vocabulary Utilities

In [10]:
class CharVocab:
    def __init__(self, tokens: Optional[List[str]] = None, min_freq: int = 1):
        self.min_freq = min_freq
        self.token2idx: Dict[str, int] = {}
        self.idx2token: List[str] = []
        if tokens:
            self.build(tokens)

    def build(self, tokens: List[str]) -> None:
        freq: Dict[str, int] = {}
        for token in tokens:
            freq[token] = freq.get(token, 0) + 1
        self.idx2token = list(SPECIAL_TOKENS)
        for ch in sorted([c for c, f in freq.items() if f >= self.min_freq]):
            if ch in SPECIAL_TOKENS:
                continue
            self.idx2token.append(ch)
        self.token2idx = {t: i for i, t in enumerate(self.idx2token)}

    def __len__(self) -> int:
        return len(self.idx2token)

    def encode(self, text: str) -> List[int]:
        return [self.token2idx.get(ch, self.token2idx[UNK_TOKEN]) for ch in text]

    def decode(self, ids: List[int]) -> str:
        out: List[str] = []
        for idx in ids:
            if idx < 0 or idx >= len(self.idx2token):
                out.append(UNK_TOKEN)
            else:
                out.append(self.idx2token[idx])
        return "".join(out)

    @property
    def pad_idx(self) -> int:
        return self.token2idx[PAD_TOKEN]

    @property
    def sos_idx(self) -> int:
        return self.token2idx[SOS_TOKEN]

    @property
    def eos_idx(self) -> int:
        return self.token2idx[EOS_TOKEN]

    @property
    def unk_idx(self) -> int:
        return self.token2idx[UNK_TOKEN]

## Dataset Wrappers & Collate Function

In [11]:
class TransliterationDataset(Dataset):
    def __init__(self, records: List[Tuple[str, str]]):
        self.records = records

    def __len__(self) -> int:
        return len(self.records)

    def __getitem__(self, idx: int) -> Tuple[str, str]:
        return self.records[idx]


def collate_fn(batch, src_vocab: CharVocab, tgt_vocab: CharVocab, device: torch.device):
    src_seqs = [s for s, _ in batch]
    tgt_seqs = [t for _, t in batch]

    src_idxs = [src_vocab.encode(s) for s in src_seqs]
    tgt_in_idxs = [[tgt_vocab.sos_idx] + tgt_vocab.encode(t) for t in tgt_seqs]
    tgt_out_idxs = [tgt_vocab.encode(t) + [tgt_vocab.eos_idx] for t in tgt_seqs]

    src_lengths = [len(x) for x in src_idxs]
    tgt_lengths = [len(x) for x in tgt_out_idxs]

    max_src = max(src_lengths)
    max_tgt = max(tgt_lengths)

    src_padded = torch.full((len(batch), max_src), src_vocab.pad_idx, dtype=torch.long)
    for i, seq in enumerate(src_idxs):
        src_padded[i, : len(seq)] = torch.tensor(seq, dtype=torch.long)

    tgt_in_padded = torch.full((len(batch), max_tgt), tgt_vocab.pad_idx, dtype=torch.long)
    tgt_out_padded = torch.full((len(batch), max_tgt), tgt_vocab.pad_idx, dtype=torch.long)
    for i, (tin, tout) in enumerate(zip(tgt_in_idxs, tgt_out_idxs)):
        tgt_in_padded[i, : len(tin)] = torch.tensor(tin, dtype=torch.long)
        tgt_out_padded[i, : len(tout)] = torch.tensor(tout, dtype=torch.long)

    return (
        src_padded.to(device),
        torch.tensor(src_lengths, dtype=torch.long, device=device),
        tgt_in_padded.to(device),
        tgt_out_padded.to(device),
        torch.tensor(tgt_lengths, dtype=torch.long, device=device),
    )

## Positional Encoding Module

In [15]:
# ========================================
# Configuration Class
# Add this cell BEFORE LocalTransformer
# ========================================

from dataclasses import dataclass
from typing import Optional

@dataclass
class TransformerConfig:
    d_model: int = 256
    nhead: int = 4
    num_encoder_layers: int = 2
    num_decoder_layers: int = 2
    dim_feedforward: int = 512
    dropout: float = 0.1
    local_window: Optional[int] = None

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, maxlen: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        position = torch.arange(0, maxlen, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe = torch.zeros(maxlen, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class LocalTransformer(nn.Module):
    def __init__(self, src_vocab_size: int, tgt_vocab_size: int, config: TransformerConfig, device: torch.device):
        super().__init__()
        self.config = config
        self.device = device

        self.src_tok_emb = nn.Embedding(src_vocab_size, config.d_model, padding_idx=0)
        self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, config.d_model, padding_idx=0)
        self.pos_encoder = PositionalEncoding(config.d_model, dropout=config.dropout)
        self.pos_decoder = PositionalEncoding(config.d_model, dropout=config.dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout,
            batch_first=True,
        )
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout,
            batch_first=True,
        )

        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_encoder_layers)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_decoder_layers)
        self.generator = nn.Linear(config.d_model, tgt_vocab_size)

        self._init_weights()

    def _init_weights(self):
        for name, p in self.named_parameters():
            if 'weight' in name and p.dim() > 1:
                nn.init.xavier_uniform_(p, gain=1.0)
            elif 'bias' in name:
                nn.init.zeros_(p)

    def build_encoder_mask(self, seq_len: int, device: torch.device) -> Optional[torch.Tensor]:
        window = self.config.local_window

        # Use full attention if window is None or too large
        if window is None or window <= 0 or window >= seq_len:
            return None

        # Build mask: 0.0 = can attend, -inf = cannot attend
        mask = torch.full((seq_len, seq_len), float('-inf'), device=device, dtype=torch.float32)

        for i in range(seq_len):
            start = max(0, i - window)
            end = min(seq_len, i + window + 1)
            mask[i, start:end] = 0.0  # Allow attention in window

        return mask

    def build_decoder_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        window = self.config.local_window

        # Causal mask (upper triangular with -inf)
        mask = torch.triu(torch.full((seq_len, seq_len), float('-inf'),
                                     device=device, dtype=torch.float32), diagonal=1)

        # Add local window constraint
        if window is not None and window > 0 and window < seq_len:
            for i in range(seq_len):
                start = max(0, i - window)
                if start > 0:
                    mask[i, :start] = float('-inf')

        return mask

    def forward(self, src, tgt_in, src_key_padding_mask, tgt_key_padding_mask):
        # Embed and scale
        src_emb = self.src_tok_emb(src) * math.sqrt(self.config.d_model)
        tgt_emb = self.tgt_tok_emb(tgt_in) * math.sqrt(self.config.d_model)

        # Add positional encoding
        src_emb = self.pos_encoder(src_emb)
        tgt_emb = self.pos_decoder(tgt_emb)

        # Build masks
        src_mask = self.build_encoder_mask(src_emb.size(1), src_emb.device)
        tgt_mask = self.build_decoder_mask(tgt_emb.size(1), tgt_emb.device)

        # Encode
        memory = self.encoder(
            src_emb,
            mask=src_mask,
            src_key_padding_mask=src_key_padding_mask
        )

        if torch.isnan(memory).any():
            print("WARNING: NaN in encoder output")

        # Decode
        out = self.decoder(
            tgt_emb,
            memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask,
        )

        return self.generator(out)

    def encode(self, src, src_key_padding_mask):
        src_emb = self.src_tok_emb(src) * math.sqrt(self.config.d_model)
        src_emb = self.pos_encoder(src_emb)
        src_mask = self.build_encoder_mask(src_emb.size(1), src_emb.device)
        return self.encoder(src_emb, mask=src_mask, src_key_padding_mask=src_key_padding_mask)

    def decode(self, tgt, memory, src_key_padding_mask, tgt_key_padding_mask):
        tgt_emb = self.tgt_tok_emb(tgt) * math.sqrt(self.config.d_model)
        tgt_emb = self.pos_decoder(tgt_emb)
        tgt_mask = self.build_decoder_mask(tgt_emb.size(1), tgt_emb.device)
        out = self.decoder(
            tgt_emb,
            memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask,
        )
        return self.generator(out)

## Transformer Configuration & Model

## Data Loading Helpers

In [16]:
from __future__ import annotations

import json, os, re, random, zipfile, shutil, tempfile
from typing import List, Tuple, Optional
from huggingface_hub import hf_hub_download, list_repo_files


def read_jsonl(path: str) -> List[Tuple[str, str]]:
    """Read Aksharantar-style JSONL into (native, english) tuples."""
    records: List[Tuple[str, str]] = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            src = (obj.get("native word") or obj.get("native_word") or obj.get("native") or "").strip()
            tgt = (obj.get("english word") or obj.get("english_word") or obj.get("english") or "").strip().lower()
            if src and tgt:
                records.append((src, tgt))
    return records




# ---------- Local file helper ----------

_SPLIT_MAP = {
    "train": "train",
    "training": "train",
    "validation": "valid",
    "val": "valid",
    "dev": "valid",
    "test": "test",
}

def _resolve_local_file(base_dir: str, language: str, split: str) -> Optional[str]:
    """Find {lang}_{split}.{jsonl|json} in base_dir; return path or None."""
    target = _SPLIT_MAP.get(split.lower(), split.lower())
    candidates = [
        f"{language}_{target}.jsonl",
        f"{language}_{target}.json",
    ]
    for name in candidates:
        p = os.path.join(base_dir, name)
        if os.path.isfile(p):
            return p
    return None


import json, os, re, random
from typing import List, Tuple, Optional

BAD_BACKSLASH = re.compile(r'(?<!\\)\\(?!["\\/bfnrtu])')  # unescaped backslash not starting a valid escape

def _loads_relaxed(s: str):
    """Try strict JSON; if it fails, patch common issues and try again."""
    try:
        return json.loads(s)
    except json.JSONDecodeError:
        s2 = BAD_BACKSLASH.sub(r'\\\\', s)

        try:
            return json.loads(s2)
        except json.JSONDecodeError:
            return None

def read_jsonl(path: str, show_errors: int = 3) -> List[Tuple[str, str]]:
    """Robust reader: JSONL or JSON-array; skips malformed lines with minimal repair."""
    records: List[Tuple[str, str]] = []

    # Detect JSON array vs JSONL
    with open(path, "r", encoding="utf-8") as fpeek:
        start = fpeek.read(256).lstrip()
        is_array = start.startswith("[")

    if is_array:
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        iterable = data
        for obj in iterable:
            src = (obj.get("native word") or obj.get("native_word") or obj.get("native") or "").strip()
            tgt = (obj.get("english word") or obj.get("english_word") or obj.get("english") or "").strip().lower()
            if src and tgt:
                records.append((src, tgt))
        return records

    # JSONL path
    bad, shown = 0, 0
    with open(path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            obj = _loads_relaxed(line)
            if obj is None:
                bad += 1
                if shown < show_errors:
                    print(f"[warn] Skipping malformed JSON at line {i}: {line[:160]}...")
                    shown += 1
                continue
            src = (obj.get("native word") or obj.get("native_word") or obj.get("native") or "").strip()
            tgt = (obj.get("english word") or obj.get("english_word") or obj.get("english") or "").strip().lower()
            if src and tgt:
                records.append((src, tgt))
    if bad:
        print(f"[info] Skipped {bad} malformed lines in {os.path.basename(path)}")
    return records

def build_vocabs(records: List[Tuple[str, str]], min_freq: int = 1) -> Tuple["CharVocab","CharVocab"]:
    src_chars: List[str] = []
    tgt_chars: List[str] = []
    for src, tgt in records:
        src_chars.extend(list(src))
        tgt_chars.extend(list(tgt))
    src_vocab = CharVocab(src_chars, min_freq=min_freq)
    tgt_vocab = CharVocab(tgt_chars, min_freq=min_freq)
    return src_vocab, tgt_vocab

_SPLIT_MAP = {"train":"train","training":"train","validation":"valid","val":"valid","dev":"valid","test":"test"}

def load_aksharantar_local(
    language: str,
    split: str,
    base_dir: str = ".",
    max_examples: Optional[int] = None,
    shuffle: bool = False,
    seed: int = 42,
) -> List[Tuple[str, str]]:
    split_key = _SPLIT_MAP.get(split.lower(), split.lower())
    # Prefer .jsonl, then .json
    candidates = [f"{language}_{split_key}.jsonl", f"{language}_{split_key}.json"]
    path = next((os.path.join(base_dir, c) for c in candidates if os.path.isfile(os.path.join(base_dir, c))), None)
    if path is None:
        raise FileNotFoundError(f"Missing file for split '{split}': tried {candidates} in {os.path.abspath(base_dir)}")

    records = read_jsonl(path)

    if shuffle:
        rng = random.Random(seed)
        rng.shuffle(records)
    if max_examples and max_examples > 0:
        records = records[:max_examples]
    return records

# ---------- Optional: fetch from Hugging Face if file is missing locally ----------

def _download_single_file(language: str, target: str) -> Optional[str]:
    """Try to download {lang}_{target}.jsonl/.json; return local path or None."""
    for ext in ("jsonl", "json"):
        fname = f"{language}_{target}.{ext}"
        try:
            return hf_hub_download(
                repo_id="ai4bharat/Aksharantar",
                filename=fname,
                repo_type="dataset",
            )
        except Exception:
            continue
    return None


def _download_from_zip(language: str, target: str) -> Optional[str]:
    """Fallback: download {lang}.zip and extract the target file; return path or None."""
    try:
        zip_path = hf_hub_download(
            repo_id="ai4bharat/Aksharantar",
            filename=f"{language}.zip",
            repo_type="dataset",
        )
    except Exception:
        return None

    target_candidates = [f"{language}_{target}.jsonl", f"{language}_{target}.json"]
    with zipfile.ZipFile(zip_path) as zf:
        names = set(zf.namelist())
        hit = next((n for n in target_candidates if n in names), None)
        if not hit:
            return None
        tempdir = tempfile.mkdtemp(prefix=f"aksh_{language}_")
        out_path = os.path.join(tempdir, hit)
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        with zf.open(hit) as src, open(out_path, "wb") as dst:
            shutil.copyfileobj(src, dst)
        return out_path


def load_aksharantar(
    language: str,
    split: str,
    base_dir: str = ".",
    max_examples: Optional[int] = None,
    shuffle: bool = False,
    seed: int = 42,
) -> List[Tuple[str, str]]:
    """
    Local-first loader; if missing, fetches from HF (dataset repo).
    """
    target = _SPLIT_MAP.get(split.lower(), split.lower())

    # 1) Try local
    path = _resolve_local_file(base_dir, language, split)

    # 2) Try direct file download ({lang}_{target}.jsonl/.json)
    if path is None:
        path = _download_single_file(language, target)

    # 3) Fallback: try zip and extract
    if path is None:
        path = _download_from_zip(language, target)

    if path is None:
        raise FileNotFoundError(
            f"Could not resolve split '{split}' for language '{language}' locally or on HF.\n"
            f"Tried: {language}_{target}.jsonl/.json and {language}.zip."
        )

    # Read file
    records = read_jsonl(path) if path.endswith(".jsonl") else _read_json_as_jsonl(path)

    # Shuffle / truncate
    if shuffle:
        rng = random.Random(seed)
        rng.shuffle(records)
    if max_examples and max_examples > 0:
        records = records[:max_examples]
    return records


## Training & Evaluation Utilities

In [18]:
def trainepoch(model, dataloader, optimizer, criterion, srcpadidx: int, tgtpadidx: int) -> float:
    model.train()
    totalloss = 0.0
    for src, _, tgtin, tgtout, _ in tqdm(dataloader, desc="Train", leave=False):
        optimizer.zero_grad()
        srckeypaddingmask = (src == srcpadidx)
        tgtkeypaddingmask = (tgtin == tgtpadidx)
        logits = model(src, tgtin, srckeypaddingmask, tgtkeypaddingmask)
        B, T, V = logits.shape
        loss = criterion(logits.view(B * T, V), tgtout.view(B * T))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        totalloss += loss.item()
    return totalloss / max(len(dataloader), 1)


def evaluate(model, dataloader, criterion, srcpadidx: int, tgtpadidx: int) -> float:
    model.eval()
    totalloss = 0.0
    count = 0
    with torch.no_grad():
        for src, _, tgtin, tgtout, _ in tqdm(dataloader, desc="Val", leave=False):
            srckeypaddingmask = (src == srcpadidx)
            tgtkeypaddingmask = (tgtin == tgtpadidx)
            logits = model(src, tgtin, srckeypaddingmask, tgtkeypaddingmask)
            B, T, V = logits.shape
            loss = criterion(logits.view(B * T, V), tgtout.view(B * T))

            # Check if loss is valid before adding
            if not torch.isnan(loss) and not torch.isinf(loss):
                totalloss += loss.item()
                count += 1

    # Return average only if we have valid losses
    return totalloss / max(count, 1) if count > 0 else float('nan')


In [19]:
def subsample_records(records: List[Tuple[str, str]], limit: int = MAX_TRAIN_LIMIT, seed: int = 42) -> List[Tuple[str, str]]:
    if limit <= 0 or len(records) <= limit:
        return records
    limit = min(limit, MAX_TRAIN_LIMIT)
    rng = random.Random(seed)
    return rng.sample(records, limit)

## Decoding Helpers

In [20]:
def greedy_decode(model: LocalTransformer, src_seq: str, src_vocab: CharVocab, tgt_vocab: CharVocab, max_len: int = 80) -> str:
    model.eval()
    device = model.device
    with torch.no_grad():
        src_idxs = src_vocab.encode(src_seq)
        if not src_idxs:
            return ""
        src_tensor = torch.tensor([src_idxs], dtype=torch.long, device=device)
        src_key_padding_mask = (src_tensor == src_vocab.pad_idx)
        memory = model.encode(src_tensor, src_key_padding_mask)

        ys = torch.tensor([[tgt_vocab.sos_idx]], dtype=torch.long, device=device)
        tgt_key_padding_mask = (ys == tgt_vocab.pad_idx)

        output_tokens: List[int] = []
        for _ in range(max_len):
            logits = model.decode(ys, memory, src_key_padding_mask, tgt_key_padding_mask)
            next_token = logits[:, -1, :].argmax(dim=-1).item()
            if next_token == tgt_vocab.eos_idx:
                break
            output_tokens.append(next_token)
            ys = torch.cat([ys, torch.tensor([[next_token]], device=device)], dim=1)
            tgt_key_padding_mask = (ys == tgt_vocab.pad_idx)
        return tgt_vocab.decode(output_tokens)


def beam_search_decode(
    model: LocalTransformer,
    src_seq: str,
    src_vocab: CharVocab,
    tgt_vocab: CharVocab,
    beam_size: int = 3,
    max_len: int = 80,
) -> str:
    model.eval()
    device = model.device
    with torch.no_grad():
        src_idxs = src_vocab.encode(src_seq)
        if not src_idxs:
            return ""
        src_tensor = torch.tensor([src_idxs], dtype=torch.long, device=device)
        src_key_padding_mask = (src_tensor == src_vocab.pad_idx)
        memory = model.encode(src_tensor, src_key_padding_mask)

        beams: List[Tuple[float, List[int]]] = [(0.0, [tgt_vocab.sos_idx])]
        completed: List[Tuple[float, List[int]]] = []

        for _ in range(max_len):
            new_beams: List[Tuple[float, List[int]]] = []
            for log_prob, seq in beams:
                if seq[-1] == tgt_vocab.eos_idx:
                    completed.append((log_prob, seq))
                    continue
                ys = torch.tensor([seq], dtype=torch.long, device=device)
                tgt_key_padding_mask = (ys == tgt_vocab.pad_idx)
                logits = model.decode(ys, memory, src_key_padding_mask, tgt_key_padding_mask)
                log_probs = torch.log_softmax(logits[:, -1, :], dim=-1)
                top_logp, top_idx = log_probs.topk(beam_size)
                for lp, idx in zip(top_logp[0], top_idx[0]):
                    next_seq = seq + [idx.item()]
                    new_beams.append((log_prob + lp.item(), next_seq))
            if not new_beams:
                break
            new_beams.sort(key=lambda x: x[0], reverse=True)
            beams = new_beams[:beam_size]

        if not completed:
            completed = beams
        best_seq = max(completed, key=lambda x: x[0])[1]
        decoded: List[int] = []
        for tok in best_seq[1:]:
            if tok == tgt_vocab.eos_idx:
                break
            decoded.append(tok)
        return tgt_vocab.decode(decoded)

## Training Loop Helper

The following cell wraps model creation, DataLoader setup, training, and validation into a reusable helper function.

In [21]:
def prepare_dataloaders(
    records: List[Tuple[str, str]],
    device: torch.device,
    batch_size: int = 64,
    train_split: float = 0.9,
) -> Tuple[DataLoader, Optional[DataLoader], CharVocab, CharVocab]:
    random.shuffle(records)
    split_idx = int(len(records) * train_split)
    train_records = records[:split_idx]
    val_records = records[split_idx:] if split_idx < len(records) else []

    src_vocab, tgt_vocab = build_vocabs(train_records)

    def collate(batch):
        return collate_fn(batch, src_vocab, tgt_vocab, device)

    train_loader = DataLoader(
        TransliterationDataset(train_records),
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate,
    )
    val_loader = None
    if val_records:
        val_loader = DataLoader(
            TransliterationDataset(val_records),
            batch_size=batch_size,
            shuffle=False,
            collate_fn=collate,
        )
    return train_loader, val_loader, src_vocab, tgt_vocab


def train_model(
    train_loader: DataLoader,
    val_loader: Optional[DataLoader],
    src_vocab: CharVocab,
    tgt_vocab: CharVocab,
    device: torch.device,
    epochs: int = 5,
    lr: float = 3e-4,
    config: Optional[TransformerConfig] = None,
) -> Tuple[LocalTransformer, Dict[str, List[float]]]:
    if config is None:
        config = TransformerConfig()

    model = LocalTransformer(len(src_vocab), len(tgt_vocab), config, device=device).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.pad_idx)

    history = {"train_loss": [], "val_loss": []}
    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, criterion, src_vocab.pad_idx, tgt_vocab.pad_idx)
        history["train_loss"].append(train_loss)

        if val_loader is not None:
            val_loss = evaluate(model, val_loader, criterion, src_vocab.pad_idx, tgt_vocab.pad_idx)
            history["val_loss"].append(val_loss)
            print(f"Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}")
        else:
            print(f"Epoch {epoch}: train={train_loss:.4f}")
    return model, history

In [22]:


language = "hin"
train = load_aksharantar_local(language, "train", base_dir=".",max_examples=MAX_TRAIN_LIMIT)
valid = load_aksharantar_local(language, "validation", base_dir=".")
test  = load_aksharantar_local(language, "test", base_dir=".")

src_vocab, tgt_vocab = build_vocabs(train + valid)

[warn] Skipping malformed JSON at line 276011: {"unique_identifier": "hin276011", "native word": "समुद्रायायन", "english wor...
[info] Skipped 1 malformed lines in hin_train.jsonl


In [23]:
import inspect

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

val_records = valid or []                      # from your previous cell
all_records = (train or []) + (val_records)    # for vocab building
src_vocab, tgt_vocab = build_vocabs(all_records, min_freq=1)

def prepare_dataloaders_compat(prepare_fn, *,
                               train_records, val_records,
                               src_vocab, tgt_vocab,
                               device, batch_size):
    """Call `prepare_fn` regardless of whether it expects
       - (all_records, device, batch_size)  -> returns (train_loader, val_loader[, src_vocab, tgt_vocab])
       - (train, val, src_vocab, tgt_vocab, device, batch_size)
       - (train, val, device, batch_size)   -> returns vocabs or not
    """
    try:
        sig = inspect.signature(prepare_fn)
        params = list(sig.parameters.keys())
    except Exception:
        params = []

    # Case A: legacy signature: (all_records, device, batch_size)
    if len(params) >= 3 and params[0] not in {"train_records", "train", "X_train"}:
        out = prepare_fn(train_records + (val_records or []), device, batch_size)
        try:
            train_loader, val_loader, sv, tv = out
            src, tgt = sv, tv
        except ValueError:
            train_loader, val_loader = out
            src, tgt = src_vocab, tgt_vocab
        return train_loader, val_loader, src, tgt

    # Case B: expects (train, val, src_vocab, tgt_vocab, device, batch_size)
    try:
        out = prepare_fn(train_records, (val_records or None), src_vocab, tgt_vocab, device, batch_size)
        try:
            train_loader, val_loader, sv, tv = out
            return train_loader, val_loader, sv, tv
        except ValueError:
            train_loader, val_loader = out
            return train_loader, val_loader, src_vocab, tgt_vocab
    except TypeError:
        # Case C: (train, val, device, batch_size)
        out = prepare_fn(train_records, (val_records or None), device, batch_size)
        try:
            train_loader, val_loader, sv, tv = out
            return train_loader, val_loader, sv, tv
        except ValueError:
            train_loader, val_loader = out
            return train_loader, val_loader, src_vocab, tgt_vocab

# ---- use the compat wrapper ----
train_loader, val_loader, src_vocab, tgt_vocab = prepare_dataloaders_compat(
    prepare_dataloaders,
    train_records=train,
    val_records=val_records,
    src_vocab=src_vocab,
    tgt_vocab=tgt_vocab,
    device=device,
    batch_size=128,
)

print(f"Device: {device.type}")
print(f"Train batches: {len(train_loader) if train_loader else 0} | "
      f"Val batches: {len(val_loader) if val_loader else 0}")
print(f"Source vocab size: {len(src_vocab)} | Target vocab size: {len(tgt_vocab)}")

Device: cuda
Train batches: 748 | Val batches: 84
Source vocab size: 72 | Target vocab size: 30


In [25]:
import torch
import torch.nn as nn
from tqdm import tqdm

def train_step(model, batch, criterion, optimizer):
    model.train()
    src, src_len, tgt_in, tgt_out, tgt_len = batch

    src_mask = (src == src_vocab.pad_idx)
    tgt_mask = (tgt_in == tgt_vocab.pad_idx)

    optimizer.zero_grad()
    logits = model(src, tgt_in, src_mask, tgt_mask)

    B, T, V = logits.shape
    loss = criterion(logits.view(B * T, V), tgt_out.view(B * T))
    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    return loss.item()

def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for batch in dataloader:
            src, src_len, tgt_in, tgt_out, tgt_len = batch
            logits = model(src, tgt_in, (src == 0), (tgt_in == 0))

            B, T, V = logits.shape
            loss = criterion(logits.view(B * T, V), tgt_out.view(B * T))
            total_loss += loss.item()

    return total_loss / len(dataloader)
config = TransformerConfig(
    d_model=256,
    nhead=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=512,
    dropout=0.1,
    local_window=None  # Use full attention (or set to a number like 10 for local)
)

# Create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LocalTransformer(
    src_vocab_size=len(src_vocab),
    tgt_vocab_size=len(tgt_vocab),
    config=config,
    device=device
).to(device)

# Training loop
criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.pad_idx)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

history = {'train_loss': [], 'val_loss': []}

for epoch in range(10):
    train_losses = []

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for batch in pbar:
        loss = train_step(model, batch, criterion, optimizer)
        train_losses.append(loss)
        pbar.set_postfix({'loss': f'{loss:.4f}'})

    avg_train_loss = sum(train_losses) / len(train_losses)
    val_loss = evaluate(model, val_loader, criterion)

    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(val_loss)

    print(f"Epoch {epoch+1}: Train={avg_train_loss:.4f}, Val={val_loss:.4f}")

# Save model
torch.save({
    'model_state': model.state_dict(),
    'config': config.__dict__,
    'history': history
}, 'transformer_model.pt')

print("✅ Training complete!")


Epoch 1: 100%|██████████| 748/748 [00:22<00:00, 33.42it/s, loss=0.6495]
  output = torch._nested_tensor_from_mask(


Epoch 1: Train=0.9387, Val=0.5587


Epoch 2: 100%|██████████| 748/748 [00:21<00:00, 34.81it/s, loss=0.5947]


Epoch 2: Train=0.5884, Val=0.4776


Epoch 3: 100%|██████████| 748/748 [00:21<00:00, 34.91it/s, loss=0.4362]


Epoch 3: Train=0.5157, Val=0.4514


Epoch 4: 100%|██████████| 748/748 [00:21<00:00, 34.66it/s, loss=0.4899]


Epoch 4: Train=0.4842, Val=0.4305


Epoch 5: 100%|██████████| 748/748 [00:21<00:00, 34.36it/s, loss=0.4909]


Epoch 5: Train=0.4660, Val=0.4213


Epoch 6: 100%|██████████| 748/748 [00:21<00:00, 34.91it/s, loss=0.4340]


Epoch 6: Train=0.4521, Val=0.4191


Epoch 7: 100%|██████████| 748/748 [00:21<00:00, 35.32it/s, loss=0.4103]


Epoch 7: Train=0.4412, Val=0.4114


Epoch 8: 100%|██████████| 748/748 [00:21<00:00, 35.13it/s, loss=0.4033]


Epoch 8: Train=0.4324, Val=0.4118


Epoch 9: 100%|██████████| 748/748 [00:21<00:00, 34.81it/s, loss=0.4567]


Epoch 9: Train=0.4251, Val=0.4036


Epoch 10: 100%|██████████| 748/748 [00:21<00:00, 34.54it/s, loss=0.4084]


Epoch 10: Train=0.4191, Val=0.4023
✅ Training complete!


In [26]:
def vocab_to_dict(vocab):
    """Convert CharVocab object to dictionary for saving"""
    return {
        'idx2token': list(vocab.idx2token),
        'token2idx': dict(vocab.token2idx)
    }

# Save model with vocabularies included
torch.save({
    'model_state': model.state_dict(),
    'config': config.__dict__,
    'src_vocab': vocab_to_dict(src_vocab),
    'tgt_vocab': vocab_to_dict(tgt_vocab),
    'history': history,
}, 'transformer_model.pt')

print("✅ Model saved with vocabularies!")
print(f"   File: transformer_model.pt")
print(f"   Source vocab size: {len(src_vocab)}")
print(f"   Target vocab size: {len(tgt_vocab)}")
print(f"   Model parameters: {sum(p.numel() for p in model.parameters()):,}")


✅ Model saved with vocabularies!
   File: transformer_model.pt
   Source vocab size: 72
   Target vocab size: 30
   Model parameters: 2,669,598


In [27]:
word = "भागना"
greedy_output = greedy_decode(model, word, src_vocab, tgt_vocab)
beam_output = beam_search_decode(model, word, src_vocab, tgt_vocab, beam_size=10)
print("Input:", word)
print("Greedy:", greedy_output)
print("Beam:", beam_output)

Input: भागना
Greedy: bhagana
Beam: bhagna


In [28]:

import numpy as np

# -------------------- Metric Functions --------------------

def edit_distance(s1, s2):
    """Compute Levenshtein edit distance between two strings"""
    if len(s1) > len(s2):
        s1, s2 = s2, s1

    distances = range(len(s1) + 1)
    for i2, c2 in enumerate(s2):
        distances_ = [i2+1]
        for i1, c1 in enumerate(s1):
            if c1 == c2:
                distances_.append(distances[i1])
            else:
                distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
        distances = distances_
    return distances[-1]

def longest_common_subsequence_length(s1, s2):
    """
    Compute LCS length using formula from NEWS 2015 (Eq. 2):
    LCS(c, r) = 1/2 * (|c| + |r| - ED(c, r))
    """
    return 0.5 * (len(s1) + len(s2) - edit_distance(s1, s2))

def compute_word_accuracy(predictions, references):
    """
    Word Accuracy in Top-1 (ACC) - Section 3.1 of NEWS 2015
    Returns the proportion of exact matches between prediction and reference
    """
    correct = 0
    for pred, ref in zip(predictions, references):
        if pred == ref:
            correct += 1
    return correct / len(predictions) if predictions else 0.0

def compute_character_f1(predictions, references):
    """
    Mean F-score (Character-level) - Section 3.2 of NEWS 2015
    Based on longest common subsequence between prediction and reference
    """
    f_scores = []

    for pred, ref in zip(predictions, references):
        if len(pred) == 0 and len(ref) == 0:
            f_scores.append(1.0)
            continue
        elif len(pred) == 0 or len(ref) == 0:
            f_scores.append(0.0)
            continue

        # Compute LCS length
        lcs_len = longest_common_subsequence_length(pred, ref)

        # Precision and Recall
        precision = lcs_len / len(pred)
        recall = lcs_len / len(ref)

        # F1 score
        if precision + recall > 0:
            f1 = 2 * (precision * recall) / (precision + recall)
        else:
            f1 = 0.0

        f_scores.append(f1)

    return np.mean(f_scores) if f_scores else 0.0

# -------------------- Greedy Decoding --------------------

def greedy_decode(model, src, src_vocab, tgt_vocab, device, max_len=50):
    """
    Greedy decoding for transliteration
    Returns the predicted transliteration as a string
    """
    model.eval()

    with torch.no_grad():
        # Encode source
        src_mask = (src == src_vocab.pad_idx)
        memory = model.encode(src, src_mask)

        # Start with SOS token
        tgt = torch.LongTensor([[tgt_vocab.sos_idx]]).to(device)

        for _ in range(max_len):
            tgt_mask = (tgt == tgt_vocab.pad_idx)

            # Decode
            out = model.decode(tgt, memory, src_mask, tgt_mask)

            # Get next token
            logits = out[:, -1, :]
            next_token = logits.argmax(dim=-1, keepdim=True)

            # Stop if EOS
            if next_token.item() == tgt_vocab.eos_idx:
                break

            tgt = torch.cat([tgt, next_token], dim=1)

    # Convert to string
    tokens = tgt.squeeze(0).cpu().tolist()
    chars = [tgt_vocab.idx2token[idx] for idx in tokens
             if idx not in [tgt_vocab.pad_idx, tgt_vocab.sos_idx, tgt_vocab.eos_idx]]

    return ''.join(chars)

# -------------------- Evaluation Function --------------------

def evaluate_transliteration(model, dataloader, src_vocab, tgt_vocab, device, dataset_name="Dataset"):
    """
    Evaluate model on a dataset and return metrics
    """
    model.eval()
    predictions = []
    references = []

    print(f"\n{'='*70}")
    print(f"Evaluating on {dataset_name}")
    print('='*70)
    print("Generating predictions...")

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Decoding"):
            src, src_len, tgt_in, tgt_out, tgt_len = batch
            batch_size = src.size(0)

            # Generate predictions for each example in batch
            for i in range(batch_size):
                src_i = src[i:i+1]

                # Greedy decode
                pred = greedy_decode(model, src_i, src_vocab, tgt_vocab, device)
                predictions.append(pred)

                # Get reference (ground truth)
                ref_tokens = tgt_out[i].cpu().tolist()
                ref_chars = [tgt_vocab.idx2token[idx] for idx in ref_tokens
                           if idx not in [tgt_vocab.pad_idx, tgt_vocab.eos_idx]]
                ref = ''.join(ref_chars)
                references.append(ref)

    # Compute metrics
    print("\nComputing metrics...")
    word_acc = compute_word_accuracy(predictions, references)
    char_f1 = compute_character_f1(predictions, references)

    return word_acc, char_f1, predictions, references

# -------------------- Run Evaluation --------------------

# Evaluate on validation set
val_word_acc, val_char_f1, val_preds, val_refs = evaluate_transliteration(
    model, val_loader, src_vocab, tgt_vocab, device, "Validation Set"
)

# Evaluate on test set
if 'test_loader' not in globals():
    print("\nCreating test loader...")
    test_dataset = TransliterationDataset(test)

    def make_collate_fn(src_v, tgt_v, dev):
        def collate(batch):
            return collate_fn(batch, src_v, tgt_v, dev)
        return collate

    test_loader = DataLoader(
        test_dataset,
        batch_size=128,
        shuffle=False,
        collate_fn=make_collate_fn(src_vocab, tgt_vocab, device)
    )

test_word_acc, test_char_f1, test_preds, test_refs = evaluate_transliteration(
    model, test_loader, src_vocab, tgt_vocab, device, "Test Set"
)

# -------------------- Display Results --------------------

print("\n" + "="*70)
print("EVALUATION RESULTS")
print("="*70)
print("\n| Dataset       | Word Accuracy | Character F1 |")
print("|---------------|---------------|--------------|")
print(f"| Validation    | {val_word_acc:13.4f} | {val_char_f1:12.4f} |")
print(f"| Test          | {test_word_acc:13.4f} | {test_char_f1:12.4f} |")
print("="*70)

# -------------------- Show Examples --------------------

print("\n" + "="*70)
print("SAMPLE PREDICTIONS (First 10 from test set)")
print("="*70)
print(f"{'Source':<20} {'Reference':<20} {'Prediction':<20} {'Match':<10}")
print("-"*70)

for i in range(min(10, len(test_preds))):
    src_word = test[i][0] if i < len(test) else "N/A"
    ref = test_refs[i]
    pred = test_preds[i]
    match = "✓" if pred == ref else "✗"
    print(f"{src_word:<20} {ref:<20} {pred:<20} {match:<10}")

print("="*70)

# Save results dictionary
results = {
    'validation': {'word_accuracy': val_word_acc, 'character_f1': val_char_f1},
    'test': {'word_accuracy': test_word_acc, 'character_f1': test_char_f1},
    'predictions': {'val': val_preds, 'test': test_preds},
    'references': {'val': val_refs, 'test': test_refs}
}

print(f"\n✅ Evaluation complete!")
print(f"\nSummary:")
print(f"  Validation: {val_word_acc*100:.2f}% accuracy, {val_char_f1:.4f} F1")
print(f"  Test:       {test_word_acc*100:.2f}% accuracy, {test_char_f1:.4f} F1")



Evaluating on Validation Set
Generating predictions...


Decoding: 100%|██████████| 84/84 [04:42<00:00,  3.37s/it]



Computing metrics...

Creating test loader...

Evaluating on Test Set
Generating predictions...


Decoding: 100%|██████████| 79/79 [04:17<00:00,  3.26s/it]


Computing metrics...

EVALUATION RESULTS

| Dataset       | Word Accuracy | Character F1 |
|---------------|---------------|--------------|
| Validation    |        0.2640 |       0.9073 |
| Test          |        0.3507 |       0.9162 |

SAMPLE PREDICTIONS (First 10 from test set)
Source               Reference            Prediction           Match     
----------------------------------------------------------------------
मैट्रोलॉजिस्ट        maitrologist         matrologist          ✗         
पीएचडब्ल्यूसीएस      phwcs                phwcs                ✓         
प्रतिद्वन्दियों      pratidwandiyon       pratidwandiyon       ✓         
प्रतियुक्ति          pratiyukti           pratiyukti           ✓         
एक्सिसटेंस           eksisatens           existens             ✗         
फ़िल्मनिर्माता       filmnirmata          filmanirmata         ✗         
अद्र्घ               adrgh                adrgh                ✓         
लड़ेगे               ladhege              ladege     




In [29]:
# -------------------- Metric Functions --------------------

def edit_distance(s1, s2):
    """Compute Levenshtein edit distance between two strings"""
    if len(s1) > len(s2):
        s1, s2 = s2, s1

    distances = range(len(s1) + 1)
    for i2, c2 in enumerate(s2):
        distances_ = [i2+1]
        for i1, c1 in enumerate(s1):
            if c1 == c2:
                distances_.append(distances[i1])
            else:
                distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
        distances = distances_
    return distances[-1]

def longest_common_subsequence_length(s1, s2):
    """Compute LCS length using formula from NEWS 2015"""
    return 0.5 * (len(s1) + len(s2) - edit_distance(s1, s2))

def compute_precision_recall_f1(pred, ref):
    """
    Compute precision, recall, and F1 for a single prediction
    Returns: (precision, recall, f1)
    """
    if len(pred) == 0 and len(ref) == 0:
        return 1.0, 1.0, 1.0
    elif len(pred) == 0 or len(ref) == 0:
        return 0.0, 0.0, 0.0

    # Compute LCS length
    lcs_len = longest_common_subsequence_length(pred, ref)

    # Precision and Recall
    precision = lcs_len / len(pred)
    recall = lcs_len / len(ref)

    # F1 score
    if precision + recall > 0:
        f1 = 2 * (precision * recall) / (precision + recall)
    else:
        f1 = 0.0

    return precision, recall, f1

def compute_word_accuracy(predictions, references):
    """Word Accuracy in Top-1 (ACC)"""
    correct = 0
    for pred, ref in zip(predictions, references):
        if pred == ref:
            correct += 1
    return correct / len(predictions) if predictions else 0.0

def compute_metrics(predictions, references):
    """
    Compute all metrics: accuracy, mean precision, mean recall, mean F1
    Returns: (accuracy, mean_precision, mean_recall, mean_f1, per_example_metrics)
    """
    precisions = []
    recalls = []
    f1_scores = []

    for pred, ref in zip(predictions, references):
        prec, rec, f1 = compute_precision_recall_f1(pred, ref)
        precisions.append(prec)
        recalls.append(rec)
        f1_scores.append(f1)

    accuracy = compute_word_accuracy(predictions, references)
    mean_precision = np.mean(precisions)
    mean_recall = np.mean(recalls)
    mean_f1 = np.mean(f1_scores)

    per_example = list(zip(precisions, recalls, f1_scores))

    return accuracy, mean_precision, mean_recall, mean_f1, per_example

# -------------------- Greedy Decoding --------------------

def greedy_decode(model, src, src_vocab, tgt_vocab, device, max_len=50):
    """Greedy decoding for transliteration"""
    model.eval()

    with torch.no_grad():
        src_mask = (src == src_vocab.pad_idx)
        memory = model.encode(src, src_mask)

        tgt = torch.LongTensor([[tgt_vocab.sos_idx]]).to(device)

        for _ in range(max_len):
            tgt_mask = (tgt == tgt_vocab.pad_idx)
            out = model.decode(tgt, memory, src_mask, tgt_mask)
            logits = out[:, -1, :]
            next_token = logits.argmax(dim=-1, keepdim=True)

            if next_token.item() == tgt_vocab.eos_idx:
                break

            tgt = torch.cat([tgt, next_token], dim=1)

    tokens = tgt.squeeze(0).cpu().tolist()
    chars = [tgt_vocab.idx2token[idx] for idx in tokens
             if idx not in [tgt_vocab.pad_idx, tgt_vocab.sos_idx, tgt_vocab.eos_idx]]

    return ''.join(chars)

# -------------------- Beam Search Decoding --------------------

def beam_search_decode(model, src, src_vocab, tgt_vocab, device, beam_width=5, max_len=50):
    """
    Beam search decoding for transliteration
    Returns the best transliteration found
    """
    model.eval()

    with torch.no_grad():
        src_mask = (src == src_vocab.pad_idx)
        memory = model.encode(src, src_mask)

        # Initialize beam with SOS token
        # Each beam element: (score, sequence)
        beams = [(0.0, [tgt_vocab.sos_idx])]
        completed = []

        for step in range(max_len):
            candidates = []

            for score, seq in beams:
                # Skip if sequence already ended
                if seq[-1] == tgt_vocab.eos_idx:
                    completed.append((score, seq))
                    continue

                # Prepare input
                tgt = torch.LongTensor([seq]).to(device)
                tgt_mask = (tgt == tgt_vocab.pad_idx)

                # Decode
                out = model.decode(tgt, memory, src_mask, tgt_mask)
                logits = out[:, -1, :]
                log_probs = torch.log_softmax(logits, dim=-1)

                # Get top-k tokens
                top_log_probs, top_indices = torch.topk(log_probs, beam_width)

                # Add to candidates
                for log_prob, token_idx in zip(top_log_probs[0], top_indices[0]):
                    new_score = score + log_prob.item()
                    new_seq = seq + [token_idx.item()]
                    candidates.append((new_score, new_seq))

            # Select top beam_width candidates
            candidates = sorted(candidates, key=lambda x: x[0], reverse=True)
            beams = candidates[:beam_width]

            # Check if all beams completed
            if all(seq[-1] == tgt_vocab.eos_idx for _, seq in beams):
                completed.extend(beams)
                break

        # Add remaining beams to completed
        completed.extend(beams)

        # Select best sequence
        if not completed:
            best_seq = [tgt_vocab.sos_idx, tgt_vocab.eos_idx]
        else:
            # Normalize by length to avoid bias towards shorter sequences
            best_score, best_seq = max(completed, key=lambda x: x[0] / len(x[1]))

    # Convert to string
    chars = [tgt_vocab.idx2token[idx] for idx in best_seq
             if idx not in [tgt_vocab.pad_idx, tgt_vocab.sos_idx, tgt_vocab.eos_idx]]

    return ''.join(chars)

# -------------------- Evaluation Functions --------------------

def evaluate_with_decoding(model, dataloader, src_vocab, tgt_vocab, device,
                          decode_fn, dataset_name="Dataset", decode_name="Decoding"):
    """
    Evaluate model with specified decoding strategy
    """
    model.eval()
    predictions = []
    references = []

    print(f"\n{'='*70}")
    print(f"Evaluating on {dataset_name} using {decode_name}")
    print('='*70)

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"{decode_name}"):
            src, src_len, tgt_in, tgt_out, tgt_len = batch
            batch_size = src.size(0)

            for i in range(batch_size):
                src_i = src[i:i+1]

                # Decode
                pred = decode_fn(model, src_i, src_vocab, tgt_vocab, device)
                predictions.append(pred)

                # Get reference
                ref_tokens = tgt_out[i].cpu().tolist()
                ref_chars = [tgt_vocab.idx2token[idx] for idx in ref_tokens
                           if idx not in [tgt_vocab.pad_idx, tgt_vocab.eos_idx]]
                ref = ''.join(ref_chars)
                references.append(ref)

    # Compute all metrics
    accuracy, mean_prec, mean_rec, mean_f1, per_example = compute_metrics(predictions, references)

    return accuracy, mean_prec, mean_rec, mean_f1, per_example, predictions, references

# -------------------- Run Evaluations --------------------

print("\n" + "="*70)
print("RUNNING COMPREHENSIVE EVALUATION")
print("="*70)

# Create test loader if needed
if 'test_loader' not in globals():
    print("\nCreating test loader...")
    test_dataset = TransliterationDataset(test)

    def make_collate_fn(src_v, tgt_v, dev):
        def collate(batch):
            return collate_fn(batch, src_v, tgt_v, dev)
        return collate

    test_loader = DataLoader(
        test_dataset,
        batch_size=128,
        shuffle=False,
        collate_fn=make_collate_fn(src_vocab, tgt_vocab, device)
    )

# Evaluate with Greedy Decoding
print("\n[1/2] Greedy Decoding...")
greedy_acc, greedy_prec, greedy_rec, greedy_f1, greedy_per_ex, greedy_preds, greedy_refs = \
    evaluate_with_decoding(model, test_loader, src_vocab, tgt_vocab, device,
                          greedy_decode, "Test Set", "Greedy Decoding")

# Evaluate with Beam Search
print("\n[2/2] Beam Search Decoding...")
beam_acc, beam_prec, beam_rec, beam_f1, beam_per_ex, beam_preds, beam_refs = \
    evaluate_with_decoding(model, test_loader, src_vocab, tgt_vocab, device,
                          lambda m, s, sv, tv, d: beam_search_decode(m, s, sv, tv, d, beam_width=5),
                          "Test Set", "Beam Search (width=5)")

# -------------------- Display Results --------------------

print("\n" + "="*70)
print("EVALUATION RESULTS - COMPARISON")
print("="*70)

print("\n### GREEDY DECODING")
print(f"Top-1 Accuracy (ACC): {greedy_acc:.4f} ({greedy_acc*100:.2f}%)")
print(f"Mean Precision:       {greedy_prec:.4f}")
print(f"Mean Recall:          {greedy_rec:.4f}")
print(f"Mean F1 (Fuzziness):  {greedy_f1:.4f}")

print("\n### BEAM SEARCH DECODING (beam_width=5)")
print(f"Top-1 Accuracy (ACC): {beam_acc:.4f} ({beam_acc*100:.2f}%)")
print(f"Mean Precision:       {beam_prec:.4f}")
print(f"Mean Recall:          {beam_rec:.4f}")
print(f"Mean F1 (Fuzziness):  {beam_f1:.4f}")

# Comparison table
print("\n" + "="*70)
print("COMPARISON TABLE")
print("="*70)
print("\n| Metric           | Greedy Decoding | Beam Search (w=5) | Improvement |")
print("|------------------|-----------------|-------------------|-------------|")
print(f"| Top-1 Accuracy   | {greedy_acc:15.4f} | {beam_acc:17.4f} | {(beam_acc-greedy_acc)*100:10.2f}% |")
print(f"| Mean Precision   | {greedy_prec:15.4f} | {beam_prec:17.4f} | {(beam_prec-greedy_prec)*100:10.2f}% |")
print(f"| Mean Recall      | {greedy_rec:15.4f} | {beam_rec:17.4f} | {(beam_rec-greedy_rec)*100:10.2f}% |")
print(f"| Mean F1          | {greedy_f1:15.4f} | {beam_f1:17.4f} | {(beam_f1-greedy_f1)*100:10.2f}% |")
print("="*70)

# -------------------- Analyze Low F1 Examples (BOTH STRATEGIES) --------------------

print("\n" + "="*70)
print("WORDS WITH F1-SCORE < 0.5")
print("="*70)

# Collect low F1 examples for GREEDY
greedy_low_f1 = []
for i, (prec, rec, f1) in enumerate(greedy_per_ex):
    if f1 < 0.5:
        src_word = test[i][0] if i < len(test) else "N/A"
        ref = greedy_refs[i]
        pred = greedy_preds[i]
        greedy_low_f1.append((src_word, ref, pred, f1, prec, rec))

# Collect low F1 examples for BEAM SEARCH
beam_low_f1 = []
for i, (prec, rec, f1) in enumerate(beam_per_ex):
    if f1 < 0.5:
        src_word = test[i][0] if i < len(test) else "N/A"
        ref = beam_refs[i]
        pred = beam_preds[i]
        beam_low_f1.append((src_word, ref, pred, f1, prec, rec))

print(f"\n### GREEDY DECODING")
print(f"Total examples with F1 < 0.5: {len(greedy_low_f1)} / {len(greedy_refs)} ({len(greedy_low_f1)/len(greedy_refs)*100:.2f}%)")

if greedy_low_f1:
    print(f"\nShowing first 20 examples:")
    print(f"\n{'Source':<20} {'Reference':<20} {'Prediction':<20} {'F1':<8} {'Prec':<8} {'Rec':<8}")
    print("-"*100)

    for src, ref, pred, f1, prec, rec in greedy_low_f1[:20]:
        print(f"{src:<20} {ref:<20} {pred:<20} {f1:<8.4f} {prec:<8.4f} {rec:<8.4f}")
else:
    print("\n✅ No examples with F1 < 0.5!")

print(f"\n### BEAM SEARCH DECODING (beam_width=5)")
print(f"Total examples with F1 < 0.5: {len(beam_low_f1)} / {len(beam_refs)} ({len(beam_low_f1)/len(beam_refs)*100:.2f}%)")

if beam_low_f1:
    print(f"\nShowing first 20 examples:")
    print(f"\n{'Source':<20} {'Reference':<20} {'Prediction':<20} {'F1':<8} {'Prec':<8} {'Rec':<8}")
    print("-"*100)

    for src, ref, pred, f1, prec, rec in beam_low_f1[:20]:
        print(f"{src:<20} {ref:<20} {pred:<20} {f1:<8.4f} {prec:<8.4f} {rec:<8.4f}")
else:
    print("\n✅ No examples with F1 < 0.5!")

# -------------------- Compare Low F1 Examples --------------------

print("\n" + "="*70)
print("COMPARISON: Examples where one strategy has F1 < 0.5 but other doesn't")
print("="*70)

# Find examples where greedy fails but beam search succeeds
greedy_fails_beam_succeeds = []
for i, (g_prec, g_rec, g_f1) in enumerate(greedy_per_ex):
    b_prec, b_rec, b_f1 = beam_per_ex[i]
    if g_f1 < 0.5 and b_f1 >= 0.5:
        src_word = test[i][0] if i < len(test) else "N/A"
        ref = greedy_refs[i]
        greedy_fails_beam_succeeds.append((
            src_word, ref, greedy_preds[i], beam_preds[i], g_f1, b_f1
        ))

# Find examples where beam fails but greedy succeeds
beam_fails_greedy_succeeds = []
for i, (g_prec, g_rec, g_f1) in enumerate(greedy_per_ex):
    b_prec, b_rec, b_f1 = beam_per_ex[i]
    if b_f1 < 0.5 and g_f1 >= 0.5:
        src_word = test[i][0] if i < len(test) else "N/A"
        ref = greedy_refs[i]
        beam_fails_greedy_succeeds.append((
            src_word, ref, greedy_preds[i], beam_preds[i], g_f1, b_f1
        ))

print(f"\n### Greedy fails (F1<0.5) but Beam Search succeeds (F1≥0.5)")
print(f"Count: {len(greedy_fails_beam_succeeds)}")

if greedy_fails_beam_succeeds:
    print(f"\nFirst 15 examples:")
    print(f"\n{'Source':<15} {'Reference':<15} {'Greedy Pred':<15} {'Beam Pred':<15} {'G-F1':<8} {'B-F1':<8}")
    print("-"*95)

    for src, ref, g_pred, b_pred, g_f1, b_f1 in greedy_fails_beam_succeeds[:15]:
        print(f"{src:<15} {ref:<15} {g_pred:<15} {b_pred:<15} {g_f1:<8.4f} {b_f1:<8.4f}")

print(f"\n### Beam Search fails (F1<0.5) but Greedy succeeds (F1≥0.5)")
print(f"Count: {len(beam_fails_greedy_succeeds)}")

if beam_fails_greedy_succeeds:
    print(f"\nFirst 15 examples:")
    print(f"\n{'Source':<15} {'Reference':<15} {'Greedy Pred':<15} {'Beam Pred':<15} {'G-F1':<8} {'B-F1':<8}")
    print("-"*95)

    for src, ref, g_pred, b_pred, g_f1, b_f1 in beam_fails_greedy_succeeds[:15]:
        print(f"{src:<15} {ref:<15} {g_pred:<15} {b_pred:<15} {g_f1:<8.4f} {b_f1:<8.4f}")

# -------------------- Both Fail Analysis --------------------

print("\n" + "="*70)
print("WORST CASES: Both strategies fail (F1 < 0.5)")
print("="*70)

both_fail = []
for i, (g_prec, g_rec, g_f1) in enumerate(greedy_per_ex):
    b_prec, b_rec, b_f1 = beam_per_ex[i]
    if g_f1 < 0.5 and b_f1 < 0.5:
        src_word = test[i][0] if i < len(test) else "N/A"
        ref = greedy_refs[i]
        both_fail.append((
            src_word, ref, greedy_preds[i], beam_preds[i], g_f1, b_f1
        ))

print(f"\nTotal: {len(both_fail)} examples where both strategies have F1 < 0.5")

if both_fail:
    # Sort by average F1 to show worst cases first
    both_fail_sorted = sorted(both_fail, key=lambda x: (x[4] + x[5]) / 2)

    print(f"\nWorst 20 examples (sorted by average F1):")
    print(f"\n{'Source':<15} {'Reference':<15} {'Greedy Pred':<15} {'Beam Pred':<15} {'G-F1':<8} {'B-F1':<8}")
    print("-"*95)

    for src, ref, g_pred, b_pred, g_f1, b_f1 in both_fail_sorted[:20]:
        print(f"{src:<15} {ref:<15} {g_pred:<15} {b_pred:<15} {g_f1:<8.4f} {b_f1:<8.4f}")

# Save detailed results
results['low_f1_analysis'] = {
    'greedy_low_f1': greedy_low_f1,
    'beam_low_f1': beam_low_f1,
    'greedy_fails_beam_succeeds': greedy_fails_beam_succeeds,
    'beam_fails_greedy_succeeds': beam_fails_greedy_succeeds,
    'both_fail': both_fail
}

print("\n" + "="*70)
print(f"✅ Low F1 analysis complete!")
print("="*70)




RUNNING COMPREHENSIVE EVALUATION

[1/2] Greedy Decoding...

Evaluating on Test Set using Greedy Decoding


Greedy Decoding: 100%|██████████| 79/79 [04:34<00:00,  3.47s/it]



[2/2] Beam Search Decoding...

Evaluating on Test Set using Beam Search (width=5)


Beam Search (width=5): 100%|██████████| 79/79 [1:10:16<00:00, 53.38s/it]



EVALUATION RESULTS - COMPARISON

### GREEDY DECODING
Top-1 Accuracy (ACC): 0.3507 (35.07%)
Mean Precision:       0.9328
Mean Recall:          0.9085
Mean F1 (Fuzziness):  0.9162

### BEAM SEARCH DECODING (beam_width=5)
Top-1 Accuracy (ACC): 0.3518 (35.18%)
Mean Precision:       0.9285
Mean Recall:          0.9126
Mean F1 (Fuzziness):  0.9161

COMPARISON TABLE

| Metric           | Greedy Decoding | Beam Search (w=5) | Improvement |
|------------------|-----------------|-------------------|-------------|
| Top-1 Accuracy   |          0.3507 |            0.3518 |       0.11% |
| Mean Precision   |          0.9328 |            0.9285 |      -0.43% |
| Mean Recall      |          0.9085 |            0.9126 |       0.41% |
| Mean F1          |          0.9162 |            0.9161 |      -0.01% |

WORDS WITH F1-SCORE < 0.5

### GREEDY DECODING
Total examples with F1 < 0.5: 35 / 10112 (0.35%)

Showing first 20 examples:

Source               Reference            Prediction           F1       