# Text Normalization (English) - Medal Plan and Experiment Log

## Goal
- Win a medal (>= bronze) on Kaggle's Text Normalization Challenge - English.

## High-level Plan
1. Data understanding: load train/test/sample_submission, inspect schema and target.
2. Baselines:
   - Identity baseline (predict `before` -> `after` as same or majority mapping).
   - Memorization: freq-based mapping from `before` (optionally conditioned on class/context) to most-common `after`.
3. Modeling:
   - Rule-based/regex normalizers for categories: numbers, dates, times, money, measures, ordinals, telephone, electronic, etc.
   - Hybrid: backoff to memorized mapping; else identity if class says `PLAIN`/`PUNCT`.
   - Optional ML: character-level seq2seq or edit-distance based transducer (if time).
4. CV strategy: time-safe split by sequence/utterance id; metric is accuracy over tokens; replicate Kaggle eval locally.
5. Iterate with error analysis by class; prioritize high-volume error buckets.

## Experiment Log
- [T0] Setup, load data, schema inspection.
- [T1] Build identity and memorize baselines; local eval.
- [T2] Add deterministic rules for common semiotic classes; iterate.
- [T3] Full train fit and inference; create submission.csv; refine until medal range.

We will request expert reviews at major milestones (plan, EDA, baseline, modeling, and if score stalls).

In [2]:
# T0: Setup, load data, schema inspection
import os, sys, gc, time, json, math, textwrap, zipfile, psutil
from pathlib import Path
import pandas as pd
import numpy as np

pd.set_option('display.max_colwidth', 200)
pd.set_option('display.width', 200)

DATA_TRAIN = 'en_train.csv.zip'
DATA_TEST = 'en_test_2.csv.zip'
DATA_SUB = 'en_sample_submission_2.csv.zip'

def mem():
    p = psutil.Process(os.getpid())
    return f"RAM used: {p.memory_info().rss/1e9:.2f} GB"

t0 = time.time()
print('CWD:', os.getcwd())
print('Files:', sorted(os.listdir()))
print(mem())

def read_csv_zip(path):
    # pandas can read zip directly
    return pd.read_csv(path)

print('\nLoading train...'); t = time.time()
train = read_csv_zip(DATA_TRAIN)
print('Train shape:', train.shape, '| elapsed:', f"{time.time()-t:.2f}s")
print(mem())

print('\nTrain columns:', list(train.columns))
print('\nTrain head:')
display(train.head(10))

# Optional quick peeks if present
for col in ['class', 'semiotic_class']:
    if col in train.columns:
        vc = train[col].value_counts().head(20)
        print(f"\nTop {col} counts:\n", vc)

for col in ['before', 'after']:
    if col in train.columns:
        nulls = train[col].isna().sum()
        print(f"Nulls in {col}:", nulls)

print('\nLoading test...'); t = time.time()
test = read_csv_zip(DATA_TEST)
print('Test shape:', test.shape, '| elapsed:', f"{time.time()-t:.2f}s")
print('Test columns:', list(test.columns))

print('\nLoading sample submission...'); t = time.time()
sub = read_csv_zip(DATA_SUB)
print('Sample submission shape:', sub.shape, '| elapsed:', f"{time.time()-t:.2f}s")
print('Sample submission columns:', list(sub.columns))

print('\nDtypes (train):')
print(train.dtypes)

print('\nBasic checks complete. Total elapsed:', f"{time.time()-t0:.2f}s")

CWD: /app/agent_run_states/text-normalization-challenge-english-language-spray-20250909-022048
Files: ['agent_metadata', 'description.md', 'docker_run.log', 'en_sample_submission_2.csv.zip', 'en_test_2.csv.zip', 'en_train.csv.zip', 'main.ipynb', 'requirements.txt', 'task.txt']
RAM used: 0.17 GB

Loading train...


Train shape: (8924976, 5) | elapsed: 6.33s
RAM used: 1.13 GB

Train columns: ['sentence_id', 'token_id', 'class', 'before', 'after']

Train head:


Unnamed: 0,sentence_id,token_id,class,before,after
0,0,0,PLAIN,Brillantaisia,Brillantaisia
1,0,1,PLAIN,is,is
2,0,2,PLAIN,a,a
3,0,3,PLAIN,genus,genus
4,0,4,PLAIN,of,of
5,0,5,PLAIN,plant,plant
6,0,6,PLAIN,in,in
7,0,7,PLAIN,family,family
8,0,8,PLAIN,Acanthaceae,Acanthaceae
9,0,9,PUNCT,.,.



Top class counts:
 class
PLAIN         6617274
PUNCT         1692369
DATE           232281
LETTERS        137499
CARDINAL       120265
VERBATIM        70160
MEASURE         13285
ORDINAL         11386
DECIMAL          8828
MONEY            5510
DIGIT            4929
ELECTRONIC       4690
TELEPHONE        3623
TIME             1318
FRACTION         1086
ADDRESS           473
Name: count, dtype: int64


Nulls in before: 144


Nulls in after: 140

Loading test...


Test shape: (993465, 3) | elapsed: 0.31s
Test columns: ['sentence_id', 'token_id', 'before']

Loading sample submission...


Sample submission shape: (993465, 2) | elapsed: 0.60s
Sample submission columns: ['id', 'after']

Dtypes (train):
sentence_id     int64
token_id        int64
class          object
before         object
after          object
dtype: object

Basic checks complete. Total elapsed: 8.57s


In [4]:
# T1: Baselines, mapping, and first submission
import time
t = time.time()

# 1) Verify submission ID format against sample
print('Verifying submission id format...')
sub_head = sub.head(3).copy()
tmp_ids = (test['sentence_id'].astype(str) + '_' + test['token_id'].astype(str)).head(3).tolist()
print('Sample sub head ids:', sub_head['id'].tolist())
print('Constructed head ids:', tmp_ids)
id_format_ok = True
for a,b in zip(sub_head['id'].tolist(), tmp_ids):
    if a != b:
        id_format_ok = False
        break
print('ID format matches sample:', id_format_ok)

# 2) Handle nulls: drop only for training map building; keep test but fill for safety
n_before_null = train['before'].isna().sum()
n_after_null = train['after'].isna().sum()
print(f'Nulls -> train.before: {n_before_null}, train.after: {n_after_null}')
train_clean = train.dropna(subset=['before','after']).copy()
print('Train after dropna:', train_clean.shape)

# 3) Identity baseline on train
identity_acc = (train_clean['before'] == train_clean['after']).mean()
print(f'Identity baseline accuracy (train): {identity_acc:.6f}')

# 4) Memorization map: before -> most frequent after
print('Building primary memorization map...')
t_map = time.time()
# Using value_counts idxmax for speed and determinism
mode_after = train_clean.groupby('before', sort=False)['after'].agg(lambda s: s.value_counts().idxmax())
memo_map = mode_after.to_dict()
print(f'Map size: {len(memo_map):,} | elapsed: {time.time()-t_map:.2f}s')

# 5) Lowercase backoff map
print('Building lowercase backoff map...')
t_low = time.time()
train_clean['_before_lower'] = train_clean['before'].str.lower()
mode_after_lower = train_clean.groupby('_before_lower', sort=False)['after'].agg(lambda s: s.value_counts().idxmax())
memo_map_lower = mode_after_lower.to_dict()
print(f'Lower map size: {len(memo_map_lower):,} | elapsed: {time.time()-t_low:.2f}s')
train_clean.drop(columns=['_before_lower'], inplace=True)

# 6) Inference on test: primary -> lowercase -> identity
print('Inferring on test...')
test_inf = test.copy()
test_inf['before'] = test_inf['before'].fillna('')
pred = test_inf['before'].map(memo_map)
mask_na = pred.isna()
if mask_na.any():
    pred2 = test_inf.loc[mask_na, 'before'].str.lower().map(memo_map_lower)
    pred.loc[mask_na] = pred2
mask_na2 = pred.isna()
if mask_na2.any():
    pred.loc[mask_na2] = test_inf.loc[mask_na2, 'before']

# 7) Build submission in the same order as sample
test_ids = test_inf['sentence_id'].astype(str) + '_' + test_inf['token_id'].astype(str)
submission = pd.DataFrame({'id': test_ids, 'after': pred.astype(str)})
print('Submission shape:', submission.shape)

# Ensure row count matches sample and first few ids align
assert submission.shape[0] == sub.shape[0], f'Row count mismatch: {submission.shape[0]} vs sample {sub.shape[0]}'
print('First 3 ids (ours):', submission['id'].head(3).tolist())
print('First 3 ids (sample):', sub['id'].head(3).tolist())

# 8) Save
out_path = 'submission.csv'
submission.to_csv(out_path, index=False)
print('Saved to', out_path)
print('T1 total elapsed:', f'{time.time()-t:.2f}s')

Verifying submission id format...


Sample sub head ids: ['0_0', '0_1', '0_2']
Constructed head ids: ['0_0', '0_1', '0_2']
ID format matches sample: True


Nulls -> train.before: 144, train.after: 140


Train after dropna: (8924832, 5)
Identity baseline accuracy (train): 0.933502
Building primary memorization map...


Map size: 455,220 | elapsed: 87.14s
Building lowercase backoff map...


Lower map size: 404,763 | elapsed: 79.28s


Inferring on test...


Submission shape: (993465, 2)
First 3 ids (ours): ['0_0', '0_1', '0_2']
First 3 ids (sample): ['0_0', '0_1', '0_2']


Saved to submission.csv
T1 total elapsed: 173.55s


In [6]:
# T1b: Quick local CV with GroupShuffleSplit by sentence_id for memo baseline
import time
from sklearn.model_selection import GroupShuffleSplit

t0 = time.time()
print('Starting GroupShuffleSplit CV (80/20) ...')
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
groups = train['sentence_id'].values
train_idx, val_idx = next(gss.split(train, groups=groups))
print('Train idx size:', len(train_idx), 'Val idx size:', len(val_idx))

trn = train.iloc[train_idx].dropna(subset=['before','after']).copy()
val = train.iloc[val_idx].copy()
print('Train (clean) shape:', trn.shape, 'Val shape:', val.shape)

t = time.time()
print('Building maps on train split ...')
mode_after_cv = trn.groupby('before', sort=False)['after'].agg(lambda s: s.value_counts().idxmax())
memo_map_cv = mode_after_cv.to_dict()
trn['_before_lower'] = trn['before'].str.lower()
mode_after_lower_cv = trn.groupby('_before_lower', sort=False)['after'].agg(lambda s: s.value_counts().idxmax())
memo_map_lower_cv = mode_after_lower_cv.to_dict()
print('Maps built | sizes:', len(memo_map_cv), len(memo_map_lower_cv), '| elapsed:', f'{time.time()-t:.2f}s')
trn.drop(columns=['_before_lower'], inplace=True)

print('Predicting on validation ...')
val_bef = val['before'].fillna('')
pred_val = val_bef.map(memo_map_cv)
na1 = pred_val.isna()
if na1.any():
    pred_val.loc[na1] = val_bef.loc[na1].str.lower().map(memo_map_lower_cv)
na2 = pred_val.isna()
if na2.any():
    pred_val.loc[na2] = val_bef.loc[na2]

acc = (pred_val.values == val['after'].values).mean()
print(f'CV accuracy (memo + lower backoff): {acc:.6f}')

# Error analysis by class (top 20)
err_mask = pred_val.values != val['after'].values
if 'class' in val.columns:
    err_classes = val.loc[err_mask, 'class'].value_counts().head(20)
    print('\nTop error classes (val):')
    print(err_classes)
else:
    print('Class column not available in validation set.')

print('T1b total elapsed:', f'{time.time()-t0:.2f}s')

Starting GroupShuffleSplit CV (80/20) ...


Train idx size: 7136041 Val idx size: 1788935


Train (clean) shape: (7135924, 5) Val shape: (1788935, 5)
Building maps on train split ...


Maps built | sizes: 395355 351290 | elapsed: 144.49s


Predicting on validation ...


CV accuracy (memo + lower backoff): 0.986828

Top error classes (val):
class
PLAIN         8352
DATE          4871
LETTERS       3447
CARDINAL      1908
MEASURE       1120
ELECTRONIC     689
TELEPHONE      647
MONEY          608
DECIMAL        577
DIGIT          527
ORDINAL        324
TIME           191
FRACTION       144
VERBATIM       119
ADDRESS         40
Name: count, dtype: int64
T1b total elapsed: 149.98s


In [26]:
# T2: Class-aware memoization + class-guarded deterministic rules, CV and inference
import re, time

t0 = time.time()
print('Building class-aware maps from full train ...')
tr = train.dropna(subset=['before','after','class']).copy()

# before -> most frequent class
cls_map = tr.groupby('before', sort=False)['class'].agg(lambda s: s.value_counts().idxmax()).to_dict()
cls_map_lower = tr.assign(_bl=tr['before'].str.lower()).groupby('_bl', sort=False)['class'].agg(lambda s: s.value_counts().idxmax()).to_dict()

# (before, class) -> most frequent after
pair_map = tr.groupby(['before','class'], sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()
tr['_bl'] = tr['before'].str.lower()
pair_map_lower = tr.groupby(['_bl','class'], sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()
tr.drop(columns=['_bl'], inplace=True)
print('Map sizes | cls:', len(cls_map), 'pair:', len(pair_map))

# Reuse simple memo maps from T1 if available; otherwise build quickly
if 'memo_map' not in globals():
    memo_map = tr.groupby('before', sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()
if 'memo_map_lower' not in globals():
    memo_map_lower = tr.assign(_bl=tr['before'].str.lower()).groupby('_bl', sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()

# ---------- Helpers ----------
digit_words = {'0':'zero','1':'one','2':'two','3':'three','4':'four','5':'five','6':'six','7':'seven','8':'eight','9':'nine'}
ones = ['zero','one','two','three','four','five','six','seven','eight','nine']
teens = ['ten','eleven','twelve','thirteen','fourteen','fifteen','sixteen','seventeen','eighteen','nineteen']
tens_words = ['', '', 'twenty','thirty','forty','fifty','sixty','seventy','eighty','ninety']

def two_digit(n):
    if n < 10: return ones[n]
    if n < 20: return teens[n-10]
    t, r = divmod(n, 10)
    return tens_words[t] + ('' if r==0 else ' ' + ones[r])

def three_digit(n):
    h, r = divmod(n, 100)
    if h == 0: return two_digit(r)
    if r == 0: return ones[h] + ' hundred'
    return ones[h] + ' hundred ' + two_digit(r)

def int_to_words(num):
    # supports up to billions comfortably
    num = int(num)
    if num == 0: return 'zero'
    parts = []
    billions, rem = divmod(num, 10**9)
    millions, rem = divmod(rem, 10**6)
    thousands, rem = divmod(rem, 1000)
    if billions: parts.append(three_digit(billions) + ' billion')
    if millions: parts.append(three_digit(millions) + ' million')
    if thousands: parts.append(three_digit(thousands) + ' thousand')
    if rem: parts.append(three_digit(rem))
    return ' '.join(parts)

def ordinal_word(n):
    n = int(n)
    special = {1:'first',2:'second',3:'third',4:'fourth',5:'fifth',6:'sixth',7:'seventh',8:'eighth',9:'ninth',10:'tenth',11:'eleventh',12:'twelfth',13:'thirteenth'}
    if n in special: return special[n]
    if n < 20: return two_digit(n) + 'th'
    t, r = divmod(n, 10)
    base = tens_words[t]
    if r == 0: return base + 'ieth' if t in [2,3,4,5,8,9] else base + 'th'
    return base + ' ' + special.get(r, ones[r] + 'th')

def normalize_decimal(tok):
    # 3.14 -> three point one four; 0.05 -> zero point zero five; .5 -> zero point five
    if tok.count('.') != 1: return None
    a,b = tok.split('.')
    a = a.replace(',', '')
    if a == '': a = '0'
    if not re.fullmatch(r'-?\d+', a) or not re.fullmatch(r'\d+', b):
        return None
    sign = 'minus ' if a.startswith('-') else ''
    if a.startswith('-'): a = a[1:]
    left = int_to_words(int(a))
    right = ' '.join(digit_words[ch] for ch in b)
    return (sign + left + ' point ' + right).strip()

month_names = {str(i).zfill(2): m for i,m in enumerate(['','january','february','march','april','may','june','july','august','september','october','november','december']) if i}
month_names.update({str(i): m for i,m in enumerate(['','january','february','march','april','may','june','july','august','september','october','november','december']) if i})

def year_to_words(y):
    y = int(y)
    if y < 1000: return int_to_words(y)
    if 1900 <= y <= 1999:
        h, t = divmod(y, 100)
        return int_to_words(h) + ' ' + two_digit(t)
    if 2000 <= y <= 2009:
        return 'two thousand' if y == 2000 else 'two thousand ' + int_to_words(y-2000)
    if 2010 <= y <= 2099:
        return 'two thousand ' + int_to_words(y-2000)
    return int_to_words(y)

def normalize_date(tok):
    # YYYY-MM-DD / MM/DD/YYYY / DD-MM-YYYY / M/D/YY / YYYY.MM.DD
    if re.fullmatch(r'\d{4}-\d{1,2}-\d{1,2}', tok):
        y,m,d = tok.split('-')
        mname = month_names.get(m.zfill(2))
        if not mname: return None
        return f"{mname} {ordinal_word(int(d))} {year_to_words(y)}"
    if re.fullmatch(r'\d{1,2}/\d{1,2}/\d{4}', tok):
        m,d,y = tok.split('/')
        mname = month_names.get(m.zfill(2))
        if not mname: return None
        return f"{mname} {ordinal_word(int(d))} {year_to_words(y)}"
    if re.fullmatch(r'\d{1,2}-\d{1,2}-\d{4}', tok):
        d,m,y = tok.split('-')
        mname = month_names.get(m.zfill(2))
        if not mname: return None
        return f"{mname} {ordinal_word(int(d))} {year_to_words(y)}"
    if re.fullmatch(r'\d{1,2}/\d{1,2}/\d{2}', tok):
        m,d,yy = tok.split('/')
        y = '20' + yy if int(yy) <= 30 else '19' + yy
        mname = month_names.get(m.zfill(2))
        if not mname: return None
        return f"{mname} {ordinal_word(int(d))} {year_to_words(y)}"
    if re.fullmatch(r'\d{4}\.\d{1,2}\.\d{1,2}', tok):
        y,m,d = tok.split('.')
        mname = month_names.get(m.zfill(2))
        if not mname: return None
        return f"{mname} {ordinal_word(int(d))} {year_to_words(y)}"
    return None

def normalize_letters(tok):
    # Avoid altering very short/common capitalized PLAIN tokens
    if tok in {'I','A','AM','PM','US','U.S.','U.S','UK','U.K.','UK.'}:
        return None
    if re.fullmatch(r'[A-Z]{2,}', tok):
        return ' '.join(list(tok.lower()))
    if re.fullmatch(r'([A-Z]\.){2,}[A-Z]?\.?', tok):
        return ' '.join(ch.lower() for ch in tok if ch.isalpha())
    if 'AT&T' in tok:
        return 'a t and t'
    return None

def normalize_digit(tok):
    if not re.fullmatch(r'\d+', tok): return None
    return ' '.join(digit_words[ch] for ch in tok)

def normalize_cardinal(tok):
    s = tok.replace(',', '')
    if not re.fullmatch(r'-?\d+', s): return None
    sign = 'minus ' if s.startswith('-') else ''
    if s.startswith('-'): s = s[1:]
    return (sign + int_to_words(int(s))).strip()

def normalize_ordinal(tok):
    m = re.fullmatch(r'(\d+)(st|nd|rd|th)', tok)
    if not m: return None
    return ordinal_word(m.group(1))

def normalize_fraction(tok):
    m = re.fullmatch(r'(\d+)/(\d+)', tok)
    if not m: return None
    a,b = int(m.group(1)), int(m.group(2))
    num = int_to_words(a)
    den = ordinal_word(b)
    if a == 1:
        spec = {2:'half',4:'quarter'}
        den_word = spec.get(b, den)
        return den_word
    else:
        if den.endswith('f'): den = den[:-1] + 'ves'
        elif den.endswith('y'): den = den[:-1] + 'ies'
        else: den = den + 's'
        return f"{num} {den}"

unit_map = {
    'ft': ('foot','feet'), 'in': ('inch','inches'), 'yd': ('yard','yards'), 'mi': ('mile','miles'),
    'lb': ('pound','pounds'), 'lbs': ('pound','pounds'), 'oz': ('ounce','ounces'),
    'kg': ('kilogram','kilograms'), 'g': ('gram','grams'), 'mg': ('milligram','milligrams'),
    'km': ('kilometer','kilometers'), 'm': ('meter','meters'), 'cm': ('centimeter','centimeters'), 'mm': ('millimeter','millimeters'),
    'l': ('liter','liters'), 'ml': ('milliliter','milliliters'),
    'mph': ('miles per hour','miles per hour'), 'km/h': ('kilometers per hour','kilometers per hour'), 'kph': ('kilometers per hour','kilometers per hour'),
    '%': ('percent','percent'), '\u00b0c': ('degrees celsius','degrees celsius'), '\u00b0f': ('degrees fahrenheit','degrees fahrenheit')
}

def normalize_measure(tok):
    t = tok.lower()
    m = re.fullmatch(r'(-?[\d,]+(?:\.\d+)?)[\s]*([a-z\u00b0/%]+(?:/[a-z]+)?)', t)
    if not m:
        m = re.fullmatch(r'(-?[\d,]+)([a-z\u00b0/%]+)', t)
    if not m: return None
    num_s, unit = m.group(1), m.group(2)
    unit = unit.strip()
    singular_plural = unit_map.get(unit)
    if not singular_plural:
        return None
    if '.' in num_s:
        spoken_num = normalize_decimal(num_s)
    else:
        spoken_num = normalize_cardinal(num_s)
    if spoken_num is None: return None
    try:
        val = float(num_s.replace(',', ''))
    except:
        val = None; 
    unit_word = singular_plural[0] if val == 1 else singular_plural[1]
    return f"{spoken_num} {unit_word}"

def normalize_money(tok):
    m = re.fullmatch(r'([$\u00a3\u20ac])\s*(\d*[\d,]*)(?:\.(\d{1,2}))?', tok)
    if not m: return None
    sym, a, c = m.groups()
    a = (a or '').replace(',', '')
    cur = {'$':'dollar','\u00a3':'pound','\u20ac':'euro'}[sym]
    out = []
    have_dollars = bool(a) and int(a) > 0
    cents_val = None
    if c is not None:
        cents_val = int(c.ljust(2,'0'))
    if have_dollars:
        out.append(int_to_words(int(a)) + (' ' + cur + ('s' if int(a)!=1 else '')))
    if cents_val is not None and cents_val > 0:
        out.append(int_to_words(cents_val) + (' cent' + ('s' if cents_val!=1 else '')))
    if not out and (cents_val is not None and cents_val > 0):
        return int_to_words(cents_val) + (' cent' + ('s' if cents_val!=1 else ''))
    if not out:
        return None
    return ' and '.join(out) if len(out) == 2 else out[0]

def normalize_time(tok):
    m = re.fullmatch(r'(\d{1,2}):(\d{2})', tok)
    if not m: return None
    hh = int(m.group(1)); mm = int(m.group(2))
    hour = int_to_words(hh)
    if mm == 0: return hour + " o'clock"
    if 1 <= mm <= 9:
        return hour + ' oh ' + int_to_words(mm)
    return hour + ' ' + int_to_words(mm)

def normalize_telephone(tok):
    t = tok
    digits = re.sub(r'\D', '', t)
    if not digits: return None
    spoken = ' '.join(digit_words[ch] for ch in digits)
    return spoken

def normalize_electronic(tok):
    t = tok.strip()
    # emails: require at least one dot after '@'
    if '@' in t:
        at_idx = t.find('@')
        if '.' in t[at_idx+1:]:
            parts = re.split(r'(@|\.)', t)
            spoken = []
            for p in parts:
                if p == '@': spoken.append('at')
                elif p == '.': spoken.append('dot')
                else: spoken.append(p)
            return ' '.join(spoken).replace('  ', ' ').strip()
    # urls/domains: require slash with letters/digits or www. prefix
    if t.lower().startswith('www.') or re.search(r'/[A-Za-z0-9]', t):
        rep = t.replace('.', ' dot ').replace('/', ' slash ').replace('-', ' dash ')
        rep = re.sub(r'\bwww\b', 'w w w', rep)
        return ' '.join(rep.split())
    return None

# Lightweight detectors (used only if class not in pair maps)
pat_decimal = re.compile(r'^-?[\d,]+\.\d+$')
pat_cardinal = re.compile(r'^-?[\d,]+$')
pat_ordinal = re.compile(r'^\d+(st|nd|rd|th)$')
pat_fraction = re.compile(r'^\d+/\d+$')
pat_date1 = re.compile(r'^\d{4}-\d{1,2}-\d{1,2}$')
pat_date2 = re.compile(r'^\d{1,2}/\d{1,2}/\d{4}$')
pat_date3 = re.compile(r'^\d{1,2}-\d{1,2}-\d{4}$')
pat_date4 = re.compile(r'^\d{1,2}/\d{1,2}/\d{2}$')
pat_date5 = re.compile(r'^\d{4}\.\d{1,2}\.\d{1,2}$')
pat_letters = re.compile(r'^[A-Z]{2,}$')
pat_phone = re.compile(r'^[+\d][\d\s().-]{5,}$')
pat_money = re.compile(r'^[$\u00a3\u20ac]')
pat_time = re.compile(r'^\d{1,2}:\d{2}$')
pat_measure = re.compile(r'^-?[\d,]+(?:\.\d+)?\s?[a-zA-Z\u00b0/%]+')

def infer_class(tok, prior=None):
    if prior: return prior
    # High-certainty symbols first
    if pat_money.search(tok): return 'MONEY'
    if pat_time.fullmatch(tok): return 'TIME'
    if pat_fraction.fullmatch(tok): return 'FRACTION'
    if pat_decimal.fullmatch(tok) or tok.startswith('.') and re.fullmatch(r'\.\d+', tok): return 'DECIMAL'
    if pat_ordinal.fullmatch(tok): return 'ORDINAL'
    if pat_cardinal.fullmatch(tok): return 'CARDINAL'
    # Dates
    if pat_date1.fullmatch(tok) or pat_date2.fullmatch(tok) or pat_date3.fullmatch(tok) or pat_date4.fullmatch(tok) or pat_date5.fullmatch(tok): return 'DATE'
    # LETTERS guarded by prior elsewhere; still detect pattern
    if pat_letters.fullmatch(tok): return 'LETTERS'
    # Telephone
    if pat_phone.fullmatch(tok): return 'TELEPHONE'
    # Electronic tightened
    t = tok.strip()
    if ('@' in t and '.' in t[t.find('@')+1:]) or t.lower().startswith('www.') or re.search(r'/[A-Za-z0-9]', t):
        return 'ELECTRONIC'
    if pat_measure.match(tok.lower()): return 'MEASURE'
    return None

def apply_rules(tok, cls):
    if cls == 'DATE':
        return normalize_date(tok)
    if cls == 'LETTERS':
        return normalize_letters(tok)
    if cls == 'DECIMAL':
        return normalize_decimal(tok)
    if cls == 'CARDINAL':
        return normalize_cardinal(tok)
    if cls == 'ORDINAL':
        return normalize_ordinal(tok)
    if cls == 'DIGIT':
        return normalize_digit(tok)
    if cls == 'MEASURE':
        return normalize_measure(tok)
    if cls == 'TELEPHONE':
        return normalize_telephone(tok)
    if cls == 'ELECTRONIC':
        return normalize_electronic(tok)
    if cls == 'MONEY':
        return normalize_money(tok)
    if cls == 'TIME':
        return normalize_time(tok)
    if cls == 'FRACTION':
        return normalize_fraction(tok)
    return None

def cascade_predict_tokens(df_before_series):
    bef = df_before_series.fillna('')
    # Predict prior class for all tokens
    prior = bef.map(cls_map)
    need_lower = prior.isna()
    if need_lower.any():
        prior.loc[need_lower] = bef.loc[need_lower].str.lower().map(cls_map_lower)

    # Stage 1: class-aware exact memo via pair maps; fallback to token memo
    pred = pd.Series(index=bef.index, dtype=object)
    # exact (before, prior) where prior is known
    has_prior = prior.notna()
    if has_prior.any():
        idxp = prior.index[has_prior]
        keys = list(zip(bef.loc[idxp].tolist(), prior.loc[idxp].tolist()))
        pred.loc[idxp] = [pair_map.get(k) for k in keys]
        # lower pair for remaining with prior
        rem = pred.loc[idxp].isna()
        if rem.any():
            idxp2 = idxp[rem.values]
            keys2 = list(zip(bef.loc[idxp2].str.lower().tolist(), prior.loc[idxp2].tolist()))
            pred.loc[idxp2] = [pair_map_lower.get(k) for k in keys2]
    # fill remaining with simple memo
    miss = pred.isna()
    if miss.any():
        pred.loc[miss] = bef.loc[miss].map(memo_map)

    # Stage 2: class-guarded rules (skip for predicted PLAIN/PUNCT)
    miss2 = pred.isna()
    if miss2.any():
        idx2 = pred.index[miss2]
        toks = bef.loc[idx2]
        prs = prior.loc[idx2]
        outs = []
        for tk, pr in zip(toks.tolist(), prs.tolist()):
            if pr in ('PLAIN','PUNCT'):
                outs.append(None)
            else:
                ic = infer_class(tk, pr)
                outs.append(apply_rules(tk, ic) if ic not in ('PLAIN','PUNCT', None) else None)
        pred.loc[idx2] = outs

    # Stage 3: lowercase memo
    miss3 = pred.isna()
    if miss3.any():
        pred.loc[miss3] = bef.loc[miss3].str.lower().map(memo_map_lower)

    # Stage 4: identity (and safeguard for predicted PLAIN/PUNCT to avoid unwanted changes)
    miss4 = pred.isna()
    if miss4.any():
        pred.loc[miss4] = bef.loc[miss4]
    if prior.notna().any():
        mask_plain = prior.isin(['PLAIN','PUNCT']) & (pred != bef)
        if mask_plain.any():
            pred.loc[mask_plain] = bef.loc[mask_plain]
    return pred

# ----- Local CV using same cascade -----
print('Running local CV (single split) with cascade ...')
from sklearn.model_selection import GroupShuffleSplit
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=123)
groups = train['sentence_id'].values
tr_idx, va_idx = next(gss.split(train, groups=groups))
va = train.iloc[va_idx].copy()
pred_va = cascade_predict_tokens(va['before'])
cv_acc = (pred_va.values == va['after'].values).mean()
print(f'Cascade CV accuracy: {cv_acc:.6f}')
err_mask = pred_va.values != va['after'].values
print('Errors remaining:', int(err_mask.sum()))
if 'class' in va.columns:
    print('Top error classes after rules:')
    print(va.loc[err_mask, 'class'].value_counts().head(15))

# ----- Inference on test with cascade -----
print('Inferring on test with cascade ...')
pred_test = cascade_predict_tokens(test['before'])
test_ids = test['sentence_id'].astype(str) + '_' + test['token_id'].astype(str)
submission2 = pd.DataFrame({'id': test_ids, 'after': pred_test.astype(str)})
submission2.to_csv('submission.csv', index=False)
print('Wrote improved submission.csv')
print('T2 total elapsed:', f'{time.time()-t0:.2f}s')

In [68]:
# T2b: Full-train guarded cascade with confidence-gated maps (mirrors leak-free CV logic) to generate submission
import time, re
t0 = time.time()

SAFE_RULE_CLASSES = {'CARDINAL','DECIMAL','ORDINAL','DIGIT','FRACTION','MONEY','MEASURE','TIME','DATE'}
GUARDED_RULE_CLASSES = {'LETTERS','ELECTRONIC','TELEPHONE','VERBATIM'}

def is_word_like(tok: str) -> bool:
    if not isinstance(tok, str):
        return False
    has_lower = any(c.isalpha() and c.islower() for c in tok)
    has_upper = any(c.isalpha() and c.isupper() for c in tok)
    return bool(tok.istitle() or (has_lower and has_upper) or has_lower)

def count_digits(tok: str) -> int:
    return sum(ch.isdigit() for ch in tok) if isinstance(tok, str) else 0

# Strong LETTERS heuristic
LETTERS_EXCEPT = {'I','A','AM','PM','US','U.S.','U.S','UK','U.K.','UK.','EU','U.N.','UN','TV','OK','OK.','AI','ML','CEO'}
def is_strong_letters(tok: str) -> bool:
    if tok in LETTERS_EXCEPT: return False
    if re.fullmatch(r'[A-Z]{3,}', tok): return True
    if re.fullmatch(r'([A-Z]\.){3,}[A-Z]?\.?', tok): return True
    if re.fullmatch(r'[A-Z]+-[A-Z]+', tok): return True
    return False

# Safety: robust number/ordinal helpers mirroring Cell 6
SCALE_WORDS = ['', 'thousand', 'million', 'billion', 'trillion', 'quadrillion', 'quintillion']
ONES_WORDS = ['zero','one','two','three','four','five','six','seven','eight','nine']
TEENS_WORDS = ['ten','eleven','twelve','thirteen','fourteen','fifteen','sixteen','seventeen','eighteen','nineteen']
TENS_WORDS = ['', '', 'twenty','thirty','forty','fifty','sixty','seventy','eighty','ninety']

def _two_digit_safe(n:int) -> str:
    if n < 10: return ONES_WORDS[n]
    if n < 20: return TEENS_WORDS[n-10]
    t, r = divmod(n, 10)
    return TENS_WORDS[t] + ('' if r==0 else ' ' + ONES_WORDS[r])

def _three_digit_safe(n:int) -> str:
    h, r = divmod(n, 100)
    if h == 0: return _two_digit_safe(r) if r else ''
    if r == 0: return ONES_WORDS[h] + ' hundred'
    return ONES_WORDS[h] + ' hundred ' + _two_digit_safe(r)

def int_to_words_safe(num_in) -> str:
    s = str(num_in).replace(',', '')
    sign = ''
    if s.startswith('-'): sign, s = 'minus ', s[1:]
    if not s.isdigit():
        try:
            s = str(int(float(s)))
        except:
            return ''
    n = int(s)
    if n == 0: return 'zero'
    parts = []
    idx = 0
    while n > 0 and idx < len(SCALE_WORDS):
        n, chunk = divmod(n, 1000)
        if chunk:
            words = _three_digit_safe(chunk)
            if SCALE_WORDS[idx]:
                parts.append(words + ' ' + SCALE_WORDS[idx])
            else:
                parts.append(words)
        idx += 1
    return (sign + ' '.join(reversed([p for p in parts if p]))).strip()

def ordinal_word_safe(n):
    try:
        n = int(n)
    except:
        return None
    try:
        return _ordinalize_words(int_to_words(n))
    except NameError:
        words = int_to_words_safe(n)
        ws = words.split()
        if not ws: return words
        last = ws[-1]
        if last.endswith('y'):
            ws[-1] = last[:-1] + 'ieth'
        elif last in ('one','two','three','five','eight','nine','twelve'):
            repl = {'one':'first','two':'second','three':'third','five':'fifth','eight':'eighth','nine':'ninth','twelve':'twelfth'}
            ws[-1] = repl[last]
        else:
            ws[-1] = last + 'th'
        return ' '.join(ws)

def normalize_fraction_safe(tok):
    m = re.fullmatch(r'(\d+)/(\d+)', tok or '')
    if not m: return None
    a,b = int(m.group(1)), int(m.group(2))
    num = int_to_words_safe(a)
    den = ordinal_word_safe(b)
    if a == 1:
        spec = {2:'half',4:'quarter'}
        den_word = spec.get(b, den)
        return den_word
    else:
        if den.endswith('f'): den = den[:-1] + 'ves'
        elif den.endswith('y'): den = den[:-1] + 'ies'
        else: den = den + 's'
        return f"{num} {den}"

def normalize_cardinal_safe(tok):
    s = str(tok).replace(',', '')
    sign = 'minus ' if s.startswith('-') else ''
    if s.startswith('-'): s = s[1:]
    if not re.fullmatch(r'\d+', s):
        return None
    return (sign + int_to_words_safe(int(s))).strip()

def normalize_decimal_safe(tok):
    if tok.count('.') != 1: return None
    a,b = tok.split('.')
    a = a.replace(',', '')
    if a == '': a = '0'
    if not re.fullmatch(r'-?\d+', a) or not re.fullmatch(r'\d+', b):
        return None
    sign = 'minus ' if a.startswith('-') else ''
    if a.startswith('-'): a = a[1:]
    left = int_to_words_safe(int(a))
    right = ' '.join({'0':'zero','1':'one','2':'two','3':'three','4':'four','5':'five','6':'six','7':'seven','8':'eight','9':'nine'}[ch] for ch in b)
    return (sign + left + ' point ' + right).strip()

def apply_rules_safe(tok, cls):
    if cls == 'FRACTION':
        return normalize_fraction_safe(tok)
    if cls == 'CARDINAL':
        return normalize_cardinal_safe(tok)
    if cls == 'ORDINAL':
        m = re.fullmatch(r'(\d+)(st|nd|rd|th)', tok or '')
        return ordinal_word_safe(m.group(1)) if m else None
    if cls == 'DECIMAL':
        return normalize_decimal_safe(tok)
    return apply_rules(tok, cls)

def build_maps(df):
    d = df.dropna(subset=['before','after','class']).copy()
    # cls maps with confidence
    grp_cls = d.groupby(['before','class']).size().rename('cnt').reset_index()
    tot_cls = grp_cls.groupby('before')['cnt'].sum().rename('tot')
    top_cls = grp_cls.sort_values(['before','cnt'], ascending=[True, False]).drop_duplicates('before')
    top_cls = top_cls.merge(tot_cls, on='before', how='left')
    top_cls['conf'] = top_cls['cnt'] / top_cls['tot']
    cls_map_loc = dict(zip(top_cls['before'], top_cls['class']))
    cls_conf_map_loc = dict(zip(top_cls['before'], top_cls['conf']))
    d['_bl'] = d['before'].str.lower()
    grp_cls_l = d.groupby(['_bl','class']).size().rename('cnt').reset_index()
    tot_cls_l = grp_cls_l.groupby('_bl')['cnt'].sum().rename('tot')
    top_cls_l = grp_cls_l.sort_values(['_bl','cnt'], ascending=[True, False]).drop_duplicates('_bl')
    top_cls_l = top_cls_l.merge(tot_cls_l, on='_bl', how='left')
    top_cls_l['conf'] = top_cls_l['cnt'] / top_cls_l['tot']
    cls_map_lower_loc = dict(zip(top_cls_l['_bl'], top_cls_l['class']))
    cls_conf_map_lower_loc = dict(zip(top_cls_l['_bl'], top_cls_l['conf']))

    # (before, class) -> after
    pair_map_loc = d.groupby(['before','class'], sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()
    pair_map_lower_loc = d.groupby(['_bl','class'], sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()

    # memo maps with confidence
    grp_memo = d.groupby(['before','after']).size().rename('cnt').reset_index()
    tot_memo = grp_memo.groupby('before')['cnt'].sum().rename('tot')
    top_memo = grp_memo.sort_values(['before','cnt'], ascending=[True, False]).drop_duplicates('before')
    top_memo = top_memo.merge(tot_memo, on='before', how='left')
    top_memo['conf'] = top_memo['cnt'] / top_memo['tot']
    memo_map_loc = dict(zip(top_memo['before'], top_memo['after']))
    memo_conf_loc = dict(zip(top_memo['before'], top_memo['conf']))

    grp_memo_l = d.groupby(['_bl','after']).size().rename('cnt').reset_index()
    tot_memo_l = grp_memo_l.groupby('_bl')['cnt'].sum().rename('tot')
    top_memo_l = grp_memo_l.sort_values(['_bl','cnt'], ascending=[True, False]).drop_duplicates('_bl')
    top_memo_l = top_memo_l.merge(tot_memo_l, on='_bl', how='left')
    top_memo_l['conf'] = top_memo_l['cnt'] / top_memo_l['tot']
    memo_map_lower_loc = dict(zip(top_memo_l['_bl'], top_memo_l['after']))
    memo_conf_lower_loc = dict(zip(top_memo_l['_bl'], top_memo_l['conf']))

    return (cls_map_loc, cls_map_lower_loc, pair_map_loc, pair_map_lower_loc, memo_map_loc, memo_map_lower_loc,
            cls_conf_map_loc, cls_conf_map_lower_loc, memo_conf_loc, memo_conf_lower_loc)

def cascade_predict_tokens_with_maps(bef_series, maps, memo_conf_thr=0.70, cls_conf_thr=0.75):
    # Strategic order: pair_map -> guarded rules -> memo -> lower_memo -> identity -> final veto
    (cls_map_loc, cls_map_lower_loc, pair_map_loc, pair_map_lower_loc, memo_map_loc, memo_map_lower_loc,
     cls_conf_map_loc, cls_conf_map_lower_loc, memo_conf_loc, memo_conf_lower_loc) = maps
    bef = bef_series.fillna('')

    # Prior class and confidence
    prior = bef.map(cls_map_loc)
    prior_conf = bef.map(cls_conf_map_loc)
    need_lower = prior.isna()
    if need_lower.any():
        bef_l = bef.loc[need_lower].str.lower()
        prior.loc[need_lower] = bef_l.map(cls_map_lower_loc)
        prior_conf.loc[need_lower] = bef_l.map(cls_conf_map_lower_loc)

    # Stage 1: pair maps (gate by prior_conf >= 0.85)
    pred = pd.Series(index=bef.index, dtype=object)
    has_prior = prior.notna()
    if has_prior.any():
        conf_ok = prior_conf.fillna(0.0) >= 0.85
        idxp = prior.index[has_prior & conf_ok]
        if len(idxp) > 0:
            keys = list(zip(bef.loc[idxp].tolist(), prior.loc[idxp].tolist()))
            pred.loc[idxp] = [pair_map_loc.get(k) for k in keys]
            rem = pred.loc[idxp].isna()
            if rem.any():
                idxp2 = idxp[rem.values]
                keys2 = list(zip(bef.loc[idxp2].str.lower().tolist(), prior.loc[idxp2].tolist()))
                pred.loc[idxp2] = [pair_map_lower_loc.get(k) for k in keys2]

    # Stage 2: guarded rules with SAFE passthrough and relaxed LETTERS
    miss2 = pred.isna()
    if miss2.any():
        idx2 = pred.index[miss2]
        toks = bef.loc[idx2].tolist()
        prs = prior.loc[idx2].tolist()
        prc = prior_conf.loc[idx2].fillna(0.0).tolist()
        outs = []
        for tk, pr, pc in zip(toks, prs, prc):
            ic = infer_class(tk)  # infer independently
            # If prior is PLAIN/PUNCT, only skip for non-SAFE inferred classes
            if pr in ('PLAIN','PUNCT') and (ic not in SAFE_RULE_CLASSES):
                outs.append(None); continue
            # Special allowance: strong LETTERS acronyms (non word-like) may fire regardless of prior/pc
            special_letters = (ic == 'LETTERS' and is_strong_letters(tk) and not is_word_like(tk))
            # Guarded classes gating (unless special_letters)
            if ic in GUARDED_RULE_CLASSES and not special_letters:
                # word-like veto for guarded
                if is_word_like(tk):
                    outs.append(None); continue
                # confidence gating for guarded
                if pc < cls_conf_thr:
                    if ic == 'ELECTRONIC':
                        t = tk.strip()
                        cond_email = (t.count('@') == 1 and '.' in t[t.find('@')+1:])
                        cond_web = t.lower().startswith('www.') or re.search(r'/[A-Za-z0-9]', t)
                        if not (cond_email or cond_web):
                            outs.append(None); continue
                    if ic == 'TELEPHONE' and count_digits(tk) < 7:
                        outs.append(None); continue
            # apply rules
            if ic in SAFE_RULE_CLASSES:
                outs.append(apply_rules_safe(tk, ic))
            elif special_letters:
                outs.append(apply_rules_safe(tk, ic))
            elif ic in GUARDED_RULE_CLASSES and pr == ic:
                outs.append(apply_rules_safe(tk, ic))
            else:
                outs.append(None)
        pred.loc[idx2] = outs

    # Stage 3: memo with tighter, class-aware gating
    miss3 = pred.isna()
    if miss3.any():
        idx3 = pred.index[miss3]
        outs3 = []
        for i in idx3:
            tk = bef.loc[i]
            pr = prior.loc[i]
            pc = float(prior_conf.loc[i]) if prior_conf.loc[i] == prior_conf.loc[i] else 0.0
            if pr in ('PLAIN','PUNCT'):
                outs3.append(None); continue
            cand = memo_map_loc.get(tk)
            conf = float(memo_conf_loc.get(tk, 0.0))
            thr = 0.95 if pr in GUARDED_RULE_CLASSES else memo_conf_thr
            if pr in GUARDED_RULE_CLASSES and is_word_like(tk):
                outs3.append(None); continue
            if cand is not None and conf >= thr and pc >= cls_conf_thr:
                outs3.append(cand)
            else:
                outs3.append(None)
        pred.loc[idx3] = outs3

    # Stage 4: lowercase memo with tighter, class-aware gating
    miss4 = pred.isna()
    if miss4.any():
        idx4 = pred.index[miss4]
        outs4 = []
        for i in idx4:
            tk = bef.loc[i]
            pr = prior.loc[i]
            pc = float(prior_conf.loc[i]) if prior_conf.loc[i] == prior_conf.loc[i] else 0.0
            if pr in ('PLAIN','PUNCT'):
                outs4.append(None); continue
            tkl = str(tk).lower()
            cand2 = memo_map_lower_loc.get(tkl)
            conf2 = float(memo_conf_lower_loc.get(tkl, 0.0))
            thr2 = 0.96 if pr in GUARDED_RULE_CLASSES else memo_conf_thr
            if pr in GUARDED_RULE_CLASSES and is_word_like(tk):
                outs4.append(None); continue
            if cand2 is not None and conf2 >= thr2 and pc >= cls_conf_thr:
                outs4.append(cand2)
            else:
                outs4.append(None)
        pred.loc[idx4] = outs4

    # Stage 5: identity
    miss5 = pred.isna()
    if miss5.any():
        pred.loc[miss5] = bef.loc[miss5]

    # Stage 6: final vetoes - revert casing-only changes on word-like tokens
    try:
        pred_l = pred.fillna('').astype(str).str.lower()
        bef_l = bef.fillna('').astype(str).str.lower()
        casing_only = (pred_l == bef_l) & (pred != bef) & bef.apply(is_word_like)
        if casing_only.any():
            pred.loc[casing_only] = bef.loc[casing_only]
    except Exception:
        pass

    return pred

print('Building full-train maps with confidence ...')
tm = time.time()
maps_full = build_maps(train)
print('Maps built in', f'{time.time()-tm:.2f}s')

print('Inferring on test with guarded, confidence-gated cascade ...')
tp = time.time()
pred_test = cascade_predict_tokens_with_maps(test['before'], maps_full, memo_conf_thr=0.70, cls_conf_thr=0.75)
print('Test inference elapsed:', f'{time.time()-tp:.2f}s')

test_ids = test['sentence_id'].astype(str) + '_' + test['token_id'].astype(str)
submission_v1 = pd.DataFrame({'id': test_ids, 'after': pred_test.astype(str)})
submission_v1.to_csv('submission.csv', index=False)
print('Wrote submission.csv (guarded + confidence-gated, full-train maps)')
print('T2b elapsed:', f'{time.time()-t0:.2f}s')

Building full-train maps with confidence ...


Maps built in 206.68s
Inferring on test with guarded, confidence-gated cascade ...


Test inference elapsed: 6.89s


Wrote submission.csv (guarded + confidence-gated, full-train maps)
T2b elapsed: 215.11s


In [66]:
# T3: Leak-free CV (train-split maps) + confidence gating + strict guarding + final veto
import time, re
from sklearn.model_selection import GroupShuffleSplit

SAFE_RULE_CLASSES = {'CARDINAL','DECIMAL','ORDINAL','DIGIT','FRACTION','MONEY','MEASURE','TIME','DATE'}
GUARDED_RULE_CLASSES = {'LETTERS','ELECTRONIC','TELEPHONE','VERBATIM'}

def is_word_like(tok: str) -> bool:
    if not isinstance(tok, str):
        return False
    has_lower = any(c.isalpha() and c.islower() for c in tok)
    has_upper = any(c.isalpha() and c.isupper() for c in tok)
    return bool(tok.istitle() or (has_lower and has_upper) or has_lower)

def count_digits(tok: str) -> int:
    return sum(ch.isdigit() for ch in tok) if isinstance(tok, str) else 0

# Adjust LETTERS output to lowercase spaced letters
def normalize_letters(tok):
    if re.fullmatch(r'[A-Z]{2,}', tok):
        return ' '.join(list(tok.lower()))
    if re.fullmatch(r'([A-Z]\.){2,}[A-Z]?\.?', tok):
        return ' '.join(ch.lower() for ch in tok if ch.isalpha())
    if 'AT&T' in tok:
        return 'A T and T'.lower()
    return None

# Strong LETTERS heuristic
LETTERS_EXCEPT = {'I','A','AM','PM','US','U.S.','U.S','UK','U.K.','UK.','EU','U.N.','UN','TV','OK','OK.','AI','ML'}
def is_strong_letters(tok):
    if tok in LETTERS_EXCEPT: return False
    if re.fullmatch(r'[A-Z]{3,}', tok): return True
    if re.fullmatch(r'([A-Z]\.){3,}[A-Z]?\.?', tok): return True
    return False

# month name helpers for context and date normalization reused from support cell 9
months_full = ['january','february','march','april','may','june','july','august','september','october','november','december']
months_abbr = ['jan','feb','mar','apr','may','jun','jul','aug','sep','sept','oct','nov','dec','jan.','feb.','mar.','apr.','jun.','jul.','aug.','sep.','sept.','oct.','nov.','dec.']
month_set = set(months_full + months_abbr)

# Safety: robust number-to-words (handles large numbers by 3-digit groups) and safe ordinal/fraction
SCALE_WORDS = ['', 'thousand', 'million', 'billion', 'trillion', 'quadrillion', 'quintillion']
ONES_WORDS = ['zero','one','two','three','four','five','six','seven','eight','nine']
TEENS_WORDS = ['ten','eleven','twelve','thirteen','fourteen','fifteen','sixteen','seventeen','eighteen','nineteen']
TENS_WORDS = ['', '', 'twenty','thirty','forty','fifty','sixty','seventy','eighty','ninety']

def _two_digit_safe(n:int) -> str:
    if n < 10: return ONES_WORDS[n]
    if n < 20: return TEENS_WORDS[n-10]
    t, r = divmod(n, 10)
    return TENS_WORDS[t] + ('' if r==0 else ' ' + ONES_WORDS[r])

def _three_digit_safe(n:int) -> str:
    h, r = divmod(n, 100)
    if h == 0: return _two_digit_safe(r) if r else ''
    if r == 0: return ONES_WORDS[h] + ' hundred'
    return ONES_WORDS[h] + ' hundred ' + _two_digit_safe(r)

def int_to_words_safe(num_in) -> str:
    s = str(num_in).replace(',', '')
    sign = ''
    if s.startswith('-'): sign, s = 'minus ', s[1:]
    if not s.isdigit():
        try:
            s = str(int(float(s)))
        except:
            return ''
    n = int(s)
    if n == 0: return 'zero'
    parts = []
    idx = 0
    while n > 0 and idx < len(SCALE_WORDS):
        n, chunk = divmod(n, 1000)
        if chunk:
            words = _three_digit_safe(chunk)
            if SCALE_WORDS[idx]:
                parts.append(words + ' ' + SCALE_WORDS[idx])
            else:
                parts.append(words)
        idx += 1
    return (sign + ' '.join(reversed([p for p in parts if p]))).strip()

# Safety wrappers to avoid ordinal_word IndexError from older defs
def ordinal_word_safe(n):
    try:
        n = int(n)
    except:
        return None
    try:
        return _ordinalize_words(int_to_words(n))
    except NameError:
        words = int_to_words_safe(n)
        ws = words.split()
        if not ws:
            return words
        last = ws[-1]
        if last.endswith('y'):
            ws[-1] = last[:-1] + 'ieth'
        elif last in ('one','two','three','five','eight','nine','twelve'):
            repl = {'one':'first','two':'second','three':'third','five':'fifth','eight':'eighth','nine':'ninth','twelve':'twelfth'}
            ws[-1] = repl[last]
        else:
            ws[-1] = last + 'th'
        return ' '.join(ws)

def normalize_fraction_safe(tok):
    m = re.fullmatch(r'(\d+)/(\d+)', tok or '')
    if not m: return None
    a,b = int(m.group(1)), int(m.group(2))
    num = int_to_words_safe(a)
    den = ordinal_word_safe(b)
    if a == 1:
        spec = {2:'half',4:'quarter'}
        den_word = spec.get(b, den)
        return den_word
    else:
        if den.endswith('f'): den = den[:-1] + 'ves'
        elif den.endswith('y'): den = den[:-1] + 'ies'
        else: den = den + 's'
        return f"{num} {den}"

def normalize_cardinal_safe(tok):
    s = str(tok).replace(',', '')
    sign = 'minus ' if s.startswith('-') else ''
    if s.startswith('-'): s = s[1:]
    if not re.fullmatch(r'\d+', s):
        return None
    return (sign + int_to_words_safe(int(s))).strip()

def normalize_decimal_safe(tok):
    if tok.count('.') != 1: return None
    a,b = tok.split('.')
    a = a.replace(',', '')
    if a == '': a = '0'
    if not re.fullmatch(r'-?\d+', a) or not re.fullmatch(r'\d+', b):
        return None
    sign = 'minus ' if a.startswith('-') else ''
    if a.startswith('-'): a = a[1:]
    left = int_to_words_safe(int(a))
    right = ' '.join({'0':'zero','1':'one','2':'two','3':'three','4':'four','5':'five','6':'six','7':'seven','8':'eight','9':'nine'}[ch] for ch in b)
    return (sign + left + ' point ' + right).strip()

def apply_rules_safe(tok, cls):
    if cls == 'FRACTION':
        return normalize_fraction_safe(tok)
    if cls == 'CARDINAL':
        return normalize_cardinal_safe(tok)
    if cls == 'ORDINAL':
        m = re.fullmatch(r'(\d+)(st|nd|rd|th)', tok or '')
        return ordinal_word_safe(m.group(1)) if m else None
    if cls == 'DECIMAL':
        return normalize_decimal_safe(tok)
    # fall back to global apply_rules for other classes (from support cell 9)
    return apply_rules(tok, cls)

def build_maps(df):
    # returns maps + confidence for cls and memo variants
    d = df.dropna(subset=['before','after','class']).copy()
    # cls maps with confidence
    grp_cls = d.groupby(['before','class']).size().rename('cnt').reset_index()
    tot_cls = grp_cls.groupby('before')['cnt'].sum().rename('tot')
    top_cls = grp_cls.sort_values(['before','cnt'], ascending=[True, False]).drop_duplicates('before')
    top_cls = top_cls.merge(tot_cls, on='before', how='left')
    top_cls['conf'] = top_cls['cnt'] / top_cls['tot']
    cls_map_loc = dict(zip(top_cls['before'], top_cls['class']))
    cls_conf_map_loc = dict(zip(top_cls['before'], top_cls['conf']))
    d['_bl'] = d['before'].str.lower()
    grp_cls_l = d.groupby(['_bl','class']).size().rename('cnt').reset_index()
    tot_cls_l = grp_cls_l.groupby('_bl')['cnt'].sum().rename('tot')
    top_cls_l = grp_cls_l.sort_values(['_bl','cnt'], ascending=[True, False]).drop_duplicates('_bl')
    top_cls_l = top_cls_l.merge(tot_cls_l, on='_bl', how='left')
    top_cls_l['conf'] = top_cls_l['cnt'] / top_cls_l['tot']
    cls_map_lower_loc = dict(zip(top_cls_l['_bl'], top_cls_l['class']))
    cls_conf_map_lower_loc = dict(zip(top_cls_l['_bl'], top_cls_l['conf']))

    # (before, class) -> after (no confidence used here)
    pair_map_loc = d.groupby(['before','class'], sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()
    pair_map_lower_loc = d.groupby(['_bl','class'], sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()

    # memo maps with confidence
    grp_memo = d.groupby(['before','after']).size().rename('cnt').reset_index()
    tot_memo = grp_memo.groupby('before')['cnt'].sum().rename('tot')
    top_memo = grp_memo.sort_values(['before','cnt'], ascending=[True, False]).drop_duplicates('before')
    top_memo = top_memo.merge(tot_memo, on='before', how='left')
    top_memo['conf'] = top_memo['cnt'] / top_memo['tot']
    memo_map_loc = dict(zip(top_memo['before'], top_memo['after']))
    memo_conf_loc = dict(zip(top_memo['before'], top_memo['conf']))

    grp_memo_l = d.groupby(['_bl','after']).size().rename('cnt').reset_index()
    tot_memo_l = grp_memo_l.groupby('_bl')['cnt'].sum().rename('tot')
    top_memo_l = grp_memo_l.sort_values(['_bl','cnt'], ascending=[True, False]).drop_duplicates('_bl')
    top_memo_l = top_memo_l.merge(tot_memo_l, on='_bl', how='left')
    top_memo_l['conf'] = top_memo_l['cnt'] / top_memo_l['tot']
    memo_map_lower_loc = dict(zip(top_memo_l['_bl'], top_memo_l['after']))
    memo_conf_lower_loc = dict(zip(top_memo_l['_bl'], top_memo_l['conf']))

    return (cls_map_loc, cls_map_lower_loc, pair_map_loc, pair_map_lower_loc, memo_map_loc, memo_map_lower_loc,
            cls_conf_map_loc, cls_conf_map_lower_loc, memo_conf_loc, memo_conf_lower_loc)

def cascade_predict_tokens_with_maps(bef_series, maps, memo_conf_thr=0.70, cls_conf_thr=0.75):
    # Strategic order: pair_map -> guarded rules -> memo -> lower_memo -> identity -> final veto
    (cls_map_loc, cls_map_lower_loc, pair_map_loc, pair_map_lower_loc, memo_map_loc, memo_map_lower_loc,
     cls_conf_map_loc, cls_conf_map_lower_loc, memo_conf_loc, memo_conf_lower_loc) = maps
    bef = bef_series.fillna('')

    # Prior class and confidence
    prior = bef.map(cls_map_loc)
    prior_conf = bef.map(cls_conf_map_loc)
    need_lower = prior.isna()
    if need_lower.any():
        bef_l = bef.loc[need_lower].str.lower()
        prior.loc[need_lower] = bef_l.map(cls_map_lower_loc)
        prior_conf.loc[need_lower] = bef_l.map(cls_conf_map_lower_loc)

    # Stage 1: pair maps (gate by prior_conf >= 0.85)
    pred = pd.Series(index=bef.index, dtype=object)
    has_prior = prior.notna()
    if has_prior.any():
        conf_ok = prior_conf.fillna(0.0) >= 0.85
        idxp = prior.index[has_prior & conf_ok]
        if len(idxp) > 0:
            keys = list(zip(bef.loc[idxp].tolist(), prior.loc[idxp].tolist()))
            pred.loc[idxp] = [pair_map_loc.get(k) for k in keys]
            rem = pred.loc[idxp].isna()
            if rem.any():
                idxp2 = idxp[rem.values]
                keys2 = list(zip(bef.loc[idxp2].str.lower().tolist(), prior.loc[idxp2].tolist()))
                pred.loc[idxp2] = [pair_map_lower_loc.get(k) for k in keys2]

    # Stage 2: guarded rules with SAFE passthrough and relaxed LETTERS
    miss2 = pred.isna()
    if miss2.any():
        idx2 = pred.index[miss2]
        toks = bef.loc[idx2].tolist()
        prs = prior.loc[idx2].tolist()
        prc = prior_conf.loc[idx2].fillna(0.0).tolist()
        outs = []
        for tk, pr, pc in zip(toks, prs, prc):
            ic = infer_class(tk)  # infer independently
            # If prior is PLAIN/PUNCT, only skip for non-SAFE inferred classes
            if pr in ('PLAIN','PUNCT') and (ic not in SAFE_RULE_CLASSES):
                outs.append(None); continue
            # Special allowance: strong LETTERS acronyms (non word-like) may fire regardless of prior/pc
            special_letters = (ic == 'LETTERS' and is_strong_letters(tk) and not is_word_like(tk))
            # Guarded classes gating (unless special_letters)
            if ic in GUARDED_RULE_CLASSES and not special_letters:
                # word-like veto for guarded
                if is_word_like(tk):
                    outs.append(None); continue
                # confidence gating for guarded
                if pc < cls_conf_thr:
                    if ic == 'ELECTRONIC':
                        t = tk.strip()
                        cond_email = (t.count('@') == 1 and '.' in t[t.find('@')+1:])
                        cond_web = t.lower().startswith('www.') or re.search(r'/[A-Za-z0-9]', t)
                        if not (cond_email or cond_web):
                            outs.append(None); continue
                    if ic == 'TELEPHONE' and count_digits(tk) < 7:
                        outs.append(None); continue
            # apply rules
            if ic in SAFE_RULE_CLASSES:
                outs.append(apply_rules_safe(tk, ic))
            elif special_letters:
                outs.append(apply_rules_safe(tk, ic))
            elif ic in GUARDED_RULE_CLASSES and pr == ic:
                outs.append(apply_rules_safe(tk, ic))
            else:
                outs.append(None)
        pred.loc[idx2] = outs

    # Stage 3: memo with tighter, class-aware gating
    miss3 = pred.isna()
    if miss3.any():
        idx3 = pred.index[miss3]
        outs3 = []
        for i in idx3:
            tk = bef.loc[i]
            pr = prior.loc[i]
            pc = float(prior_conf.loc[i]) if prior_conf.loc[i] == prior_conf.loc[i] else 0.0
            if pr in ('PLAIN','PUNCT'):
                outs3.append(None); continue
            cand = memo_map_loc.get(tk)
            conf = float(memo_conf_loc.get(tk, 0.0))
            # dynamic threshold: stricter for guarded priors
            thr = 0.95 if pr in GUARDED_RULE_CLASSES else memo_conf_thr
            # extra veto for word-like tokens under guarded prior
            if pr in GUARDED_RULE_CLASSES and is_word_like(tk):
                outs3.append(None); continue
            if cand is not None and conf >= thr and pc >= cls_conf_thr:
                outs3.append(cand)
            else:
                outs3.append(None)
        pred.loc[idx3] = outs3

    # Stage 4: lowercase memo with tighter, class-aware gating
    miss4 = pred.isna()
    if miss4.any():
        idx4 = pred.index[miss4]
        outs4 = []
        for i in idx4:
            tk = bef.loc[i]
            pr = prior.loc[i]
            pc = float(prior_conf.loc[i]) if prior_conf.loc[i] == prior_conf.loc[i] else 0.0
            if pr in ('PLAIN','PUNCT'):
                outs4.append(None); continue
            tkl = str(tk).lower()
            cand2 = memo_map_lower_loc.get(tkl)
            conf2 = float(memo_conf_lower_loc.get(tkl, 0.0))
            thr2 = 0.96 if pr in GUARDED_RULE_CLASSES else memo_conf_thr
            if pr in GUARDED_RULE_CLASSES and is_word_like(tk):
                outs4.append(None); continue
            if cand2 is not None and conf2 >= thr2 and pc >= cls_conf_thr:
                outs4.append(cand2)
            else:
                outs4.append(None)
        pred.loc[idx4] = outs4

    # Stage 5: identity
    miss5 = pred.isna()
    if miss5.any():
        pred.loc[miss5] = bef.loc[miss5]

    # Stage 6: final vetoes
    # Revert casing-only changes on word-like tokens
    try:
        pred_l = pred.fillna('').astype(str).str.lower()
        bef_l = bef.fillna('').astype(str).str.lower()
        casing_only = (pred_l == bef_l) & (pred != bef) & bef.apply(is_word_like)
        if casing_only.any():
            pred.loc[casing_only] = bef.loc[casing_only]
    except Exception:
        pass

    return pred

# Context post-processing disabled (confirmed to slightly hurt CV)

print('Leak-free CV: building maps on train split only ...')
t0 = time.time()
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=2025)
groups = train['sentence_id'].values
tr_idx, va_idx = next(gss.split(train, groups=groups))
trn = train.iloc[tr_idx].copy()
val = train.iloc[va_idx].copy()
print('Split sizes:', trn.shape, val.shape)
tm = time.time()
maps = build_maps(trn)
print('Maps built (leak-free) in', f"{time.time()-tm:.2f}s")
print('Predicting validation with guarded cascade + confidence gating ...')
tp = time.time()
pred_val = cascade_predict_tokens_with_maps(val['before'], maps, memo_conf_thr=0.70, cls_conf_thr=0.75)
print('Val prediction elapsed:', f"{time.time()-tp:.2f}s")
acc = (pred_val.values == val['after'].values).mean()
print(f'Leak-free CV accuracy (guarded+conf): {acc:.6f}')
err_mask = pred_val.values != val['after'].values
print('Errors remaining:', int(err_mask.sum()))
print('Top error classes:')
print(val.loc[err_mask, 'class'].value_counts().head(15))
print('T3 elapsed:', f"{time.time()-t0:.2f}s")

Leak-free CV: building maps on train split only ...


Split sizes: (7140066, 5) (1784910, 5)


Maps built (leak-free) in 178.96s
Predicting validation with guarded cascade + confidence gating ...


Val prediction elapsed: 14.00s
Leak-free CV accuracy (guarded+conf): 0.991772
Errors remaining: 14687
Top error classes:
class
DATE          4962
PLAIN         3463
LETTERS       1780
CARDINAL       824
TELEPHONE      685
ELECTRONIC     652
DIGIT          635
MEASURE        603
ORDINAL        260
DECIMAL        234
MONEY          213
VERBATIM       186
TIME           106
ADDRESS         53
FRACTION        31
Name: count, dtype: int64
T3 elapsed: 195.00s


In [39]:
# T4: Context-aware post-processing (month + day ordinals, year in dates) and new submission
import time, re
t0 = time.time()

months_full = ['january','february','march','april','may','june','july','august','september','october','november','december']
months_abbr = ['jan','feb','mar','apr','may','jun','jul','aug','sep','sept','oct','nov','dec']
month_set = set(months_full + months_abbr)

def is_int_str(s):
    return re.fullmatch(r'\d+', s) is not None

def context_postprocess(df_tokens, pred_series):
    # df_tokens must have columns: sentence_id, token_id, before
    pred = pred_series.copy()
    df = df_tokens[['sentence_id','token_id','before']].copy()
    df['pred'] = pred.values
    df['before_l'] = df['before'].fillna('').astype(str)
    df['before_l'] = df['before_l']
    # Build previous tokens within sentence
    df = df.sort_values(['sentence_id','token_id'])
    df['prev_before'] = df.groupby('sentence_id')['before_l'].shift(1)
    df['prev2_before'] = df.groupby('sentence_id')['before_l'].shift(2)
    df['prev_pred'] = df.groupby('sentence_id')['pred'].shift(1)

    # Rule 1: Day-of-month ordinals when preceded by a month name/abbr
    # If current token is integer 1..31 and previous token is a month token
    def day_to_ordinal_safe(x):
        try:
            v = int(x)
            if 1 <= v <= 31:
                return ordinal_word(v)
        except:
            return None
        return None

    mask_day = df['before_l'].str.fullmatch(r'\d{1,2}')
    prev_is_month = df['prev_before'].str.lower().isin(month_set)
    m1 = mask_day & prev_is_month
    if m1.any():
        df.loc[m1, 'pred'] = df.loc[m1, 'before_l'].apply(day_to_ordinal_safe)

    # Rule 2: Year words when part of date context: if token is 4-digit year and prev token is month or prev2 is month with prev comma
    def is_year_token(s):
        if re.fullmatch(r'\d{4}', s):
            y = int(s)
            return 1000 <= y <= 2099
        return False

    mask_year = df['before_l'].apply(is_year_token)
    prev_is_comma = df['prev_before'] == ','
    prev_is_month2 = df['prev2_before'].str.lower().isin(month_set)
    prev_is_month1 = df['prev_before'].str.lower().isin(month_set)
    m2 = mask_year & (prev_is_month1 | (prev_is_comma & prev_is_month2))
    if m2.any():
        df.loc[m2, 'pred'] = df.loc[m2, 'before_l'].apply(year_to_words)

    # Return updated prediction aligned to original index order
    df = df.sort_index()
    return df['pred']

print('Generating base predictions with cascade v1 ...')
pred_test_base = cascade_predict_tokens_v1(test['before'])
print('Applying context post-processing ...')
pred_test_ctx = context_postprocess(test[['sentence_id','token_id','before']], pred_test_base)

test_ids = test['sentence_id'].astype(str) + '_' + test['token_id'].astype(str)
sub_ctx = pd.DataFrame({'id': test_ids, 'after': pred_test_ctx.astype(str)})
sub_ctx.to_csv('submission.csv', index=False)
print('Context-enhanced submission.csv written')
print('T4 elapsed:', f'{time.time()-t0:.2f}s')

Generating base predictions with cascade v1 ...


Applying context post-processing ...


Context-enhanced submission.csv written
T4 elapsed: 6.50s


In [20]:
# T5: Add context-aware memo (prev/before/next) to reduce PLAIN errors; infer on test and save submission
import time
t0 = time.time()
print('Building context-aware maps from full train ...')
tr_ctx = train.dropna(subset=['before','after']).copy()
tr_ctx = tr_ctx.sort_values(['sentence_id','token_id'])
tr_ctx['prev_before'] = tr_ctx.groupby('sentence_id')['before'].shift(1).fillna('')
tr_ctx['next_before'] = tr_ctx.groupby('sentence_id')['before'].shift(-1).fillna('')

# (prev, before) -> after
prev_map = tr_ctx.groupby(['prev_before','before'], sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()
# (before, next) -> after
next_map = tr_ctx.groupby(['before','next_before'], sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()
# lowercase backoffs
tr_ctx['_prev_l'] = tr_ctx['prev_before'].str.lower()
tr_ctx['_bef_l'] = tr_ctx['before'].str.lower()
tr_ctx['_next_l'] = tr_ctx['next_before'].str.lower()
prev_map_lower = tr_ctx.groupby(['_prev_l','_bef_l'], sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()
next_map_lower = tr_ctx.groupby(['_bef_l','_next_l'], sort=False)['after'].agg(lambda s: s.value_counts().idxmax()).to_dict()
print('Context maps sizes | prev:', len(prev_map), 'next:', len(next_map))

def cascade_predict_tokens_v2(df_tokens):
    # df_tokens must have sentence_id, token_id, before
    df = df_tokens[['sentence_id','token_id','before']].copy()
    df = df.sort_values(['sentence_id','token_id'])
    bef = df['before'].fillna('')
    # Stage 1: exact memo
    pred = bef.map(memo_map)
    miss = pred.isna()
    if miss.any():
        idx = pred.index[miss]
        sub_df = df.loc[idx].copy()
        sub_df['prev_before'] = sub_df.groupby('sentence_id')['before'].shift(1).fillna('')
        sub_df['next_before'] = sub_df.groupby('sentence_id')['before'].shift(-1).fillna('')
        # prev map exact
        keys_prev = list(zip(sub_df['prev_before'].tolist(), sub_df['before'].tolist()))
        vals_prev = [prev_map.get(k) for k in keys_prev]
        pred.loc[idx] = vals_prev
        # fill remaining via next map exact
        miss2 = pred.loc[idx].isna()
        if miss2.any():
            idx2 = idx[miss2.values]
            sub2 = df.loc[idx2]
            sub2_prev = sub2.groupby('sentence_id')['before'].shift(1).fillna('')
            sub2_next = sub2.groupby('sentence_id')['before'].shift(-1).fillna('')
            keys_next = list(zip(sub2['before'].tolist(), sub2_next.tolist()))
            pred.loc[idx2] = [next_map.get(k) for k in keys_next]
        # lowercase backoffs for context
        miss3 = pred.loc[idx].isna()
        if miss3.any():
            idx3 = idx[miss3.values]
            sub3 = df.loc[idx3]
            prev_l = sub3.groupby('sentence_id')['before'].shift(1).fillna('').str.lower()
            bef_l = sub3['before'].str.lower()
            next_l = sub3.groupby('sentence_id')['before'].shift(-1).fillna('').str.lower()
            keys_prev_l = list(zip(prev_l.tolist(), bef_l.tolist()))
            vals_prev_l = [prev_map_lower.get(k) for k in keys_prev_l]
            pred.loc[idx3] = vals_prev_l
            miss4 = pred.loc[idx3].isna()
            if miss4.any():
                idx4 = idx3[miss4.values]
                sub4 = df.loc[idx4]
                bef4_l = sub4['before'].str.lower()
                next4_l = sub4.groupby('sentence_id')['before'].shift(-1).fillna('').str.lower()
                keys_next_l = list(zip(bef4_l.tolist(), next4_l.tolist()))
                pred.loc[idx4] = [next_map_lower.get(k) for k in keys_next_l]
    # Stage 2: class-aware pair/rules (reuse earlier infer/rules with prior from cls_map) for remaining
    missA = pred.isna()
    if missA.any():
        miss_idx = missA[missA].index
        sub_bef = bef.loc[miss_idx]
        prior = sub_bef.map(cls_map)
        need_lower = prior.isna()
        if need_lower.any():
            prior.loc[need_lower] = sub_bef.loc[need_lower].str.lower().map(cls_map_lower)
        # (before, class) maps
        key_exact = list(zip(sub_bef.tolist(), prior.tolist()))
        vals = [pair_map.get(k) for k in key_exact]
        out2 = pd.Series(vals, index=miss_idx, dtype=object)
        rem = out2.isna()
        if rem.any():
            idxr = out2.index[rem]
            keys2 = list(zip(sub_bef.loc[idxr].str.lower().tolist(), prior.loc[idxr].tolist()))
            out2.loc[idxr] = [pair_map_lower.get(k) for k in keys2]
        rem2 = out2.isna()
        if rem2.any():
            idx2 = out2.index[rem2]
            toks = sub_bef.loc[idx2].tolist()
            prs = prior.loc[idx2].tolist()
            finals = []
            for tk, pr in zip(toks, prs):
                ic = infer_class(tk, pr)
                finals.append(apply_rules(tk, ic) if ic not in ('PLAIN','PUNCT', None) else None)
            out2.loc[idx2] = finals
        pred.loc[miss_idx] = out2.values
    # Stage 3: lowercase memo
    missB = pred.isna()
    if missB.any():
        pred.loc[missB] = bef.loc[missB].str.lower().map(memo_map_lower)
    # Stage 4: identity
    missC = pred.isna()
    if missC.any():
        pred.loc[missC] = bef.loc[missC]
    return pred

print('Inferring on test with context-aware cascade v2 ...')
pred_test_v2 = cascade_predict_tokens_v2(test[['sentence_id','token_id','before']])
test_ids = test['sentence_id'].astype(str) + '_' + test['token_id'].astype(str)
submission_ctx = pd.DataFrame({'id': test_ids, 'after': pred_test_v2.astype(str)})
submission_ctx.to_csv('submission.csv', index=False)
print('Wrote submission.csv with context-aware memo v2')
print('T5 elapsed:', f'{time.time()-t0:.2f}s')

In [27]:
# T3-support: Lightweight rule and detector definitions (no heavy map building)
import re

digit_words = {'0':'zero','1':'one','2':'two','3':'three','4':'four','5':'five','6':'six','7':'seven','8':'eight','9':'nine'}
ones = ['zero','one','two','three','four','five','six','seven','eight','nine']
teens = ['ten','eleven','twelve','thirteen','fourteen','fifteen','sixteen','seventeen','eighteen','nineteen']
tens_words = ['', '', 'twenty','thirty','forty','fifty','sixty','seventy','eighty','ninety']

def two_digit(n):
    if n < 10: return ones[n]
    if n < 20: return teens[n-10]
    t, r = divmod(n, 10)
    return tens_words[t] + ('' if r==0 else ' ' + ones[r])

def three_digit(n):
    h, r = divmod(n, 100)
    if h == 0: return two_digit(r)
    if r == 0: return ones[h] + ' hundred'
    return ones[h] + ' hundred ' + two_digit(r)

def int_to_words(num):
    num = int(num)
    if num == 0: return 'zero'
    parts = []
    billions, rem = divmod(num, 10**9)
    millions, rem = divmod(rem, 10**6)
    thousands, rem = divmod(rem, 1000)
    if billions: parts.append(three_digit(billions) + ' billion')
    if millions: parts.append(three_digit(millions) + ' million')
    if thousands: parts.append(three_digit(thousands) + ' thousand')
    if rem: parts.append(three_digit(rem))
    return ' '.join(parts)

def _ordinalize_words(words: str) -> str:
    irregular = {
        'one':'first','two':'second','three':'third','four':'fourth','five':'fifth','six':'sixth','seven':'seventh','eight':'eighth','nine':'ninth',
        'ten':'tenth','eleven':'eleventh','twelve':'twelfth','thirteen':'thirteenth','fourteen':'fourteenth','fifteen':'fifteenth','sixteen':'sixteenth','seventeen':'seventeenth','eighteen':'eighteenth','nineteen':'nineteenth',
        'twenty':'twentieth','thirty':'thirtieth','forty':'fortieth','fifty':'fiftieth','sixty':'sixtieth','seventy':'seventieth','eighty':'eightieth','ninety':'ninetieth'
    }
    scale = {'hundred','thousand','million','billion'}
    ws = words.split()
    if not ws:
        return words
    if ws[-1] in scale:
        ws[-1] = ws[-1] + 'th'
        return ' '.join(ws)
    last = ws[-1]
    if last in irregular:
        ws[-1] = irregular[last]
    elif last.endswith('y'):
        ws[-1] = last[:-1] + 'ieth'
    else:
        ws[-1] = last + 'th'
    return ' '.join(ws)

def ordinal_word(n):
    n = int(n)
    # Robust ordinalization for any positive integer
    if n <= 0:
        return int_to_words(n)  # fallback
    return _ordinalize_words(int_to_words(n))

def normalize_decimal(tok):
    if tok.count('.') != 1: return None
    a,b = tok.split('.')
    a = a.replace(',', '')
    if a == '': a = '0'
    if not re.fullmatch(r'-?\d+', a) or not re.fullmatch(r'\d+', b):
        return None
    sign = 'minus ' if a.startswith('-') else ''
    if a.startswith('-'): a = a[1:]
    left = int_to_words(int(a))
    right = ' '.join(digit_words[ch] for ch in b)
    return (sign + left + ' point ' + right).strip()

month_names = {str(i).zfill(2): m for i,m in enumerate(['','january','february','march','april','may','june','july','august','september','october','november','december']) if i}
month_names.update({str(i): m for i,m in enumerate(['','january','february','march','april','may','june','july','august','september','october','november','december']) if i})

# Month name handling (full and abbreviations, with optional dot)
MONTH_CANON = ['january','february','march','april','may','june','july','august','september','october','november','december']
MONTH_ABBR = {'jan':'january','feb':'february','mar':'march','apr':'april','may':'may','jun':'june','jul':'july','aug':'august','sep':'september','sept':'september','oct':'october','nov':'november','dec':'december'}
WEEKDAY = ['monday','tuesday','wednesday','thursday','friday','saturday','sunday']
MONTH_ALT_PATTERN = r'(?:jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec|january|february|march|april|june|july|august|september|october|november|december)\.?'

def canon_month(s):
    s0 = s.strip().lower().rstrip('.')
    if s0 in MONTH_CANON: return s0
    if s0 in MONTH_ABBR: return MONTH_ABBR[s0]
    return None

def year_to_words(y):
    y = int(y)
    if y < 1000: return int_to_words(y)
    if 1900 <= y <= 1999:
        h, t = divmod(y, 100)
        return int_to_words(h) + ' ' + two_digit(t)
    if 2000 <= y <= 2009:
        return 'two thousand' if y == 2000 else 'two thousand ' + int_to_words(y-2000)
    if 2010 <= y <= 2099:
        return 'two thousand ' + int_to_words(y-2000)
    return int_to_words(y)

def decade_to_words(y):
    y = int(y)
    century = y // 100
    decade = (y % 100) // 10 * 10
    if 1900 <= y <= 1990:
        return int_to_words(century) + ' ' + two_digit(decade).replace('y', 'ies')
    if 2000 <= y <= 2090:
        if decade == 0:
            return 'two thousands'
        return 'two thousand ' + two_digit(decade).replace('y', 'ies')
    return int_to_words(y)

def _unambiguous_mdy(d1, d2, y):
    try:
        a = int(d1); b = int(d2)
        y_int = int(y)
    except:
        return None
    if len(y) == 2:
        y4 = int('20'+y) if int(y) <= 30 else int('19'+y)
    else:
        y4 = y_int
    if a <= 12 and b <= 12:
        return None
    if a <= 12 and 1 <= b <= 31:
        return (a, b, y4)
    if b <= 12 and 1 <= a <= 31:
        return (b, a, y4)
    return None

def normalize_date(tok):
    # YYYY-MM-DD
    if re.fullmatch(r'\d{4}-\d{1,2}-\d{1,2}', tok):
        y,m,d = tok.split('-')
        mname = month_names.get(m.zfill(2))
        if not mname: return None
        return f"{mname} {ordinal_word(int(d))} {year_to_words(y)}"
    # YYYY/MM/DD
    if re.fullmatch(r'\d{4}/\d{1,2}/\d{1,2}', tok):
        y,m,d = tok.split('/')
        mname = month_names.get(m.zfill(2))
        if not mname: return None
        return f"{mname} {ordinal_word(int(d))} {year_to_words(y)}"
    # DD-MM-YYYY or MM-DD-YYYY (unambiguous guard)
    if re.fullmatch(r'\d{1,2}-\d{1,2}-\d{4}', tok):
        a,b,y = tok.split('-')
        if int(a) <= 12 and int(b) <= 12:
            return None
        if int(a) <= 12:
            m,d = a,b
        else:
            d,m = a,b
        mname = month_names.get(m.zfill(2))
        if not mname: return None
        return f"{mname} {ordinal_word(int(d))} {year_to_words(y)}"
    # M/D/YY or M/D/YYYY with ambiguity guard
    if re.fullmatch(r'\d{1,2}/\d{1,2}/\d{2,4}', tok):
        a,b,y = tok.split('/')
        parsed = _unambiguous_mdy(a,b,y)
        if not parsed: return None
        m,d,y4 = parsed
        mname = month_names.get(str(m).zfill(2))
        if not mname: return None
        return f"{mname} {ordinal_word(int(d))} {year_to_words(y4)}"
    # YYYY.MM.DD or DD.MM.YYYY
    if re.fullmatch(r'\d{4}\.\d{1,2}\.\d{1,2}', tok):
        y,m,d = tok.split('.')
        mname = month_names.get(m.zfill(2))
        if not mname: return None
        return f"{mname} {ordinal_word(int(d))} {year_to_words(y)}"
    if re.fullmatch(r'\d{1,2}\.\d{1,2}\.\d{2,4}', tok):
        a,b,y = tok.split('.')
        parsed = _unambiguous_mdy(a,b,y)
        if not parsed: return None
        m,d,y4 = parsed
        mname = month_names.get(str(m).zfill(2))
        if not mname: return None
        return f"{mname} {ordinal_word(int(d))} {year_to_words(y4)}"
    # Month name patterns
    s = tok.strip()
    m = re.fullmatch(r'(?:(?:Mon|Tue|Tues|Wed|Thu|Thur|Thurs|Fri|Sat|Sun|Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday),?\s+)?(' + MONTH_ALT_PATTERN + r')\s+(\d{1,2})(?:,)?(?:\s+(\d{4}))?$', s, re.IGNORECASE)
    if m:
        mon_raw, d, y = m.group(1), m.group(2), m.group(3)
        mon = canon_month(mon_raw)
        if not mon: return None
        out = f"{mon} {ordinal_word(int(d))}"
        if y: out += ' ' + year_to_words(y)
        return out
    m = re.fullmatch(r'^(\d{1,2})\s+(' + MONTH_ALT_PATTERN + r')(?:,)?(?:\s+(\d{4}))?$', s, re.IGNORECASE)
    if m:
        d, mon_raw, y = m.group(1), m.group(2), m.group(3)
        mon = canon_month(mon_raw)
        if not mon: return None
        out = f"{mon} {ordinal_word(int(d))}"
        if y: out += ' ' + year_to_words(y)
        return out
    m = re.fullmatch(r'(\d{4})s', tok)
    if m:
        y = int(m.group(1))
        base = decade_to_words(y)
        return base
    return None

LETTERS_EXCEPT = {'I','A','AM','PM','US','U.S.','U.S','UK','U.K.','UK.','EU','U.N.','UN','TV','OK','OK.','AI','ML','CEO'}

def normalize_letters(tok):
    if tok in LETTERS_EXCEPT:
        return None
    if re.fullmatch(r'[A-Z]{3,}', tok):
        return ' '.join(list(tok.lower()))
    if re.fullmatch(r'([A-Z]\.){3,}[A-Z]?\.?', tok):
        return ' '.join(ch.lower() for ch in tok if ch.isalpha())
    if re.fullmatch(r'[A-Z]+-[A-Z]+', tok):
        letters = tok.replace('-', '')
        return ' '.join(list(letters.lower()))
    if 'AT&T' in tok:
        return 'a t and t'
    return None

def normalize_digit(tok):
    if not re.fullmatch(r'\d+', tok): return None
    return ' '.join(digit_words[ch] for ch in tok)

def normalize_cardinal(tok):
    s = tok.replace(',', '')
    if not re.fullmatch(r'-?\d+', s): return None
    sign = 'minus ' if s.startswith('-') else ''
    if s.startswith('-'): s = s[1:]
    return (sign + int_to_words(int(s))).strip()

def normalize_ordinal(tok):
    m = re.fullmatch(r'(\d+)(st|nd|rd|th)', tok)
    if not m: return None
    return ordinal_word(m.group(1))

def normalize_fraction(tok):
    m = re.fullmatch(r'(\d+)/(\d+)', tok)
    if not m: return None
    a,b = int(m.group(1)), int(m.group(2))
    num = int_to_words(a)
    den = ordinal_word(b)
    if a == 1:
        spec = {2:'half',4:'quarter'}
        den_word = spec.get(b, den)
        return den_word
    else:
        if den.endswith('f'): den = den[:-1] + 'ves'
        elif den.endswith('y'): den = den[:-1] + 'ies'
        else: den = den + 's'
        return f"{num} {den}"

unit_map = {
    'ft': ('foot','feet'), 'in': ('inch','inches'), 'yd': ('yard','yards'), 'mi': ('mile','miles'),
    'lb': ('pound','pounds'), 'lbs': ('pound','pounds'), 'oz': ('ounce','ounces'),
    'kg': ('kilogram','kilograms'), 'g': ('gram','grams'), 'mg': ('milligram','milligrams'),
    'km': ('kilometer','kilometers'), 'm': ('meter','meters'), 'cm': ('centimeter','centimeters'), 'mm': ('millimeter','millimeters'),
    'l': ('liter','liters'), 'ml': ('milliliter','milliliters'),
    'mph': ('miles per hour','miles per hour'), 'km/h': ('kilometers per hour','kilometers per hour'), 'kph': ('kilometers per hour','kilometers per hour'),
    '%': ('percent','percent'), '\u00b0c': ('degrees celsius','degrees celsius'), '\u00b0f': ('degrees fahrenheit','degrees fahrenheit')
}

def normalize_measure(tok):
    t = tok.lower()
    m = re.fullmatch(r'(-?[\d,]+(?:\.\d+)?)[\s]*([a-z\u00b0/%]+(?:/[a-z]+)?)', t)
    if not m:
        m = re.fullmatch(r'(-?[\d,]+)([a-z\u00b0/%]+)', t)
    if not m: return None
    num_s, unit = m.group(1), m.group(2)
    unit = unit.strip()
    singular_plural = unit_map.get(unit)
    if not singular_plural:
        return None
    if '.' in num_s:
        spoken_num = normalize_decimal(num_s)
    else:
        spoken_num = normalize_cardinal(num_s)
    if spoken_num is None: return None
    try:
        val = float(num_s.replace(',', ''))
    except:
        val = None
    unit_word = singular_plural[0] if val == 1 else singular_plural[1]
    return f"{spoken_num} {unit_word}"

def normalize_money(tok):
    m = re.fullmatch(r'([$\u00a3\u20ac])\s*(\d*[\d,]*)(?:\.(\d{1,2}))?', tok)
    if not m: return None
    sym, a, c = m.groups()
    a = (a or '').replace(',', '')
    cur = {'$':'dollar','\u00a3':'pound','\u20ac':'euro'}[sym]
    out = []
    have_dollars = bool(a) and int(a) > 0
    cents_val = None
    if c is not None:
        cents_val = int(c.ljust(2,'0'))
    if have_dollars:
        out.append(int_to_words(int(a)) + (' ' + cur + ('s' if int(a)!=1 else '')))
    if cents_val is not None and cents_val > 0:
        out.append(int_to_words(cents_val) + (' cent' + ('s' if cents_val!=1 else '')))
    if not out and (cents_val is not None and cents_val > 0):
        return int_to_words(cents_val) + (' cent' + ('s' if cents_val!=1 else ''))
    if not out:
        return None
    return ' and '.join(out) if len(out) == 2 else out[0]

def normalize_time(tok):
    m = re.fullmatch(r'(\d{1,2}):(\d{2})', tok)
    if not m: return None
    hh = int(m.group(1)); mm = int(m.group(2))
    hour = int_to_words(hh)
    if mm == 0: return hour + " o'clock"
    if 1 <= mm <= 9:
        return hour + ' oh ' + int_to_words(mm)
    return hour + ' ' + int_to_words(mm)

def normalize_telephone(tok):
    t = tok
    digits = re.sub(r'\D', '', t)
    if len(digits) < 7: return None
    spoken = ' '.join(digit_words[ch] for ch in digits)
    return spoken

def normalize_electronic(tok):
    t = tok.strip()
    # emails: require exactly one '@' and at least one dot after '@'
    if t.count('@') == 1:
        at_idx = t.find('@')
        if '.' in t[at_idx+1:]:
            parts = re.split(r'(@|\.)', t)
            spoken = []
            for p in parts:
                if p == '@': spoken.append('at')
                elif p == '.': spoken.append('dot')
                else: spoken.append(p)
            return ' '.join(spoken).replace('  ', ' ').strip()
    # urls/domains: require www. prefix or a slash with alnum after
    if t.lower().startswith('www.') or re.search(r'/[A-Za-z0-9]', t):
        rep = t.replace('.', ' dot ').replace('/', ' slash ').replace('-', ' dash ')
        rep = re.sub(r'\bwww\b', 'w w w', rep)
        return ' '.join(rep.split())
    return None

# Detectors (tightened LETTERS and DATE ambiguity guard, add month-name dates)
pat_decimal = re.compile(r'^-?[\d,]+\.\d+$')
pat_cardinal = re.compile(r'^-?[\d,]+$')
pat_ordinal = re.compile(r'^\d+(st|nd|rd|th)$')
pat_fraction = re.compile(r'^\d+/\d+$')
pat_date_y_m_d_dash = re.compile(r'^\d{4}-\d{1,2}-\d{1,2}$')
pat_date_y_m_d_slash = re.compile(r'^\d{4}/\d{1,2}/\d{1,2}$')
pat_date_two_sep = re.compile(r'^(\d{1,2})[/-](\d{1,2})[/-](\d{2,4})$')
pat_date_dot_y_m_d = re.compile(r'^\d{4}\.\d{1,2}\.\d{1,2}$')
pat_date_dot_two_sep = re.compile(r'^(\d{1,2})\.(\d{1,2})\.(\d{2,4})$')
pat_decade = re.compile(r'^\d{4}s$')
pat_letters_strong = re.compile(r'^(?:[A-Z]{3,}|(?:[A-Z]\.){3,}[A-Z]?\.?|[A-Z]+-[A-Z]+)$')
pat_phone = re.compile(r'^[+\d][\d\s().-]{6,}$')  # at least ~7 digits overall
pat_money = re.compile(r'^[$\u00a3\u20ac]')
pat_time = re.compile(r'^\d{1,2}:\d{2}$')
pat_measure = re.compile(r'^-?[\d,]+(?:\.\d+)?\s?[a-zA-Z\u00b0/%]+')
# Month-name date detectors
pat_month_name_1 = re.compile(r'(?:(?:Mon|Tue|Tues|Wed|Thu|Thur|Thurs|Fri|Sat|Sun|Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday),?\s+)?(' + MONTH_ALT_PATTERN + r')\s+\d{1,2}(?:,)?(?:\s+\d{4})?$', re.IGNORECASE)
pat_month_name_2 = re.compile(r'^\d{1,2}\s+(' + MONTH_ALT_PATTERN + r')(?:,)?(?:\s+\d{4})?$', re.IGNORECASE)

def infer_class(tok, prior=None):
    if prior: return prior
    if pat_money.search(tok): return 'MONEY'
    if pat_time.fullmatch(tok): return 'TIME'
    if pat_fraction.fullmatch(tok): return 'FRACTION'
    if pat_decimal.fullmatch(tok) or tok.startswith('.') and re.fullmatch(r'\.\d+', tok): return 'DECIMAL'
    if pat_ordinal.fullmatch(tok): return 'ORDINAL'
    if pat_cardinal.fullmatch(tok): return 'CARDINAL'
    # DATE: unambiguous numeric forms or decades
    if pat_date_y_m_d_dash.fullmatch(tok) or pat_date_y_m_d_slash.fullmatch(tok) or pat_date_dot_y_m_d.fullmatch(tok):
        return 'DATE'
    m = pat_date_two_sep.fullmatch(tok)
    if m:
        a,b,y = m.groups()
        if _unambiguous_mdy(a,b,y):
            return 'DATE'
    m = pat_date_dot_two_sep.fullmatch(tok)
    if m:
        a,b,y = m.groups()
        if _unambiguous_mdy(a,b,y):
            return 'DATE'
    if pat_decade.fullmatch(tok):
        return 'DATE'
    # Month-name dates
    if pat_month_name_1.fullmatch(tok) or pat_month_name_2.fullmatch(tok):
        return 'DATE'
    # LETTERS strong only, with blacklist exceptions
    if tok in LETTERS_EXCEPT:
        pass
    else:
        if pat_letters_strong.fullmatch(tok):
            return 'LETTERS'
    # Telephone (require at least ~7 digits total)
    if pat_phone.fullmatch(tok):
        return 'TELEPHONE'
    # Electronic tightened
    t = tok.strip()
    if (t.count('@') == 1 and '.' in t[t.find('@')+1:]) or t.lower().startswith('www.') or re.search(r'/[A-Za-z0-9]', t):
        return 'ELECTRONIC'
    if pat_measure.match(tok.lower()): return 'MEASURE'
    return None

def apply_rules(tok, cls):
    if cls == 'DATE': return normalize_date(tok)
    if cls == 'LETTERS': return normalize_letters(tok)
    if cls == 'DECIMAL': return normalize_decimal(tok)
    if cls == 'CARDINAL': return normalize_cardinal(tok)
    if cls == 'ORDINAL': return normalize_ordinal(tok)
    if cls == 'DIGIT': return normalize_digit(tok)
    if cls == 'MEASURE': return normalize_measure(tok)
    if cls == 'TELEPHONE': return normalize_telephone(tok)
    if cls == 'ELECTRONIC': return normalize_electronic(tok)
    if cls == 'MONEY': return normalize_money(tok)
    if cls == 'TIME': return normalize_time(tok)
    if cls == 'FRACTION': return normalize_fraction(tok)
    return None

In [52]:
# T6: Targeted error analysis for DATE / LETTERS / CARDINAL on leak-free CV
import re, time
t0 = time.time()
assert 'val' in globals() and 'pred_val' in globals(), 'Run Cell 6 (CV) first to populate val/pred_val'

df_err = val.copy()
df_err = df_err.assign(pred=pred_val.values)
mask_err = df_err['pred'] != df_err['after']
df_err = df_err.loc[mask_err, ['sentence_id','token_id','class','before','after','pred']].copy()
df_err = df_err.sort_values(['sentence_id','token_id'])
df_err['prev_before'] = df_err.groupby('sentence_id')['before'].shift(1)
df_err['next_before'] = df_err.groupby('sentence_id')['before'].shift(-1)
df_err['prev2_before'] = df_err.groupby('sentence_id')['before'].shift(2)

print('Error counts (top 10):')
print(df_err['class'].value_counts().head(10))

def preview(cls_name, n=10):
    print(f"\n--- {cls_name} sample errors ---")
    display(df_err[df_err['class']==cls_name].head(n))

preview('DATE', 12)
preview('LETTERS', 12)
preview('CARDINAL', 12)

# DATE diagnostics
date_err = df_err[df_err['class']=='DATE'].copy()
date_err['before_s'] = date_err['before'].astype(str)
pat_slash = re.compile(r'^\d{1,4}/\d{1,2}/\d{1,4}$')
pat_dash = re.compile(r'^\d{1,4}-\d{1,2}-\d{1,4}$')
pat_dot  = re.compile(r'^\d{1,4}\.\d{1,2}\.\d{1,4}$')
pat_decade = re.compile(r'^\d{4}s$')
def bucket_date(s):
    if pat_decade.fullmatch(s): return 'decade'
    if pat_slash.fullmatch(s): return 'slash'
    if pat_dash.fullmatch(s): return 'dash'
    if pat_dot.fullmatch(s): return 'dot'
    return 'other'
date_err['bucket'] = date_err['before_s'].apply(bucket_date)
print('\nDATE error buckets:')
print(date_err['bucket'].value_counts())

# Ambiguous M/D/Y vs D/M/Y check (both parts <=12)
def ambiguous_mdy(s):
    m = re.fullmatch(r'(?:(\d{1,2})[/-](\d{1,2})[/-](\d{2,4}))', s or '')
    if not m: return False
    a,b,y = m.groups()
    try:
        a=int(a); b=int(b)
    except:
        return False
    return a<=12 and b<=12
date_err['ambig_mdy'] = date_err['before_s'].apply(ambiguous_mdy)
print('\nDATE ambiguous M/D/Y count:', int(date_err['ambig_mdy'].sum()))

# Month context around 4-digit years in errors
month_tokens = set(['january','february','march','april','may','june','july','august','september','october','november','december',
                    'jan','feb','mar','apr','jun','jul','aug','sep','sept','oct','nov','dec','jan.','feb.','mar.','apr.','jun.','jul.','aug.','sep.','sept.','oct.','nov.','dec.'])
def is_year4(s):
    return bool(re.fullmatch(r'\d{4}', str(s) or ''))
yr_err = date_err[date_err['before'].apply(is_year4)].copy()
yr_err['prev_is_month'] = yr_err['prev_before'].str.lower().isin(month_tokens)
yr_err['prev2_is_month'] = yr_err['prev2_before'].str.lower().isin(month_tokens)
print('\nDATE 4-digit year errors with month context:')
print('prev month:', int(yr_err['prev_is_month'].sum()), '| prev2 month:', int(yr_err['prev2_is_month'].sum()))

# LETTERS diagnostics
let_err = df_err[df_err['class']=='LETTERS'].copy()
def letters_bucket(s):
    s2 = str(s or '')
    if re.fullmatch(r'[A-Z]{3,}', s2): return 'ALLCAPS>=3'
    if re.fullmatch(r'([A-Z]\.){3,}[A-Z]?\.?', s2): return 'DOTTED>=3'
    if re.fullmatch(r'[A-Z]+-[A-Z]+', s2): return 'HYPHEN_CAPS'
    if re.fullmatch(r'[A-Za-z]+', s2) and any(c.islower() for c in s2): return 'wordlike_mixed'
    return 'other'
let_err['bucket'] = let_err['before'].apply(letters_bucket)
print('\nLETTERS error buckets:')
print(let_err['bucket'].value_counts())

# CARDINAL diagnostics
card_err = df_err[df_err['class']=='CARDINAL'].copy()
def is_four_digit_yearish(s):
    if not re.fullmatch(r'\d{4}', str(s) or ''): return False
    v = int(s)
    return 1000 <= v <= 2099
card_err['four_digit_yearish'] = card_err['before'].apply(is_four_digit_yearish)
print('\nCARDINAL 4-digit yearish errors:', int(card_err['four_digit_yearish'].sum()))
card_yr = card_err[card_err['four_digit_yearish']].copy()
card_yr['prev_is_month'] = card_yr['prev_before'].str.lower().isin(month_tokens)
card_yr['prev2_is_month'] = card_yr['prev2_before'].str.lower().isin(month_tokens)
print('...with month context -> prev:', int(card_yr['prev_is_month'].sum()), 'prev2:', int(card_yr['prev2_is_month'].sum()))

print('\nT6 elapsed:', f"{time.time()-t0:.2f}s")

Error counts (top 10):
class
PLAIN         6606
DATE          4921
LETTERS       3546
CARDINAL      1889
MEASURE       1127
TELEPHONE      682
ELECTRONIC     652
MONEY          621
DECIMAL        577
DIGIT          527
Name: count, dtype: int64

--- DATE sample errors ---


Unnamed: 0,sentence_id,token_id,class,before,after,pred,prev_before,next_before,prev2_before
454,34,1,DATE,7 December 2015,the seventh of december twenty fifteen,7 December 2015,,,
1342,106,0,DATE,"January 30, 1987",january thirtieth nineteen eighty seven,"January 30, 1987",,,
5289,427,5,DATE,"February 14, 1999",february fourteenth nineteen ninety nine,"February 14, 1999",,,
5996,485,1,DATE,"Feb. 6, 2008",february sixth two thousand eight,"Feb. 6, 2008",,,
10755,881,5,DATE,2006-05-23,the twenty third of may two thousand six,2006-05-23,,,
13272,1094,1,DATE,"Thursday, April 19",thursday april nineteenth,"Thursday, April 19",,,
16084,1310,5,DATE,"March 3, 1884",march third eighteen eighty four,"March 3, 1884",,,
19544,1568,6,DATE,"August 31, 1945",august thirty first nineteen forty five,"August 31, 1945",dr,,
21325,1713,1,DATE,"April 14, 1999",april fourteenth nineteen ninety nine,"April 14, 1999",,,
25300,2014,0,DATE,"December 26, 1993",december twenty sixth nineteen ninety three,"December 26, 1993",,,



--- LETTERS sample errors ---


Unnamed: 0,sentence_id,token_id,class,before,after,pred,prev_before,next_before,prev2_before
740,54,12,LETTERS,Ukh,u k h,Ukh,,,
2715,219,0,LETTERS,CSIPL,c s i p l,CSIPL,,MUV,
2725,219,10,LETTERS,MUV,m u v,MUV,CSIPL,,
7125,582,1,LETTERS,SSNP,s s n p,SSNP,,organised,
7817,635,0,LETTERS,FISD,f i s d,FISD,,,
11049,908,4,LETTERS,RCACC,r c a c c,RCACC,2968,,
14108,1147,13,LETTERS,Tarw,t a r w,Tarw,,,
16185,1317,6,LETTERS,NYCLU's,n y c l u's,NYCLU's,,,
16271,1321,21,LETTERS,WDTA,w d t a,WDTA,,,
16392,1332,3,LETTERS,HURINET,h u r i n e t,HURINET,,,



--- CARDINAL sample errors ---


Unnamed: 0,sentence_id,token_id,class,before,after,pred,prev_before,next_before,prev2_before
460,35,4,CARDINAL,1065520,one million sixty five thousand five hundred twenty,1065520,,,
3356,275,2,CARDINAL,20300,twenty thousand three hundred,20300,,,
7563,617,17,CARDINAL,235758,two hundred thirty five thousand seven hundred fifty eight,235758,,24606,
7569,617,23,CARDINAL,24606,twenty four thousand six hundred six,24606,235758.0,,
8917,731,2,CARDINAL,11339,eleven thousand three hundred thirty nine,11339,,550.8,
11048,908,3,CARDINAL,2968,two thousand nine hundred sixty eight,2968,,RCACC,
11681,964,12,CARDINAL,10044,ten thousand forty four,10044,,,
12262,1014,4,CARDINAL,00,zero,o o,,,
12510,1037,17,CARDINAL,19506,nineteen thousand five hundred six,19506,,,
16477,1339,7,CARDINAL,78,seventy eight,78,,,



DATE error buckets:
bucket
other     4287
dash       525
slash       67
dot         31
decade      11
Name: count, dtype: int64

DATE ambiguous M/D/Y count: 37

DATE 4-digit year errors with month context:
prev month: 0 | prev2 month: 0

LETTERS error buckets:
bucket
ALLCAPS>=3        2036
other              752
wordlike_mixed     642
DOTTED>=3          116
Name: count, dtype: int64

CARDINAL 4-digit yearish errors: 405
...with month context -> prev: 0 prev2: 0

T6 elapsed: 0.77s
