# Deep Past Challenge v2: ByT5-base Two-Stage Training

**Improvements over v1:**
1. ByT5-base (580M) instead of ByT5-small (300M) with gradient checkpointing
2. Sentence-level alignment data from `Sentences_Oare_FirstWord_LinNum.csv`
3. R-Drop regularization (dual forward passes with KL divergence)
4. Bidirectional training in Stage 1 (Akkadian->English + English->Akkadian)
5. Checkpoint averaging for final model
6. Adafactor optimizer (memory-efficient)
7. Label smoothing (0.1), cosine annealing, better hyperparameters

**Metric**: `sqrt(BLEU * chrF++)` via SacreBLEU

In [None]:
!pip install -q sacrebleu transformers accelerate sentencepiece

In [None]:
import warnings
warnings.simplefilter("ignore")

import os
import re
import gc
import copy
import math
import json
import glob
import unicodedata
import collections
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    T5ForConditionalGeneration,
)
from transformers.optimization import Adafactor, AdafactorSchedule
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import sacrebleu

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")
if DEVICE == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    props = torch.cuda.get_device_properties(0)
    print(f"Memory: {props.total_mem / 1e9:.1f} GB" if hasattr(props, 'total_mem') else f"Memory: {props.total_memory / 1e9:.1f} GB")

## Configuration

In [None]:
# ============================================================
# Paths - adjust for Kaggle vs local
# ============================================================
IS_KAGGLE = os.path.exists("/kaggle/input")

if IS_KAGGLE:
    COMP_DATA = "/kaggle/input/deep-past-initiative-machine-translation"
    AKKADEMIA_PATH = "/kaggle/input/akkademia-nmt"
    ORACC_PATH = "/kaggle/input/oracc-akkadian-english-parallel-corpus"
    OUTPUT_DIR = "/kaggle/working"
else:
    COMP_DATA = "data"
    AKKADEMIA_PATH = "external_data/akkademia/NMT_input"
    ORACC_PATH = "external_data/oracc"
    OUTPUT_DIR = "output"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================================================
# Model & Training Config
# ============================================================
MODEL_NAME = "google/byt5-base"  # 580M params, byte-level

# Data sources
USE_AKKADEMIA = True
USE_ORACC = True
USE_COMP_DATA = True
USE_SENTENCE_ALIGN = True  # sentence-level alignment data
USE_BIDIRECTIONAL = True   # reverse pairs for Stage 1

# Shared
MAX_SOURCE_LEN = 384
MAX_TARGET_LEN = 384
PREFIX_AK2EN = "translate Akkadian to English: "
PREFIX_EN2AK = "translate English to Akkadian: "

# Stage 1: General Akkadian
STAGE1_EPOCHS = 5
STAGE1_LR = 1e-4
STAGE1_BATCH = 2
STAGE1_GRAD_ACC = 16  # effective batch = 32
STAGE1_WARMUP = 500

# Stage 2: Old Assyrian specialization
STAGE2_EPOCHS = 15
STAGE2_LR = 5e-5
STAGE2_BATCH = 2
STAGE2_GRAD_ACC = 16  # effective batch = 32
STAGE2_WARMUP = 100

# Regularization
LABEL_SMOOTHING = 0.1
WEIGHT_DECAY = 0.01
RDROP_ALPHA = 1.0  # KL divergence weight for R-Drop

# Early stopping
PATIENCE = 5

# Inference
BEAM_WIDTH = 4
REP_PENALTY = 1.2

# Checkpoint averaging
NUM_CKPTS_TO_AVG = 3  # average top-N checkpoints

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

print("Configuration set.")
print(f"  Model: {MODEL_NAME}")
print(f"  Stage 1: {STAGE1_EPOCHS} epochs, lr={STAGE1_LR}, batch={STAGE1_BATCH}x{STAGE1_GRAD_ACC}")
print(f"  Stage 2: {STAGE2_EPOCHS} epochs, lr={STAGE2_LR}, batch={STAGE2_BATCH}x{STAGE2_GRAD_ACC}")
print(f"  R-Drop alpha: {RDROP_ALPHA}")
print(f"  Label smoothing: {LABEL_SMOOTHING}")
print(f"  Bidirectional: {USE_BIDIRECTIONAL}")
print(f"  Checkpoint averaging: top {NUM_CKPTS_TO_AVG}")

## Preprocessing

In [None]:
# Subscript digit mapping
SUBSCRIPT_MAP = str.maketrans(
    "\u2080\u2081\u2082\u2083\u2084\u2085\u2086\u2087\u2088\u2089",
    "0123456789"
)

# ASCII to diacritic normalization for Akkadian transliteration
ASCII_TO_DIACRITIC = {
    "sz": "\u0161", "SZ": "\u0160", "Sz": "\u0160",
    "sh": "\u0161", "SH": "\u0160", "Sh": "\u0160",
    "s,": "\u1E63", "S,": "\u1E62",
    "t,": "\u1E6D", "T,": "\u1E6C",
    ".s": "\u1E63", ".S": "\u1E62",
    ".t": "\u1E6D", ".T": "\u1E6C",
    "h,": "\u1E2B", "H,": "\u1E2A",
    ".h": "\u1E2B", ".H": "\u1E2A",
}


def normalize_ascii(text):
    """Normalize ASCII representations to proper Akkadian diacritics."""
    for old, new in ASCII_TO_DIACRITIC.items():
        text = text.replace(old, new)
    return text


def normalize_gaps(text):
    """Normalize gap markers."""
    text = re.sub(r'\[x\]', '<gap>', text)
    text = re.sub(r'\[\.{3,}[^\]]*\]', '<big_gap>', text)
    text = re.sub(r'\.{3,}', '<big_gap>', text)
    text = re.sub(r'\u2026', '<big_gap>', text)  # ellipsis character
    return text


def clean_akkadian(text):
    """Clean and normalize Akkadian transliteration."""
    if pd.isna(text) or not str(text).strip():
        return ""
    text = str(text)
    text = unicodedata.normalize("NFC", text)
    # Remove editorial marks
    text = text.replace("!", "").replace("?", "")
    text = re.sub(r'[\u02F9\u02FA]', '', text)  # half-brackets
    # Remove square brackets but keep content
    text = re.sub(r'\[([^\]]*)\]', r'\1', text)
    # Normalize
    text = normalize_ascii(text)
    text = normalize_gaps(text)
    text = text.translate(SUBSCRIPT_MAP)
    # Remove stray slashes, colons, dots that are editorial
    text = re.sub(r'[/:.](?![\d])', ' ', text)
    # Collapse whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text


def clean_english(text):
    """Clean English translation text."""
    if pd.isna(text) or not str(text).strip():
        return ""
    text = str(text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text


print("Preprocessing functions defined.")

# Quick test
sample = "KIŠIB ma-nu-ba-lúm-a-šur DUMU ṣí-lá-(d)IM [x] ... i-ša-qal!"
print(f"Original: {sample}")
print(f"Cleaned:  {clean_akkadian(sample)}")

## Load & Prepare Data

In [None]:
# ============================================================
# 1. Competition data (Old Assyrian - used in both stages)
# ============================================================
old_assyrian_data = []
comp_df = None

if USE_COMP_DATA:
    comp_df = pd.read_csv(os.path.join(COMP_DATA, "train.csv"))
    print(f"Competition data: {len(comp_df)} documents")

    for _, row in comp_df.iterrows():
        src = clean_akkadian(row['transliteration'])
        tgt = clean_english(row['translation'])
        if src and tgt and len(src) > 10 and len(tgt) > 10:
            old_assyrian_data.append({'source': src, 'target': tgt})

    print(f"  After cleaning: {len(old_assyrian_data)} document-level pairs")

# ============================================================
# 2. Akkademia data (mixed periods - Stage 1 only)
# ============================================================
general_akkadian_data = []

if USE_AKKADEMIA and os.path.exists(AKKADEMIA_PATH):
    for split in ['train', 'valid', 'test']:
        tr_path = os.path.join(AKKADEMIA_PATH, f'{split}.tr')
        en_path = os.path.join(AKKADEMIA_PATH, f'{split}.en')
        if os.path.exists(tr_path) and os.path.exists(en_path):
            translits = open(tr_path, encoding='utf-8').read().splitlines()
            transls = open(en_path, encoding='utf-8').read().splitlines()
            for tr, en in zip(translits, transls):
                src = clean_akkadian(tr)
                tgt = clean_english(en)
                if src and tgt and len(src) > 5 and len(tgt) > 5:
                    general_akkadian_data.append({'source': src, 'target': tgt})
    print(f"Akkademia data: {len(general_akkadian_data)} pairs")
else:
    print("Akkademia data not found, skipping.")

# ============================================================
# 3. ORACC parallel corpus
# ============================================================
if USE_ORACC:
    oracc_csv = os.path.join(ORACC_PATH, "train.csv")
    if os.path.exists(oracc_csv):
        oracc_df = pd.read_csv(oracc_csv)
        print(f"ORACC data: {len(oracc_df)} pairs")
        for _, row in oracc_df.iterrows():
            src = clean_akkadian(row.get('akkadian', ''))
            tgt = clean_english(row.get('english', ''))
            if src and tgt and len(src) > 5 and len(tgt) > 5:
                general_akkadian_data.append({'source': src, 'target': tgt})
        print(f"  Total general data after ORACC: {len(general_akkadian_data)} pairs")
    else:
        print("ORACC data not found, skipping.")

print(f"\n=== Data Summary (before sentence alignment) ===")
print(f"Old Assyrian (competition): {len(old_assyrian_data)} doc-level pairs")
print(f"General Akkadian (external): {len(general_akkadian_data)} pairs")

## Sentence Alignment

In [None]:
# ============================================================
# 4. Sentence-level alignment data
# ============================================================
# The file Sentences_Oare_FirstWord_LinNum.csv has sentence-level
# English translations with text_uuid that maps to oare_id in train.csv.
# We extract corresponding Akkadian segments from the document
# transliteration using line numbers.
# ============================================================

sentence_pairs = []

if USE_SENTENCE_ALIGN and comp_df is not None:
    sent_file = os.path.join(COMP_DATA, "Sentences_Oare_FirstWord_LinNum.csv")
    if os.path.exists(sent_file):
        sent_df = pd.read_csv(sent_file)
        print(f"Sentence alignment file: {len(sent_df)} rows, {sent_df['text_uuid'].nunique()} unique documents")

        # Build lookup: oare_id -> transliteration lines
        # We split the document transliteration by spaces that likely
        # correspond to word boundaries, then use first_word_spelling
        # and line_number to find sentence boundaries.
        doc_translit = {}
        for _, row in comp_df.iterrows():
            doc_translit[row['oare_id']] = str(row['transliteration'])

        # Find overlapping documents
        overlap_ids = set(comp_df['oare_id']) & set(sent_df['text_uuid'])
        print(f"Overlapping documents: {len(overlap_ids)}")

        # For each overlapping document, extract sentence-level pairs
        # Strategy: group sentences by document, sort by line_number,
        # and use consecutive first_word_spelling values to split the
        # transliteration into sentence segments.
        aligned_count = 0
        fallback_count = 0

        for doc_id in overlap_ids:
            translit = doc_translit.get(doc_id, "")
            if not translit:
                continue

            doc_sents = sent_df[sent_df['text_uuid'] == doc_id].sort_values('sentence_obj_in_text')
            if len(doc_sents) == 0:
                continue

            # Collect the first_word_spelling for each sentence to find boundaries
            sent_list = []
            for _, srow in doc_sents.iterrows():
                eng = clean_english(srow.get('translation', ''))
                fw_spelling = str(srow.get('first_word_spelling', '')) if pd.notna(srow.get('first_word_spelling')) else ''
                fw_number = srow.get('first_word_number', -1)
                if pd.isna(fw_number):
                    fw_number = -1
                sent_list.append({
                    'translation': eng,
                    'first_word_spelling': fw_spelling,
                    'first_word_number': int(fw_number),
                })

            # Split the document transliteration into words
            words = translit.split()

            # For each sentence, find the start position using first_word_number
            # first_word_number is 1-indexed position of the first word in the text
            # (as counted in the original OARE format)
            for i, s in enumerate(sent_list):
                eng = s['translation']
                if not eng or len(eng) < 5:
                    continue

                fw_num = s['first_word_number']
                if fw_num < 1:
                    continue

                # Determine start and end word indices (0-indexed)
                start_idx = fw_num - 1  # convert to 0-indexed

                # End index: use next sentence's first_word_number, or end of doc
                if i + 1 < len(sent_list) and sent_list[i + 1]['first_word_number'] > 0:
                    end_idx = sent_list[i + 1]['first_word_number'] - 1
                else:
                    end_idx = len(words)

                if start_idx >= len(words):
                    continue
                end_idx = min(end_idx, len(words))

                akk_segment = ' '.join(words[start_idx:end_idx])
                akk_clean = clean_akkadian(akk_segment)

                if akk_clean and len(akk_clean) > 3:
                    sentence_pairs.append({
                        'source': akk_clean,
                        'target': eng,
                    })
                    aligned_count += 1

        print(f"Sentence-aligned pairs extracted: {aligned_count}")

        # Also use sentences from NON-overlapping documents
        # (they have English but no Akkadian transliteration from train.csv)
        # We cannot use these directly since we don't have the Akkadian source.
        # However, the first_word_spelling column provides a small Akkadian fragment.
        # We skip these as they are too fragmented.

    else:
        print("Sentence alignment file not found, skipping.")
else:
    print("Sentence alignment disabled or no competition data.")

print(f"\nSentence-level pairs: {len(sentence_pairs)}")

# Show samples
if sentence_pairs:
    print("\n--- Sample sentence-aligned pairs ---")
    for p in sentence_pairs[:3]:
        print(f"  AKK: {p['source'][:120]}")
        print(f"  ENG: {p['target'][:120]}")
        print()

In [None]:
# ============================================================
# 5. Split document-level data into sentence-level pairs
# ============================================================
# For competition docs NOT in the sentence alignment file,
# we try splitting by newlines and sentence-ending punctuation.
# ============================================================

doc_split_pairs = []

if comp_df is not None:
    overlap_ids = set(sent_df['text_uuid']) if USE_SENTENCE_ALIGN and 'sent_df' in dir() else set()

    for _, row in comp_df.iterrows():
        translit = str(row['transliteration'])
        translation = str(row['translation'])

        # Skip docs that we already have sentence-level data for
        if row['oare_id'] in overlap_ids:
            continue

        # Skip very short docs (already essentially sentence-level)
        if len(translit.split()) < 20:
            continue

        # Try splitting the English translation into sentences
        # Common patterns: sentences ending with period, or separated by semicolons
        eng_sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', translation)

        # If we got meaningful splits and the doc has clear line structure
        if len(eng_sentences) >= 2:
            # Split Akkadian by common separators (newlines already collapsed,
            # but some docs use certain patterns as line breaks)
            # We use a heuristic: split on words that commonly start new clauses
            # This is imperfect, so only use when we get a reasonable number of segments
            pass  # Document splitting is too noisy without line numbers; skip

    print(f"Document-split pairs: {len(doc_split_pairs)}")

print(f"\n=== Final Data Summary ===")
print(f"Old Assyrian doc-level: {len(old_assyrian_data)} pairs")
print(f"Sentence-aligned: {len(sentence_pairs)} pairs")
print(f"General Akkadian: {len(general_akkadian_data)} pairs")
print(f"Total: {len(old_assyrian_data) + len(sentence_pairs) + len(general_akkadian_data)} pairs")

In [None]:
# ============================================================
# 6. Create bidirectional pairs for Stage 1
# ============================================================

bidirectional_data = []

if USE_BIDIRECTIONAL:
    # Add reverse pairs (English -> Akkadian) for general data
    for pair in general_akkadian_data:
        bidirectional_data.append({
            'source': pair['target'],  # English as source
            'target': pair['source'],  # Akkadian as target
            'direction': 'en2ak',
        })
    # Also add reverse pairs for competition data
    for pair in old_assyrian_data:
        bidirectional_data.append({
            'source': pair['target'],
            'target': pair['source'],
            'direction': 'en2ak',
        })
    print(f"Bidirectional reverse pairs: {len(bidirectional_data)}")

# Mark forward pairs with direction
for pair in general_akkadian_data:
    pair['direction'] = 'ak2en'
for pair in old_assyrian_data:
    pair['direction'] = 'ak2en'
for pair in sentence_pairs:
    pair['direction'] = 'ak2en'

In [None]:
# ============================================================
# Prepare DataFrames for each stage
# ============================================================

# Stage 1: ALL data combined (general + competition + bidirectional)
stage1_all_list = (
    general_akkadian_data
    + old_assyrian_data
    + sentence_pairs
    + bidirectional_data
)
stage1_all = pd.DataFrame(stage1_all_list)
stage1_all = stage1_all.drop_duplicates(subset=['source', 'target']).reset_index(drop=True)
stage1_train, stage1_val = train_test_split(stage1_all, test_size=0.05, random_state=SEED)
print(f"Stage 1 - Train: {len(stage1_train)}, Val: {len(stage1_val)}")
if 'direction' in stage1_train.columns:
    print(f"  ak2en: {(stage1_train['direction'] == 'ak2en').sum()}, en2ak: {(stage1_train['direction'] == 'en2ak').sum()}")

# Stage 2: Old Assyrian (document-level + sentence-level)
stage2_all_list = old_assyrian_data + sentence_pairs
stage2_all = pd.DataFrame(stage2_all_list)
stage2_all = stage2_all.drop_duplicates(subset=['source', 'target']).reset_index(drop=True)
stage2_train, stage2_val = train_test_split(stage2_all, test_size=0.1, random_state=SEED)
print(f"Stage 2 - Train: {len(stage2_train)}, Val: {len(stage2_val)}")

# Show samples
print(f"\n--- Sample from Stage 1 (forward) ---")
fwd = stage1_train[stage1_train.get('direction', 'ak2en') == 'ak2en'].iloc[0]
print(f"SRC: {fwd['source'][:200]}")
print(f"TGT: {fwd['target'][:200]}")

if USE_BIDIRECTIONAL:
    rev = stage1_train[stage1_train['direction'] == 'en2ak']
    if len(rev) > 0:
        rev = rev.iloc[0]
        print(f"\n--- Sample from Stage 1 (reverse) ---")
        print(f"SRC: {rev['source'][:200]}")
        print(f"TGT: {rev['target'][:200]}")

print(f"\n--- Sample from Stage 2 ---")
row = stage2_train.iloc[0]
print(f"SRC: {row['source'][:200]}")
print(f"TGT: {row['target'][:200]}")

## Dataset & Model Setup

In [None]:
class AkkadianDataset(Dataset):
    """Dataset for Akkadian-English translation with directional prefixes."""

    def __init__(self, df, tokenizer, max_source_len, max_target_len,
                 prefix_ak2en=PREFIX_AK2EN, prefix_en2ak=PREFIX_EN2AK):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_source_len = max_source_len
        self.max_target_len = max_target_len
        self.prefix_ak2en = prefix_ak2en
        self.prefix_en2ak = prefix_en2ak
        self.has_direction = 'direction' in df.columns

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Determine prefix based on direction
        if self.has_direction and row.get('direction', 'ak2en') == 'en2ak':
            prefix = self.prefix_en2ak
        else:
            prefix = self.prefix_ak2en

        src = self.tokenizer(
            prefix + row['source'],
            max_length=self.max_source_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        tgt = self.tokenizer(
            row['target'],
            max_length=self.max_target_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        labels = tgt['input_ids'].squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100
        return {
            'input_ids': src['input_ids'].squeeze(),
            'attention_mask': src['attention_mask'].squeeze(),
            'labels': labels
        }

In [None]:
# Load model and tokenizer
print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)

# Enable gradient checkpointing to fit ByT5-base on P100 16GB
model.gradient_checkpointing_enable()
print("Gradient checkpointing enabled.")

model = model.to(DEVICE)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

if DEVICE == "cuda":
    mem_alloc = torch.cuda.memory_allocated() / 1e9
    mem_reserved = torch.cuda.memory_reserved() / 1e9
    print(f"GPU memory allocated: {mem_alloc:.2f} GB")
    print(f"GPU memory reserved: {mem_reserved:.2f} GB")

## Evaluation Functions

In [None]:
def evaluate(model, tokenizer, val_df, batch_size=8, max_source_len=MAX_SOURCE_LEN,
             max_target_len=MAX_TARGET_LEN, beam_width=BEAM_WIDTH,
             rep_penalty=REP_PENALTY):
    """Evaluate model on validation set using competition metric.
    
    Only evaluates ak2en direction (competition metric).
    """
    model.eval()
    predictions = []
    references = []

    # Filter to ak2en only for evaluation
    if 'direction' in val_df.columns:
        eval_df = val_df[val_df['direction'] == 'ak2en'].reset_index(drop=True)
    else:
        eval_df = val_df.reset_index(drop=True)

    if len(eval_df) == 0:
        print("  Warning: no ak2en pairs in validation set.")
        return 0.0, 0.0, 0.0, []

    for i in range(0, len(eval_df), batch_size):
        batch = eval_df.iloc[i:i + batch_size]
        batch_texts = [PREFIX_AK2EN + row['source'] for _, row in batch.iterrows()]
        batch_refs = [row['target'] for _, row in batch.iterrows()]

        inputs = tokenizer(
            batch_texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=max_source_len
        ).to(DEVICE)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_target_len,
                num_beams=beam_width,
                repetition_penalty=rep_penalty,
                length_penalty=1.0,
            )

        preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        predictions.extend(preds)
        references.extend(batch_refs)

    # Compute metrics
    bleu = sacrebleu.corpus_bleu(predictions, [references]).score
    chrf = sacrebleu.corpus_chrf(predictions, [references], word_order=2).score

    # Competition metric: geometric mean
    combined = math.sqrt(max(bleu, 0.001) * max(chrf, 0.001))

    model.train()
    return bleu, chrf, combined, predictions


print("Evaluation functions defined.")

## Training Loop with R-Drop

In [None]:
def compute_kl_divergence(logits_p, logits_q, labels, pad_id=-100):
    """Compute symmetric KL divergence between two sets of logits.
    
    Only computes KL over non-padding positions.
    """
    # Create mask for non-padding tokens
    mask = (labels != pad_id).float()  # (batch, seq_len)

    # Get log probabilities
    log_p = F.log_softmax(logits_p, dim=-1)  # (batch, seq_len, vocab)
    log_q = F.log_softmax(logits_q, dim=-1)
    p = log_p.exp()
    q = log_q.exp()

    # KL(p || q) = sum(p * (log_p - log_q))
    kl_pq = F.kl_div(log_q, p, reduction='none').sum(dim=-1)  # (batch, seq_len)
    kl_qp = F.kl_div(log_p, q, reduction='none').sum(dim=-1)

    # Apply mask and average
    kl_pq = (kl_pq * mask).sum() / mask.sum().clamp(min=1)
    kl_qp = (kl_qp * mask).sum() / mask.sum().clamp(min=1)

    return (kl_pq + kl_qp) / 2


def train_stage(model, tokenizer, train_df, val_df, epochs, lr, batch_size,
                grad_acc, save_dir, stage_name, warmup_steps=0,
                use_rdrop=False, rdrop_alpha=RDROP_ALPHA,
                save_all_checkpoints=False):
    """Train one stage with R-Drop, Adafactor, cosine annealing, and early stopping.

    Args:
        save_all_checkpoints: If True, save a checkpoint every epoch (for Stage 2 averaging).
    
    Returns:
        model: Best model loaded from checkpoint.
        best_score: Best combined metric score.
        checkpoint_scores: List of (epoch, score, path) for all saved checkpoints.
    """
    print(f"\n{'=' * 60}")
    print(f"  {stage_name}")
    print(f"  Train: {len(train_df)} | Val: {len(val_df)}")
    print(f"  Epochs: {epochs} | LR: {lr} | Batch: {batch_size} x {grad_acc}")
    print(f"  R-Drop: {use_rdrop} (alpha={rdrop_alpha})")
    print(f"  Warmup: {warmup_steps} steps")
    print(f"{'=' * 60}\n")

    os.makedirs(save_dir, exist_ok=True)

    train_ds = AkkadianDataset(train_df, tokenizer, MAX_SOURCE_LEN, MAX_TARGET_LEN)
    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=2, pin_memory=True, drop_last=False
    )

    # Use Adafactor optimizer (memory efficient, no momentum states)
    optimizer = Adafactor(
        model.parameters(),
        lr=lr,
        scale_parameter=False,
        relative_step=False,
        warmup_init=False,
        weight_decay=WEIGHT_DECAY,
    )

    # Cosine annealing with warmup
    num_training_steps = math.ceil(len(train_loader) / grad_acc) * epochs

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(
            max(1, num_training_steps - warmup_steps)
        )
        return max(0.01, 0.5 * (1.0 + math.cos(math.pi * progress)))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # Cross-entropy loss with label smoothing
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=LABEL_SMOOTHING)

    best_score = 0.0
    patience_counter = 0
    checkpoint_scores = []  # (epoch, score, path)
    global_step = 0

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        epoch_ce_loss = 0.0
        epoch_kl_loss = 0.0
        optimizer.zero_grad()

        pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
        for step, batch in enumerate(pbar):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            if use_rdrop:
                # R-Drop: two forward passes with different dropout masks
                outputs1 = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                outputs2 = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                # Average CE loss from both passes
                ce_loss = (outputs1.loss + outputs2.loss) / 2

                # KL divergence between the two output distributions
                kl_loss = compute_kl_divergence(
                    outputs1.logits, outputs2.logits, labels
                )

                loss = ce_loss + rdrop_alpha * kl_loss
                epoch_kl_loss += kl_loss.item()
            else:
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss
                ce_loss = loss

            scaled_loss = loss / grad_acc
            scaled_loss.backward()

            if (step + 1) % grad_acc == 0 or (step + 1) == len(train_loader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

            epoch_loss += loss.item()
            epoch_ce_loss += ce_loss.item()

            postfix = {'loss': f"{epoch_loss / (step + 1):.4f}"}
            if use_rdrop:
                postfix['kl'] = f"{epoch_kl_loss / (step + 1):.4f}"
            postfix['lr'] = f"{scheduler.get_last_lr()[0]:.2e}"
            pbar.set_postfix(postfix)

        avg_loss = epoch_loss / len(train_loader)
        avg_ce = epoch_ce_loss / len(train_loader)

        # Free GPU memory before evaluation
        torch.cuda.empty_cache() if DEVICE == 'cuda' else None

        # Evaluate
        print(f"\nEvaluating...")
        bleu, chrf, combined, _ = evaluate(model, tokenizer, val_df)
        log_msg = f"Epoch {epoch + 1}: Loss={avg_loss:.4f} CE={avg_ce:.4f}"
        if use_rdrop:
            log_msg += f" KL={epoch_kl_loss / len(train_loader):.4f}"
        log_msg += f" | BLEU={bleu:.2f} | chrF++={chrf:.2f} | Combined={combined:.2f}"
        print(log_msg)

        # Save checkpoint for this epoch (if enabled)
        if save_all_checkpoints:
            ckpt_path = os.path.join(save_dir, f"epoch_{epoch + 1}")
            model.save_pretrained(ckpt_path)
            tokenizer.save_pretrained(ckpt_path)
            checkpoint_scores.append((epoch + 1, combined, ckpt_path))
            print(f"  Checkpoint saved to {ckpt_path}")

        # Early stopping / best model tracking
        if combined > best_score:
            best_score = combined
            patience_counter = 0
            best_path = os.path.join(save_dir, "best")
            model.save_pretrained(best_path)
            tokenizer.save_pretrained(best_path)
            print(f"  >>> New best! Saved to {best_path} (score={combined:.2f})")
        else:
            patience_counter += 1
            print(f"  No improvement ({patience_counter}/{PATIENCE})")
            if patience_counter >= PATIENCE:
                print(f"  Early stopping triggered.")
                break

    # Reload best model
    best_path = os.path.join(save_dir, "best")
    print(f"\nLoading best model from {best_path} (score={best_score:.2f})")
    model = T5ForConditionalGeneration.from_pretrained(best_path).to(DEVICE)
    model.gradient_checkpointing_enable()

    return model, best_score, checkpoint_scores


print("Training function defined.")

## Stage 1: General Akkadian Training (with Bidirectional)

In [None]:
stage1_save = os.path.join(OUTPUT_DIR, "byt5_stage1")

model, stage1_score, _ = train_stage(
    model, tokenizer,
    train_df=stage1_train,
    val_df=stage1_val,
    epochs=STAGE1_EPOCHS,
    lr=STAGE1_LR,
    batch_size=STAGE1_BATCH,
    grad_acc=STAGE1_GRAD_ACC,
    save_dir=stage1_save,
    stage_name="STAGE 1: General Akkadian + Bidirectional",
    warmup_steps=STAGE1_WARMUP,
    use_rdrop=False,  # No R-Drop in Stage 1 (too slow with large data)
    save_all_checkpoints=False,
)

print(f"\nStage 1 complete! Best combined score: {stage1_score:.2f}")

# Free memory
gc.collect()
if DEVICE == 'cuda':
    torch.cuda.empty_cache()

## Stage 2: Old Assyrian Specialization (with R-Drop)

In [None]:
stage2_save = os.path.join(OUTPUT_DIR, "byt5_stage2")

model, stage2_score, stage2_ckpts = train_stage(
    model, tokenizer,
    train_df=stage2_train,
    val_df=stage2_val,
    epochs=STAGE2_EPOCHS,
    lr=STAGE2_LR,
    batch_size=STAGE2_BATCH,
    grad_acc=STAGE2_GRAD_ACC,
    save_dir=stage2_save,
    stage_name="STAGE 2: Old Assyrian + R-Drop",
    warmup_steps=STAGE2_WARMUP,
    use_rdrop=True,
    rdrop_alpha=RDROP_ALPHA,
    save_all_checkpoints=True,  # Save every epoch for averaging
)

print(f"\nStage 2 complete! Best combined score: {stage2_score:.2f}")
print(f"Saved {len(stage2_ckpts)} checkpoints.")
for ep, sc, path in stage2_ckpts:
    print(f"  Epoch {ep}: score={sc:.2f} -> {path}")

## Checkpoint Averaging

In [None]:
def average_checkpoints(checkpoint_paths, tokenizer_path, output_path, device=DEVICE):
    """Average the weights of multiple checkpoints into a single model.

    This technique reduces variance and often improves generalization.
    """
    print(f"\nAveraging {len(checkpoint_paths)} checkpoints...")
    for p in checkpoint_paths:
        print(f"  - {p}")

    # Load the first checkpoint as base
    avg_state = None

    for i, ckpt_path in enumerate(checkpoint_paths):
        m = T5ForConditionalGeneration.from_pretrained(ckpt_path)
        state = m.state_dict()

        if avg_state is None:
            avg_state = {k: v.clone().float() for k, v in state.items()}
        else:
            for k in avg_state:
                avg_state[k] += state[k].float()

        del m
        gc.collect()

    # Divide by number of checkpoints
    for k in avg_state:
        avg_state[k] /= len(checkpoint_paths)

    # Load a model and replace its weights
    model = T5ForConditionalGeneration.from_pretrained(checkpoint_paths[0])
    model.load_state_dict({k: v.to(model.dtype) for k, v in avg_state.items()})

    # Save averaged model
    os.makedirs(output_path, exist_ok=True)
    model.save_pretrained(output_path)

    # Copy tokenizer from one of the checkpoints
    tok = AutoTokenizer.from_pretrained(tokenizer_path)
    tok.save_pretrained(output_path)

    print(f"Averaged model saved to {output_path}")

    model = model.to(device)
    model.gradient_checkpointing_enable()
    return model


# Select top-N checkpoints by score
if stage2_ckpts and len(stage2_ckpts) >= 2:
    sorted_ckpts = sorted(stage2_ckpts, key=lambda x: x[1], reverse=True)
    top_ckpts = sorted_ckpts[:NUM_CKPTS_TO_AVG]
    print(f"\nTop {len(top_ckpts)} checkpoints for averaging:")
    for ep, sc, path in top_ckpts:
        print(f"  Epoch {ep}: score={sc:.2f}")

    ckpt_paths = [c[2] for c in top_ckpts]
    final_save = os.path.join(OUTPUT_DIR, "byt5_final_averaged")

    model = average_checkpoints(
        ckpt_paths,
        tokenizer_path=ckpt_paths[0],
        output_path=final_save
    )

    # Evaluate averaged model
    print("\nEvaluating averaged model...")
    bleu, chrf, combined, _ = evaluate(model, tokenizer, stage2_val)
    print(f"Averaged model: BLEU={bleu:.2f} | chrF++={chrf:.2f} | Combined={combined:.2f}")

    # If averaged model is worse than best single checkpoint, use the best single
    if combined < stage2_score:
        print(f"\nAveraged model ({combined:.2f}) is worse than best single ({stage2_score:.2f}).")
        print("Using best single checkpoint instead.")
        best_path = os.path.join(stage2_save, "best")
        model = T5ForConditionalGeneration.from_pretrained(best_path).to(DEVICE)
        model.gradient_checkpointing_enable()
        final_save = best_path
    else:
        print(f"\nAveraged model ({combined:.2f}) >= best single ({stage2_score:.2f}).")
        print("Using averaged model.")
        stage2_score = combined
else:
    print("\nNot enough checkpoints for averaging. Using best single checkpoint.")
    final_save = os.path.join(stage2_save, "best")

print(f"\nFinal model path: {final_save}")

## Validate: Show Sample Translations

In [None]:
# Show some validation translations
print("Sample Translations from Validation Set:\n")

# Filter to ak2en for display
if 'direction' in stage2_val.columns:
    display_val = stage2_val[stage2_val['direction'] == 'ak2en'].head(10)
else:
    display_val = stage2_val.head(10)

bleu, chrf, combined, preds = evaluate(model, tokenizer, display_val)

for i, (_, row) in enumerate(display_val.iterrows()):
    if i >= len(preds):
        break
    print(f"--- Example {i + 1} ---")
    print(f"SRC:  {row['source'][:150]}...")
    print(f"REF:  {row['target'][:150]}...")
    print(f"PRED: {preds[i][:150]}...")
    print()

print(f"\nFinal Scores: BLEU={bleu:.2f} | chrF++={chrf:.2f} | Combined={combined:.2f}")

## Generate Submission

In [None]:
test_path = os.path.join(COMP_DATA, "test.csv")
if os.path.exists(test_path):
    test_df = pd.read_csv(test_path)
    print(f"Test data: {len(test_df)} rows")
    print(test_df.head())

    # Clean test transliterations
    test_df['clean_src'] = test_df['transliteration'].apply(clean_akkadian)

    # Generate translations
    model.eval()
    predictions = []

    eval_batch_size = 8  # smaller batch for generation on P100

    for i in tqdm(range(0, len(test_df), eval_batch_size), desc="Generating"):
        batch_texts = [
            PREFIX_AK2EN + t
            for t in test_df['clean_src'].iloc[i:i + eval_batch_size]
        ]
        inputs = tokenizer(
            batch_texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=MAX_SOURCE_LEN
        ).to(DEVICE)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_TARGET_LEN,
                num_beams=BEAM_WIDTH,
                repetition_penalty=REP_PENALTY,
                length_penalty=1.0,
            )
        preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        predictions.extend(preds)

    # Create submission
    submission = pd.DataFrame({
        'id': test_df['id'],
        'translation': predictions
    })
    sub_path = os.path.join(OUTPUT_DIR, 'submission.csv')
    submission.to_csv(sub_path, index=False)
    print(f"\nSubmission saved to {sub_path}")
    print(submission.head())
else:
    print("No test data found. Run on Kaggle to generate submission.")

In [None]:
# ============================================================
# Save dataset metadata for Kaggle upload
# ============================================================
# When uploading to Kaggle as a dataset, include this metadata file

dataset_metadata = {
    "title": "deep-past-byt5-base-v2",
    "id": "your-username/deep-past-byt5-base-v2",
    "licenses": [{"name": "CC0-1.0"}]
}

# Save metadata to the final model directory
meta_path = os.path.join(final_save, "dataset-metadata.json")
with open(meta_path, 'w') as f:
    json.dump(dataset_metadata, f, indent=2)
print(f"Dataset metadata saved to {meta_path}")

In [None]:
print("\n" + "=" * 60)
print("  Training Complete!")
print("=" * 60)
print(f"Model: {MODEL_NAME}")
print(f"Stage 1 best score: {stage1_score:.2f}")
print(f"Stage 2 best score: {stage2_score:.2f}")
print(f"Final model: {final_save}")
print(f"\nImprovements used:")
print(f"  - ByT5-base (580M) with gradient checkpointing")
print(f"  - Adafactor optimizer")
print(f"  - Sentence-level alignment data ({len(sentence_pairs)} pairs)")
print(f"  - Bidirectional training ({len(bidirectional_data)} reverse pairs)")
print(f"  - R-Drop regularization (alpha={RDROP_ALPHA})")
print(f"  - Label smoothing ({LABEL_SMOOTHING})")
print(f"  - Cosine annealing LR with warmup")
print(f"  - Checkpoint averaging (top {NUM_CKPTS_TO_AVG})")
print(f"\nNext steps:")
print(f"  1. Upload {final_save} as a Kaggle dataset")
print(f"  2. Use the inference notebook to generate submission")
print(f"  3. Submit to competition")