In [1]:
# ============================================================
# AKKADIAN TRANSLATION - ULTRA-OPTIMIZED INFERENCE
# Output: submission.csv with columns [id, translation]
# ============================================================

try:
    from optimum.bettertransformer import BetterTransformer
    print("‚úÖ BetterTransformer available!")
except ImportError:
    print("‚ùå BetterTransformer NOT available")

import os
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['MKL_NUM_THREADS'] = '4'
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
os.environ['TORCH_CUDNN_V8_API_ENABLED'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

import re
import logging
import warnings
from pathlib import Path
from typing import List
from dataclasses import dataclass
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.cuda.amp import autocast
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm.auto import tqdm
import json
import random

warnings.filterwarnings('ignore')

print(f"PyTorch: {torch.__version__} | CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} | {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")

# ============================================================
# CONFIGURATION
# ============================================================
@dataclass
class UltraConfig:
    # Paths
    test_data_path: str = "/kaggle/input/deep-past-initiative-machine-translation/test.csv"
    model_path: str = "/kaggle/input/final-byt5/byt5-akkadian-optimized-34x"
    output_dir: str = "/kaggle/working/"
    
    # Processing
    max_length: int = 512
    batch_size: int = 8
    num_workers: int = 4
    
    # Generation (tuned parameters - do not change)
    num_beams: int = 8
    max_new_tokens: int = 512
    length_penalty: float = 1.3
    early_stopping: bool = True
    no_repeat_ngram_size: int = 0
    
    # Optimizations
    use_mixed_precision: bool = True
    use_better_transformer: bool = True
    use_bucket_batching: bool = True
    use_vectorized_postproc: bool = True
    use_adaptive_beams: bool = True
    use_auto_batch_size: bool = False
    
    tiny_dataset_threshold: int = 64
    num_buckets: int = 4
    aggressive_postprocessing: bool = True
    checkpoint_freq: int = 100
    
    def __post_init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        Path(self.output_dir).mkdir(exist_ok=True, parents=True)
        if not torch.cuda.is_available():
            self.use_mixed_precision = False
            self.use_better_transformer = False

config = UltraConfig()

# ============================================================
# LOGGING
# ============================================================
def setup_logging(output_dir: str):
    Path(output_dir).mkdir(exist_ok=True, parents=True)
    for h in logging.root.handlers[:]:
        logging.root.removeHandler(h)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[logging.StreamHandler(), logging.FileHandler(Path(output_dir)/'inference.log')]
    )
    return logging.getLogger(__name__)

logger = setup_logging(config.output_dir)

# ============================================================
# PREPROCESSOR
# ============================================================
class OptimizedPreprocessor:
    def __init__(self):
        self.patterns = {
            'big_gap': re.compile(r'(\.{3,}|‚Ä¶+|‚Ä¶‚Ä¶)'),
            'small_gap': re.compile(r'(xx+|\s+x\s+)'),
        }
    
    def preprocess_batch(self, texts: List[str]) -> List[str]:
        s = pd.Series(texts).fillna("").astype(str)
        s = s.str.replace(self.patterns['big_gap'], '<big_gap>', regex=True)
        s = s.str.replace(self.patterns['small_gap'], '<gap>', regex=True)
        return s.tolist()

# ============================================================
# POSTPROCESSOR
# ============================================================
class VectorizedPostprocessor:
    def __init__(self, aggressive: bool = True):
        self.aggressive = aggressive
        self.patterns = {
            'gap': re.compile(r'(\[x\]|\(x\)|\bx\b)', re.I),
            'big_gap': re.compile(r'(\.{3,}|‚Ä¶|\[\.+\])'),
            'annotations': re.compile(r'\((fem|plur|pl|sing|singular|plural|\?|!)\..\s*\w*\)', re.I),
            'repeated_words': re.compile(r'\b(\w+)(?:\s+\1\b)+'),
            'whitespace': re.compile(r'\s+'),
            'punct_space': re.compile(r'\s+([.,:])'),
            'repeated_punct': re.compile(r'([.,])\1+'),
        }
        self.subscript_trans = str.maketrans("‚ÇÄ‚ÇÅ‚ÇÇ‚ÇÉ‚ÇÑ‚ÇÖ‚ÇÜ‚Çá‚Çà‚Çâ", "0123456789")
        self.special_chars_trans = str.maketrans('·∏´·∏™', 'hH')
        self.forbidden_trans = str.maketrans('', '', '!?()"‚Äî‚Äî<>‚åà‚åã‚åä[]+ æ/;')
    
    def postprocess_batch(self, translations: List[str]) -> List[str]:
        s = pd.Series(translations)
        s[~s.apply(lambda x: isinstance(x, str) and bool(x.strip()))] = ""
        
        s = s.str.translate(self.special_chars_trans)
        s = s.str.translate(self.subscript_trans)
        s = s.str.replace(self.patterns['whitespace'], ' ', regex=True).str.strip()
        
        if self.aggressive:
            s = s.str.replace(self.patterns['gap'], '<gap>', regex=True)
            s = s.str.replace(self.patterns['big_gap'], '<big_gap>', regex=True)
            s = s.str.replace('<gap> <gap>', '<big_gap>', regex=False)
            s = s.str.replace('<big_gap> <big_gap>', '<big_gap>', regex=False)
            s = s.str.replace(self.patterns['annotations'], '', regex=True)
            
            # Protect gaps during forbidden char removal
            s = s.str.replace('<gap>', '\x00GAP\x00', regex=False)
            s = s.str.replace('<big_gap>', '\x00BIG\x00', regex=False)
            s = s.str.translate(self.forbidden_trans)
            s = s.str.replace('\x00GAP\x00', ' <gap> ', regex=False)
            s = s.str.replace('\x00BIG\x00', ' <big_gap> ', regex=False)
            
            # Fractions
            for pat, repl in [(r'(\d+)\.5\b', r'\1¬Ω'), (r'\b0\.5\b', '¬Ω'),
                              (r'(\d+)\.25\b', r'\1¬º'), (r'\b0\.25\b', '¬º'),
                              (r'(\d+)\.75\b', r'\1¬æ'), (r'\b0\.75\b', '¬æ')]:
                s = s.str.replace(pat, repl, regex=True)
            
            # Remove repeated words/ngrams
            s = s.str.replace(self.patterns['repeated_words'], r'\1', regex=True)
            for n in range(4, 1, -1):
                pat = r'\b((?:\w+\s+){' + str(n-1) + r'}\w+)(?:\s+\1\b)+'
                s = s.str.replace(pat, r'\1', regex=True)
            
            s = s.str.replace(self.patterns['punct_space'], r'\1', regex=True)
            s = s.str.replace(self.patterns['repeated_punct'], r'\1', regex=True)
            s = s.str.replace(self.patterns['whitespace'], ' ', regex=True)
            s = s.str.strip().str.strip('-').str.strip()
        
        return s.tolist()

# ============================================================
# DATASET & SAMPLER
# ============================================================
class AkkadianDataset(Dataset):
    def __init__(self, df: pd.DataFrame, preprocessor):
        self.sample_ids = df['id'].tolist()
        preprocessed = preprocessor.preprocess_batch(df['transliteration'].tolist())
        self.input_texts = ["translate Akkadian to English: " + t for t in preprocessed]
        logger.info(f"Dataset: {len(self.sample_ids)} samples")
    
    def __len__(self): return len(self.sample_ids)
    def __getitem__(self, i): return self.sample_ids[i], self.input_texts[i]

class BucketBatchSampler(Sampler):
    def __init__(self, dataset, batch_size: int, num_buckets: int = 4):
        lengths = [len(t.split()) for _, t in dataset]
        sorted_idx = sorted(range(len(lengths)), key=lambda i: lengths[i])
        bucket_size = max(1, len(sorted_idx) // num_buckets)
        self.buckets = [sorted_idx[i*bucket_size:(i+1)*bucket_size if i < num_buckets-1 else None]
                        for i in range(num_buckets) if sorted_idx[i*bucket_size:]]
        self.batch_size = batch_size
    
    def __iter__(self):
        for bucket in self.buckets:
            for i in range(0, len(bucket), self.batch_size):
                yield bucket[i:i+self.batch_size]
    
    def __len__(self):
        return sum((len(b) + self.batch_size - 1) // self.batch_size for b in self.buckets)

# ============================================================
# INFERENCE ENGINE
# ============================================================
class UltraInferenceEngine:
    def __init__(self, config: UltraConfig):
        self.config = config
        self.preprocessor = OptimizedPreprocessor()
        self.postprocessor = VectorizedPostprocessor(aggressive=config.aggressive_postprocessing)
        self._load_model()
    
    def _load_model(self):
        logger.info(f"Loading model: {self.config.model_path}")
        self.model = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_path).to(self.config.device).eval()
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
        logger.info(f"Model: {sum(p.numel() for p in self.model.parameters()):,} params")
        
        if self.config.use_better_transformer and torch.cuda.is_available():
            try:
                from optimum.bettertransformer import BetterTransformer
                self.model = BetterTransformer.transform(self.model)
                logger.info("‚úÖ BetterTransformer applied")
            except Exception as e:
                logger.warning(f"‚ö†Ô∏è BetterTransformer skipped: {e}")
                self.config.use_better_transformer = False
    
    def _collate_fn(self, batch):
        ids = [s[0] for s in batch]
        texts = [s[1] for s in batch]
        tok = self.tokenizer(texts, max_length=self.config.max_length, padding=True, truncation=True, return_tensors="pt")
        return ids, tok
    
    def run_inference(self, test_df: pd.DataFrame) -> pd.DataFrame:
        logger.info("üöÄ Starting inference")
        
        # Simplify for tiny datasets
        if len(test_df) < self.config.tiny_dataset_threshold:
            self.config.use_bucket_batching = False
            self.config.num_workers = 0
        
        dataset = AkkadianDataset(test_df, self.preprocessor)
        
        if self.config.use_bucket_batching:
            sampler = BucketBatchSampler(dataset, self.config.batch_size, self.config.num_buckets)
            loader = DataLoader(dataset, batch_sampler=sampler, num_workers=self.config.num_workers,
                                collate_fn=self._collate_fn, pin_memory=True)
        else:
            loader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=False,
                                num_workers=self.config.num_workers, collate_fn=self._collate_fn, pin_memory=True)
        
        gen_cfg = {
            "max_new_tokens": self.config.max_new_tokens,
            "length_penalty": self.config.length_penalty,
            "early_stopping": self.config.early_stopping,
            "use_cache": True,
            "num_beams": self.config.num_beams,
        }
        if self.config.no_repeat_ngram_size > 0:
            gen_cfg["no_repeat_ngram_size"] = self.config.no_repeat_ngram_size
        
        results = []
        with torch.inference_mode():
            for batch_ids, tok in tqdm(loader, desc="üöÄ Translating"):
                input_ids = tok.input_ids.to(self.config.device)
                attn_mask = tok.attention_mask.to(self.config.device)
                
                # Adaptive beams: fewer for short sequences
                if self.config.use_adaptive_beams:
                    avg_len = attn_mask.sum(dim=1).float().mean().item()
                    gen_cfg["num_beams"] = max(4, self.config.num_beams // 2) if avg_len < 100 else self.config.num_beams
                
                if self.config.use_mixed_precision:
                    with autocast():
                        outputs = self.model.generate(input_ids=input_ids, attention_mask=attn_mask, **gen_cfg)
                else:
                    outputs = self.model.generate(input_ids=input_ids, attention_mask=attn_mask, **gen_cfg)
                
                translations = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
                cleaned = self.postprocessor.postprocess_batch(translations)
                results.extend(zip(batch_ids, cleaned))
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        logger.info("‚úÖ Inference complete")
        return pd.DataFrame(results, columns=['id', 'translation'])

# ============================================================
# MAIN EXECUTION
# ============================================================
logger.info(f"Loading test data: {config.test_data_path}")
test_df = pd.read_csv(config.test_data_path, encoding='utf-8')
logger.info(f"‚úÖ Loaded {len(test_df)} samples")

engine = UltraInferenceEngine(config)
results_df = engine.run_inference(test_df)

# Save submission.csv
output_path = Path(config.output_dir) / 'submission.csv'
results_df.to_csv(output_path, index=False)
logger.info(f"‚úÖ Saved: {output_path}")

# Validation
print("\n" + "="*60)
print("üìä SUBMISSION SUMMARY")
print("="*60)
print(f"Total rows: {len(results_df)}")
print(f"Empty translations: {results_df['translation'].str.strip().eq('').sum()}")
print(f"Length range: [{results_df['translation'].str.len().min()}, {results_df['translation'].str.len().max()}]")
print(f"\nFirst 3:")
print(results_df.head(3).to_string(index=False))
print("="*60)


‚ùå BetterTransformer NOT available


2026-02-01 10:58:53,968 - INFO - Loading test data: /kaggle/input/deep-past-initiative-machine-translation/test.csv
2026-02-01 10:58:53,982 - INFO - ‚úÖ Loaded 4 samples
2026-02-01 10:58:53,983 - INFO - Loading model: /kaggle/input/final-byt5/byt5-akkadian-optimized-34x


PyTorch: 2.8.0+cu126 | CUDA: True
GPU: Tesla T4 | 15.6GB


2026-02-01 10:58:56.440110: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769943536.651258      24 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769943536.710689      24 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769943537.209247      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769943537.209292      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769943537.209295      24 computation_placer.cc:177] computation placer alr

üöÄ Translating:   0%|          | 0/1 [00:00<?, ?it/s]

2026-02-01 10:59:33,512 - INFO - ‚úÖ Inference complete
2026-02-01 10:59:33,520 - INFO - ‚úÖ Saved: /kaggle/working/submission.csv



üìä SUBMISSION SUMMARY
Total rows: 4
Empty translations: 0
Length range: [70, 242]

First 3:
 id                                                                                                                                                                                                                                        translation
  0                                                                                                                                                                             Thus says the Kanesh colony: Speak to our messengers every single day.
  1                                                             As for the tablet of the City, you wrote to me in the tablet of the City. On this day, whoever informs me or does not detain me in the proceedings, the colony of Kanesh will receive.
  2 As soon as you hear our letter, there, either he gave anything to the palace, or he did not give anything until us, or he did not give anything until us. Let us