# Deep Past Challenge: ByT5 Two-Stage Training

**Approach**: Fine-tune `google/byt5-base` in two stages:
1. **Stage 1**: General Akkadian (Akkademia + ORACC + competition data)
2. **Stage 2**: Old Assyrian specialization (competition data only)

**Metric**: `sqrt(BLEU * chrF++)`

In [None]:
!pip install -q sacrebleu transformers accelerate sentencepiece

In [None]:
import warnings
warnings.simplefilter("ignore")

import os
import re
import math
import unicodedata
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, T5ForConditionalGeneration
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import sacrebleu

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")
if DEVICE == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    props = torch.cuda.get_device_properties(0)
    print(f"Memory: {props.total_memory / 1e9:.1f} GB")

## Configuration

In [None]:
# ============================================================
# Paths - adjust for Kaggle vs local
# ============================================================
IS_KAGGLE = os.path.exists("/kaggle/input")

if IS_KAGGLE:
    COMP_DATA = "/kaggle/input/deep-past-initiative-machine-translation"
    AKKADEMIA_PATH = "/kaggle/input/akkademia-nmt"  # Files at root: train.tr, train.en, etc.
    ORACC_PATH = "/kaggle/input/oracc-akkadian-english-parallel-corpus"
    OUTPUT_DIR = "/kaggle/working"
else:
    COMP_DATA = "data"
    AKKADEMIA_PATH = "external_data/akkademia/NMT_input"
    ORACC_PATH = "external_data/oracc"
    OUTPUT_DIR = "output"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================================================
# Model & Training Config
# ============================================================
# NOTE: byt5-base (580M) OOMs on P100 16GB. Using byt5-small (300M) instead.
MODEL_NAME = "google/byt5-small"  # 300M params, byte-level, fits on P100

# Data sources to use
USE_AKKADEMIA = True    # ~50K general Akkadian pairs
USE_ORACC = True        # ~2K ORACC pairs
USE_COMP_DATA = True    # ~1.5K competition pairs (Old Assyrian)

# Shared - reduced for P100 memory
MAX_SOURCE_LEN = 384
MAX_TARGET_LEN = 384
PREFIX = "translate Akkadian to English: "

# Stage 1: General Akkadian
STAGE1_EPOCHS = 3
STAGE1_LR = 2e-4
STAGE1_BATCH = 4
STAGE1_GRAD_ACC = 8  # effective batch = 32

# Stage 2: Old Assyrian specialization
STAGE2_EPOCHS = 10
STAGE2_LR = 5e-5
STAGE2_BATCH = 4
STAGE2_GRAD_ACC = 4  # effective batch = 16

# Early stopping
PATIENCE = 3

# Inference
BEAM_WIDTH = 4
REP_PENALTY = 1.2

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

## Preprocessing

In [None]:
# Subscript digit mapping
SUBSCRIPT_MAP = str.maketrans("\u2080\u2081\u2082\u2083\u2084\u2085\u2086\u2087\u2088\u2089",
                              "0123456789")

# ASCII to diacritic normalization for Akkadian transliteration
ASCII_TO_DIACRITIC = {
    "sz": "\u0161", "SZ": "\u0160", "Sz": "\u0160",
    "sh": "\u0161", "SH": "\u0160", "Sh": "\u0160",
    "s,": "\u1E63", "S,": "\u1E62",
    "t,": "\u1E6D", "T,": "\u1E6C",
    ".s": "\u1E63", ".S": "\u1E62",
    ".t": "\u1E6D", ".T": "\u1E6C",
    "h,": "\u1E2B", "H,": "\u1E2A",
    ".h": "\u1E2B", ".H": "\u1E2A",
}


def normalize_ascii(text):
    """Normalize ASCII representations to proper Akkadian diacritics."""
    for old, new in ASCII_TO_DIACRITIC.items():
        text = text.replace(old, new)
    return text


def normalize_gaps(text):
    """Normalize gap markers."""
    text = re.sub(r'\[x\]', '<gap>', text)
    text = re.sub(r'\[\.{3,}[^\]]*\]', '<big_gap>', text)
    text = re.sub(r'\.{3,}', '<big_gap>', text)
    text = re.sub(r'\u2026', '<big_gap>', text)  # ellipsis character
    return text


def clean_akkadian(text):
    """Clean and normalize Akkadian transliteration."""
    if pd.isna(text) or not str(text).strip():
        return ""
    text = str(text)
    text = unicodedata.normalize("NFC", text)
    # Remove editorial marks
    text = text.replace("!", "").replace("?", "")
    text = re.sub(r'[\u02F9\u02FA]', '', text)  # ˹ ˺ half-brackets
    # Remove square brackets but keep content
    text = re.sub(r'\[([^\]]*)\]', r'\1', text)
    # Normalize
    text = normalize_ascii(text)
    text = normalize_gaps(text)
    text = text.translate(SUBSCRIPT_MAP)
    # Remove stray slashes, colons, dots that are editorial
    text = re.sub(r'[/:.](?![\d])', ' ', text)  # keep dots before numbers
    # Collapse whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text


def clean_english(text):
    """Clean English translation text."""
    if pd.isna(text) or not str(text).strip():
        return ""
    text = str(text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text


print("Preprocessing functions defined.")

# Quick test
sample = "KIŠIB ma-nu-ba-lúm-a-šur DUMU ṣí-lá-(d)IM [x] ... i-ša-qal!"
print(f"Original: {sample}")
print(f"Cleaned:  {clean_akkadian(sample)}")

## Load & Prepare Data

In [None]:
# ============================================================
# 1. Competition data (Old Assyrian - used in both stages)
# ============================================================
old_assyrian_data = []

if USE_COMP_DATA:
    comp_df = pd.read_csv(os.path.join(COMP_DATA, "train.csv"))
    print(f"Competition data: {len(comp_df)} documents")
    
    for _, row in comp_df.iterrows():
        src = clean_akkadian(row['transliteration'])
        tgt = clean_english(row['translation'])
        if src and tgt and len(src) > 10 and len(tgt) > 10:
            old_assyrian_data.append({'source': src, 'target': tgt})
    
    print(f"  After cleaning: {len(old_assyrian_data)} pairs")

# ============================================================
# 2. Akkademia data (mixed periods - Stage 1 only)
# ============================================================
general_akkadian_data = []

if USE_AKKADEMIA and os.path.exists(AKKADEMIA_PATH):
    for split in ['train', 'valid', 'test']:
        tr_path = os.path.join(AKKADEMIA_PATH, f'{split}.tr')
        en_path = os.path.join(AKKADEMIA_PATH, f'{split}.en')
        if os.path.exists(tr_path) and os.path.exists(en_path):
            translits = open(tr_path, encoding='utf-8').read().splitlines()
            transls = open(en_path, encoding='utf-8').read().splitlines()
            for tr, en in zip(translits, transls):
                src = clean_akkadian(tr)
                tgt = clean_english(en)
                if src and tgt and len(src) > 5 and len(tgt) > 5:
                    general_akkadian_data.append({'source': src, 'target': tgt})
    print(f"Akkademia data: {len(general_akkadian_data)} pairs")
else:
    print("Akkademia data not found, skipping.")

# ============================================================
# 3. ORACC parallel corpus
# ============================================================
if USE_ORACC:
    oracc_csv = os.path.join(ORACC_PATH, "train.csv")
    if os.path.exists(oracc_csv):
        oracc_df = pd.read_csv(oracc_csv)
        print(f"ORACC data: {len(oracc_df)} pairs")
        for _, row in oracc_df.iterrows():
            src = clean_akkadian(row.get('akkadian', ''))
            tgt = clean_english(row.get('english', ''))
            if src and tgt and len(src) > 5 and len(tgt) > 5:
                general_akkadian_data.append({'source': src, 'target': tgt})
        print(f"  Total general data after ORACC: {len(general_akkadian_data)} pairs")
    else:
        print("ORACC data not found, skipping.")

print(f"\n=== Data Summary ===")
print(f"Old Assyrian (competition): {len(old_assyrian_data)} pairs")
print(f"General Akkadian (external): {len(general_akkadian_data)} pairs")
print(f"Total available: {len(old_assyrian_data) + len(general_akkadian_data)} pairs")

In [None]:
# ============================================================
# Prepare DataFrames for each stage
# ============================================================

# Stage 1: ALL data combined
stage1_all = pd.DataFrame(general_akkadian_data + old_assyrian_data)
stage1_all = stage1_all.drop_duplicates(subset=['source']).reset_index(drop=True)
stage1_train, stage1_val = train_test_split(stage1_all, test_size=0.05, random_state=SEED)
print(f"Stage 1 - Train: {len(stage1_train)}, Val: {len(stage1_val)}")

# Stage 2: Only Old Assyrian
stage2_all = pd.DataFrame(old_assyrian_data)
stage2_all = stage2_all.drop_duplicates(subset=['source']).reset_index(drop=True)
stage2_train, stage2_val = train_test_split(stage2_all, test_size=0.1, random_state=SEED)
print(f"Stage 2 - Train: {len(stage2_train)}, Val: {len(stage2_val)}")

# Show samples
print(f"\n--- Sample from Stage 1 ---")
row = stage1_train.iloc[0]
print(f"SRC: {row['source'][:200]}")
print(f"TGT: {row['target'][:200]}")
print(f"\n--- Sample from Stage 2 ---")
row = stage2_train.iloc[0]
print(f"SRC: {row['source'][:200]}")
print(f"TGT: {row['target'][:200]}")

## Dataset & Model Setup

In [None]:
class AkkadianDataset(Dataset):
    def __init__(self, df, tokenizer, max_source_len, max_target_len, prefix):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_source_len = max_source_len
        self.max_target_len = max_target_len
        self.prefix = prefix

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        src = self.tokenizer(
            self.prefix + row['source'],
            max_length=self.max_source_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        tgt = self.tokenizer(
            row['target'],
            max_length=self.max_target_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        labels = tgt['input_ids'].squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100
        return {
            'input_ids': src['input_ids'].squeeze(),
            'attention_mask': src['attention_mask'].squeeze(),
            'labels': labels
        }

In [None]:
# Load model and tokenizer
print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
model = model.to(DEVICE)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## Evaluation Functions

In [None]:
def evaluate(model, tokenizer, val_df, batch_size=16):
    """Evaluate model on validation set using competition metric."""
    model.eval()
    predictions = []
    references = []
    
    for i in range(0, len(val_df), batch_size):
        batch = val_df.iloc[i:i+batch_size]
        batch_texts = [PREFIX + row['source'] for _, row in batch.iterrows()]
        batch_refs = [row['target'] for _, row in batch.iterrows()]
        
        inputs = tokenizer(
            batch_texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=MAX_SOURCE_LEN
        ).to(DEVICE)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_TARGET_LEN,
                num_beams=BEAM_WIDTH,
                repetition_penalty=REP_PENALTY,
                length_penalty=1.0,
            )
        
        preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        predictions.extend(preds)
        references.extend(batch_refs)
    
    # Compute metrics
    bleu = sacrebleu.corpus_bleu(predictions, [references]).score
    chrf = sacrebleu.corpus_chrf(predictions, [references], word_order=2).score
    
    # Competition metric: geometric mean
    combined = math.sqrt(max(bleu, 0.001) * max(chrf, 0.001))
    
    model.train()
    return bleu, chrf, combined, predictions


print("Evaluation functions defined.")

## Training Loop

In [None]:
def train_stage(model, tokenizer, train_df, val_df, epochs, lr, batch_size, grad_acc, 
                save_path, stage_name):
    """Train one stage with early stopping."""
    print(f"\n{'='*60}")
    print(f"  {stage_name}")
    print(f"  Train: {len(train_df)} | Val: {len(val_df)}")
    print(f"  Epochs: {epochs} | LR: {lr} | Batch: {batch_size} x {grad_acc}")
    print(f"{'='*60}\n")
    
    train_ds = AkkadianDataset(train_df, tokenizer, MAX_SOURCE_LEN, MAX_TARGET_LEN, PREFIX)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, 
                              num_workers=2, pin_memory=True)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    num_steps = math.ceil(len(train_loader) / grad_acc) * epochs
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps, eta_min=lr * 0.01)
    
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=0.01)
    
    best_score = 0.0
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        optimizer.zero_grad()
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for step, batch in enumerate(pbar):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs.loss / grad_acc
            loss.backward()
            
            if (step + 1) % grad_acc == 0 or (step + 1) == len(train_loader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            epoch_loss += outputs.loss.item()
            pbar.set_postfix({'loss': f"{epoch_loss / (step + 1):.4f}"})
        
        avg_loss = epoch_loss / len(train_loader)
        
        # Evaluate
        print(f"\nEvaluating...")
        bleu, chrf, combined, _ = evaluate(model, tokenizer, val_df)
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f} | BLEU={bleu:.2f} | chrF++={chrf:.2f} | Combined={combined:.2f}")
        
        # Early stopping
        if combined > best_score:
            best_score = combined
            patience_counter = 0
            # Save best model
            model.save_pretrained(save_path)
            tokenizer.save_pretrained(save_path)
            print(f"  >>> New best! Saved to {save_path}")
        else:
            patience_counter += 1
            print(f"  No improvement ({patience_counter}/{PATIENCE})")
            if patience_counter >= PATIENCE:
                print(f"  Early stopping triggered.")
                break
    
    # Reload best model
    print(f"\nLoading best model from {save_path} (score={best_score:.2f})")
    model = T5ForConditionalGeneration.from_pretrained(save_path).to(DEVICE)
    return model, best_score


print("Training function defined.")

## Stage 1: General Akkadian Training

In [None]:
stage1_save = os.path.join(OUTPUT_DIR, "byt5_stage1")

model, stage1_score = train_stage(
    model, tokenizer,
    train_df=stage1_train,
    val_df=stage1_val,
    epochs=STAGE1_EPOCHS,
    lr=STAGE1_LR,
    batch_size=STAGE1_BATCH,
    grad_acc=STAGE1_GRAD_ACC,
    save_path=stage1_save,
    stage_name="STAGE 1: General Akkadian Pre-training"
)

print(f"\nStage 1 complete! Best combined score: {stage1_score:.2f}")

## Stage 2: Old Assyrian Specialization

In [None]:
stage2_save = os.path.join(OUTPUT_DIR, "byt5_stage2_final")

model, stage2_score = train_stage(
    model, tokenizer,
    train_df=stage2_train,
    val_df=stage2_val,
    epochs=STAGE2_EPOCHS,
    lr=STAGE2_LR,
    batch_size=STAGE2_BATCH,
    grad_acc=STAGE2_GRAD_ACC,
    save_path=stage2_save,
    stage_name="STAGE 2: Old Assyrian Specialization"
)

print(f"\nStage 2 complete! Best combined score: {stage2_score:.2f}")

## Validate: Show Sample Translations

In [None]:
# Show some validation translations
print("Sample Translations from Validation Set:\n")
bleu, chrf, combined, preds = evaluate(model, tokenizer, stage2_val.head(10))

for i, (_, row) in enumerate(stage2_val.head(10).iterrows()):
    print(f"--- Example {i+1} ---")
    print(f"SRC: {row['source'][:150]}...")
    print(f"REF: {row['target'][:150]}...")
    print(f"PRED: {preds[i][:150]}...")
    print()

print(f"\nFinal Scores: BLEU={bleu:.2f} | chrF++={chrf:.2f} | Combined={combined:.2f}")

## Generate Submission (if test data available)

In [None]:
test_path = os.path.join(COMP_DATA, "test.csv")
if os.path.exists(test_path):
    test_df = pd.read_csv(test_path)
    print(f"Test data: {len(test_df)} rows")
    print(test_df.head())
    
    # Clean test transliterations
    test_df['clean_src'] = test_df['transliteration'].apply(clean_akkadian)
    
    # Generate translations
    model.eval()
    predictions = []
    
    for i in tqdm(range(0, len(test_df), 16), desc="Generating"):
        batch_texts = [PREFIX + t for t in test_df['clean_src'].iloc[i:i+16]]
        inputs = tokenizer(
            batch_texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=MAX_SOURCE_LEN
        ).to(DEVICE)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_TARGET_LEN,
                num_beams=BEAM_WIDTH,
                repetition_penalty=REP_PENALTY,
                length_penalty=1.0,
            )
        preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        predictions.extend(preds)
    
    # Create submission
    submission = pd.DataFrame({
        'id': test_df['id'],
        'translation': predictions
    })
    sub_path = os.path.join(OUTPUT_DIR, 'submission.csv')
    submission.to_csv(sub_path, index=False)
    print(f"\nSubmission saved to {sub_path}")
    print(submission.head())
else:
    print("No test data found. Run on Kaggle to generate submission.")

In [None]:
print("\n" + "="*60)
print("  Training Complete!")
print("="*60)
print(f"Stage 1 best score: {stage1_score:.2f}")
print(f"Stage 2 best score: {stage2_score:.2f}")
print(f"Model saved to: {stage2_save}")
print(f"\nNext steps:")
print(f"  1. Upload the saved model as a Kaggle dataset")
print(f"  2. Use the inference notebook to generate submission")
print(f"  3. Submit to competition")