# Deep Past Challenge: ByT5 Inference v2

Improvements over v1:
- Tuned beam search (num_beams=8, no_repeat_ngram_size=3, length_penalty=1.1)
- Post-processing (repetition removal, whitespace/unicode normalization)
- MBR decoding option (generate N samples, rerank with chrF++)

In [None]:
# sacrebleu not available offline - implement chrF++ inline if needed

In [None]:
import os
import re
import math
import unicodedata
from collections import Counter
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, T5ForConditionalGeneration

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
# ============================================================
# Configuration
# ============================================================
IS_KAGGLE = os.path.exists('/kaggle/input')

if IS_KAGGLE:
    COMP_DATA = '/kaggle/input/deep-past-initiative-machine-translation'
    MODEL_PATH = '/kaggle/input/byt5-akkadian-final'
else:
    COMP_DATA = 'data'
    MODEL_PATH = 'trained_model/byt5_stage2_final'

PREFIX = 'translate Akkadian to English: '
MAX_SOURCE_LEN = 512
MAX_TARGET_LEN = 512

# --- Beam Search (tuned) ---
NUM_BEAMS = 8
LENGTH_PENALTY = 1.1
REP_PENALTY = 1.2
NO_REPEAT_NGRAM = 3
BEAM_BATCH_SIZE = 8  # smaller batch for larger beam

# --- MBR Decoding ---
USE_MBR = True
MBR_NUM_SAMPLES = 16
MBR_TEMPERATURE = 1.0
MBR_EPSILON = 0.02  # epsilon sampling cutoff
MBR_BATCH_SIZE = 4  # batch size during sampling (each produces MBR_NUM_SAMPLES)

print('Config loaded.')

In [None]:
# ============================================================
# Preprocessing (must match training)
# ============================================================
SUBSCRIPT_MAP = str.maketrans('\u2080\u2081\u2082\u2083\u2084\u2085\u2086\u2087\u2088\u2089',
                              '0123456789')

ASCII_TO_DIACRITIC = {
    'sz': '\u0161', 'SZ': '\u0160', 'Sz': '\u0160',
    'sh': '\u0161', 'SH': '\u0160', 'Sh': '\u0160',
    's,': '\u1E63', 'S,': '\u1E62',
    't,': '\u1E6D', 'T,': '\u1E6C',
    '.s': '\u1E63', '.S': '\u1E62',
    '.t': '\u1E6D', '.T': '\u1E6C',
    'h,': '\u1E2B', 'H,': '\u1E2A',
    '.h': '\u1E2B', '.H': '\u1E2A',
}


def normalize_ascii(text):
    for old, new in ASCII_TO_DIACRITIC.items():
        text = text.replace(old, new)
    return text


def normalize_gaps(text):
    text = re.sub(r'\[x\]', '<gap>', text)
    text = re.sub(r'\[\.{3,}[^\]]*\]', '<big_gap>', text)
    text = re.sub(r'\.{3,}', '<big_gap>', text)
    text = re.sub(r'\u2026', '<big_gap>', text)
    return text


def clean_akkadian(text):
    if pd.isna(text) or not str(text).strip():
        return ''
    text = str(text)
    text = unicodedata.normalize('NFC', text)
    text = text.replace('!', '').replace('?', '')
    text = re.sub(r'[\u02F9\u02FA]', '', text)
    text = re.sub(r'\[([^\]]*)\]', r'\1', text)
    text = normalize_ascii(text)
    text = normalize_gaps(text)
    text = text.translate(SUBSCRIPT_MAP)
    text = re.sub(r'[/:.](?![\d])', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

print('Preprocessing functions defined.')

In [None]:
# ============================================================
# Post-Processing Pipeline
# ============================================================

def remove_repeated_phrases(text, max_ngram=8):
    """Remove consecutively repeated phrases (n-grams)."""
    words = text.split()
    if len(words) < 4:
        return text
    
    # Remove repeated n-grams from large to small
    for n in range(max_ngram, 1, -1):
        i = 0
        result = []
        while i < len(words):
            if i + 2 * n <= len(words):
                ngram = words[i:i+n]
                next_ngram = words[i+n:i+2*n]
                if ngram == next_ngram:
                    # Skip the duplicate
                    i += n
                    continue
            result.append(words[i])
            i += 1
        words = result
    
    return ' '.join(words)


def remove_trailing_repetition(text):
    """Remove trailing text that repeats earlier content."""
    words = text.split()
    if len(words) < 10:
        return text
    
    # Check if the last N words repeat a pattern from earlier
    best_len = len(words)
    for window in range(3, min(len(words) // 2, 20)):
        tail = words[-window:]
        # Search for this pattern earlier in the text
        for start in range(len(words) - window):
            if words[start:start+window] == tail:
                best_len = min(best_len, len(words) - window)
                break
    
    return ' '.join(words[:best_len])


def normalize_whitespace(text):
    """Normalize whitespace and unicode."""
    text = unicodedata.normalize('NFC', text)
    text = re.sub(r'\s+', ' ', text).strip()
    # Fix spacing around punctuation
    text = re.sub(r'\s+([.,;:!?)])', r'\1', text)
    text = re.sub(r'([({])\s+', r'\1', text)
    return text


def postprocess(text):
    """Full post-processing pipeline."""
    text = remove_repeated_phrases(text)
    text = remove_trailing_repetition(text)
    text = normalize_whitespace(text)
    return text


# Test
sample = 'When you wrote me as follows: When you wrote me as follows: When you wrote me as follows: I wrote my assistance.'
print(f'Before: {sample}')
print(f'After:  {postprocess(sample)}')

In [None]:
# ============================================================
# Load Model
# ============================================================
print(f'Loading model from {MODEL_PATH}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH)
model = model.to(DEVICE)
model.eval()
print('Model loaded successfully!')

In [None]:
# ============================================================
# Load Test Data
# ============================================================
test_df = pd.read_csv(os.path.join(COMP_DATA, 'test.csv'))
print(f'Test data: {len(test_df)} rows')
print(test_df.head())

test_df['clean_src'] = test_df['transliteration'].apply(clean_akkadian)
print(f'\nSample cleaned:')
print(test_df[['transliteration', 'clean_src']].head())

In [None]:
# ============================================================
# Inline chrF++ Implementation (no external deps)
# ============================================================

def extract_char_ngrams(text, n):
    """Extract character n-grams from text."""
    return [text[i:i+n] for i in range(len(text) - n + 1)]

def extract_word_ngrams(text, n):
    """Extract word n-grams from text."""
    words = text.split()
    return [' '.join(words[i:i+n]) for i in range(len(words) - n + 1)]

def chrf_score(hypothesis, reference, char_order=6, word_order=2, beta=2):
    """Compute chrF++ score between hypothesis and reference."""
    if not hypothesis or not reference:
        return 0.0
    
    total_f = 0.0
    count = 0
    
    # Character n-grams (orders 1 to char_order)
    for n in range(1, char_order + 1):
        hyp_ngrams = Counter(extract_char_ngrams(hypothesis, n))
        ref_ngrams = Counter(extract_char_ngrams(reference, n))
        
        if not hyp_ngrams or not ref_ngrams:
            continue
        
        common = sum((hyp_ngrams & ref_ngrams).values())
        hyp_total = sum(hyp_ngrams.values())
        ref_total = sum(ref_ngrams.values())
        
        precision = common / hyp_total if hyp_total > 0 else 0
        recall = common / ref_total if ref_total > 0 else 0
        
        if precision + recall > 0:
            f = (1 + beta**2) * precision * recall / (beta**2 * precision + recall)
            total_f += f
            count += 1
    
    # Word n-grams (orders 1 to word_order) for chrF++
    for n in range(1, word_order + 1):
        hyp_ngrams = Counter(extract_word_ngrams(hypothesis, n))
        ref_ngrams = Counter(extract_word_ngrams(reference, n))
        
        if not hyp_ngrams or not ref_ngrams:
            continue
        
        common = sum((hyp_ngrams & ref_ngrams).values())
        hyp_total = sum(hyp_ngrams.values())
        ref_total = sum(ref_ngrams.values())
        
        precision = common / hyp_total if hyp_total > 0 else 0
        recall = common / ref_total if ref_total > 0 else 0
        
        if precision + recall > 0:
            f = (1 + beta**2) * precision * recall / (beta**2 * precision + recall)
            total_f += f
            count += 1
    
    return (total_f / count * 100) if count > 0 else 0.0


def mbr_select(candidates, utility_fn=chrf_score):
    """Select best candidate via MBR (highest avg pairwise utility)."""
    n = len(candidates)
    if n == 1:
        return candidates[0]
    
    unique = list(set(candidates))
    if len(unique) == 1:
        return unique[0]
    
    scores = np.zeros(len(unique))
    for i, hyp in enumerate(unique):
        for j, ref in enumerate(unique):
            if i != j:
                scores[i] += utility_fn(hyp, ref)
        scores[i] /= (len(unique) - 1)
    
    return unique[np.argmax(scores)]


# Quick test
print(f"chrF++ test: {chrf_score('the cat sat on the mat', 'the cat is on the mat'):.1f}")
print('MBR functions defined.')

In [None]:
# ============================================================
# Generate Translations
# ============================================================
import time
start_time = time.time()

predictions = []
n_test = len(test_df)

if USE_MBR:
    print(f'Using MBR decoding with {MBR_NUM_SAMPLES} samples per sentence')
    print(f'Processing {n_test} sentences...')
    
    for i in tqdm(range(0, n_test, MBR_BATCH_SIZE), desc='MBR Translating'):
        batch_end = min(i + MBR_BATCH_SIZE, n_test)
        batch_texts = [PREFIX + t for t in test_df['clean_src'].iloc[i:batch_end]]
        batch_size_actual = len(batch_texts)
        
        inputs = tokenizer(
            batch_texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=MAX_SOURCE_LEN
        ).to(DEVICE)
        
        # Generate beam search candidate (1 per sentence)
        with torch.no_grad():
            beam_outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_TARGET_LEN,
                num_beams=NUM_BEAMS,
                repetition_penalty=REP_PENALTY,
                length_penalty=LENGTH_PENALTY,
                no_repeat_ngram_size=NO_REPEAT_NGRAM,
            )
        beam_preds = tokenizer.batch_decode(beam_outputs, skip_special_tokens=True)
        
        # Generate sampled candidates for MBR
        # Expand inputs for num_return_sequences
        with torch.no_grad():
            sample_outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_TARGET_LEN,
                do_sample=True,
                temperature=MBR_TEMPERATURE,
                epsilon_cutoff=MBR_EPSILON,
                num_return_sequences=MBR_NUM_SAMPLES,
                repetition_penalty=REP_PENALTY,
            )
        sample_preds = tokenizer.batch_decode(sample_outputs, skip_special_tokens=True)
        
        # MBR select for each sentence
        for j in range(batch_size_actual):
            # Gather all candidates: beam + samples
            candidates = [beam_preds[j]]
            for k in range(MBR_NUM_SAMPLES):
                candidates.append(sample_preds[j * MBR_NUM_SAMPLES + k])
            
            best = mbr_select(candidates, utility_fn=chrf_score)
            predictions.append(best)
        
        # Progress timing
        if (i // MBR_BATCH_SIZE) % 50 == 0 and i > 0:
            elapsed = time.time() - start_time
            rate = i / elapsed
            remaining = (n_test - i) / rate if rate > 0 else 0
            print(f'  [{i}/{n_test}] Elapsed: {elapsed/60:.1f}min, ETA: {remaining/60:.1f}min')

else:
    print(f'Using beam search (beams={NUM_BEAMS}, length_penalty={LENGTH_PENALTY})')
    
    for i in tqdm(range(0, n_test, BEAM_BATCH_SIZE), desc='Translating'):
        batch_texts = [PREFIX + t for t in test_df['clean_src'].iloc[i:i+BEAM_BATCH_SIZE]]
        
        inputs = tokenizer(
            batch_texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=MAX_SOURCE_LEN
        ).to(DEVICE)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_TARGET_LEN,
                num_beams=NUM_BEAMS,
                repetition_penalty=REP_PENALTY,
                length_penalty=LENGTH_PENALTY,
                no_repeat_ngram_size=NO_REPEAT_NGRAM,
            )
        
        preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        predictions.extend(preds)

elapsed = time.time() - start_time
print(f'\nGenerated {len(predictions)} translations in {elapsed/60:.1f} minutes')

In [None]:
# ============================================================
# Apply Post-Processing
# ============================================================
print('Applying post-processing...')
processed = [postprocess(p) for p in predictions]

# Show before/after for a few examples
for i in range(min(3, len(predictions))):
    if predictions[i] != processed[i]:
        print(f'\n--- Example {i} ---')
        print(f'Before: {predictions[i][:200]}')
        print(f'After:  {processed[i][:200]}')

predictions = processed
print(f'\nPost-processing complete.')

In [None]:
# ============================================================
# Create Submission
# ============================================================
submission = pd.DataFrame({
    'id': test_df['id'],
    'translation': predictions
})

submission['translation'] = submission['translation'].fillna('')

submission.to_csv('submission.csv', index=False)
print('Submission saved to submission.csv')
print(f'Shape: {submission.shape}')
print(submission.head(10))

In [None]:
# Show sample translations
print('\n' + '='*60)
print('  Sample Translations')
print('='*60)
for i in range(min(10, len(test_df))):
    print(f'\n--- Test {i} ---')
    print(f'SRC:  {test_df.iloc[i]["transliteration"][:200]}')
    print(f'PRED: {predictions[i][:200]}')