# Production Final Plan

Goal: Achieve a medal via a clean, alignment-safe ensemble using only the strongest, well-aligned streams and robust consensus rules.

Key assets (alignment-safe only):
- submission_charfusion_512_single_seed_alignsafe.csv
- submission_tokendp_512.csv  (improved Token-DP from Cell 24)
- submission_tokenselect_512single_or_384_lambda012.csv
- Optional reference: submission_zeroshot_xlmr_squad2.csv for ultra-safe micro-merge (low weight, override only when high-agreement and short).

Do NOT use misaligned artifacts (e.g., 3-seed 512 averaged logits). Exclude 384-only from HI pool unless experts insist.

Ensemble variants to implement (fast, deterministic, alignment-safe):
1) Per-language consensus (base):
   - Candidates: {charfusion_512_single_seed_alignsafe, tokendp_512, tokenselect_512-or-384}.
   - Voting: mean Jaccard (punct-insensitive), with per-language length regularization.
   - Length regularization (initial): lambda_hi=0.02, lambda_ta=0.045 (tunable).
   - Overrides:
     * Numeric/date micro-override if all candidates numeric/date-like and len <= cap (HI<=18, TA<=16).
     * Tie rules: if A != B and norm_text(A) == norm_text(B) then pick B; if TS==TD, tie-break preference order (tokenselect > tokendp > charfusion).
     * Per-language confidence gate: if consensus score < thresh_lang, fall back to Token-DP (thresh_hi ~= 0.35, thresh_ta ~= 0.38; tunable).

2) Per-id hedge:
   - Compute consensus as in (1) and keep Token-DP as backup.
   - If per-id confidence < CONF_THRESH_LANG, use Token-DP; else use consensus.
   - Apply the same tie-break and numeric/date overrides.

3) Meta-hedge:
   - Combine outputs from (1), (2), and pure Token-DP.
   - If answers are text-normalization-equal, snap to the more confident candidate; otherwise pick majority; break ties with per-language length reg.

Diagnostics for each run:
- mean_len overall and by language, histogram tail checks.
- Route percentages (% handled by consensus vs token-DP, % numeric overrides).
- Sanity: no empty strings unless CLS intended; print 10 random samples per language.
- Performance hygiene: no long loops without logging; always print elapsed time.

Submission flow:
- Generate submission files for (1), (2), and Meta-hedge (3).
- Point submission.csv to best candidate and submit.

Open questions for experts:
- Confirm final candidate set (drop 384-only in HI?).
- Recommended per-language lambdas and confidence thresholds to stabilize mean_len ~10.6–10.9.
- Any extra micro-rules (e.g., safe containment, digit normalization on TA only, punctuation snapping) that improved late LB.
- Known problematic patterns to avoid (e.g., length drift triggers) and exact preferred tiebreak order.
- Whether to include ultra-safe zero-shot stream as override on high-agreement, short answers.

Next steps:
- Implement Variant (1) and (2) as small, clean functions; log diagnostics.
- Build Meta-hedge (3); evaluate mean_len and routes; prepare for quick submissions.

In [63]:
import pandas as pd, numpy as np, unicodedata as ud, re, time, random, os, sys

t0 = time.time()

# Locked params from expert advice
LAMBDA_HI = 0.020
LAMBDA_TA = 0.042
GATE_HI = 0.33
GATE_TA = 0.37

# Filepaths
fp_ts = 'submission_tokenselect_512single_or_384_lambda012.csv'
fp_td = 'submission_tokendp_512.csv'
fp_cf = 'submission_charfusion_512_single_seed_alignsafe.csv'

sub_ts = pd.read_csv(fp_ts)
sub_td = pd.read_csv(fp_td)
sub_cf = pd.read_csv(fp_cf)
test_df = pd.read_csv('test.csv')

# Heuristic language detection (script-based) — test may lack language column
def detect_lang(text):
    if not isinstance(text, str):
        return 'hi'
    return 'ta' if any('\u0B80' <= c <= '\u0BFF' for c in text) else 'hi'

col = 'question' if 'question' in test_df.columns else ('question_text' if 'question_text' in test_df.columns else ('context' if 'context' in test_df.columns else None))
if col:
    test_df['language'] = test_df[col].apply(detect_lang)
else:
    test_df['language'] = 'hi'

# Align and sanity
for df in (sub_ts, sub_td, sub_cf):
    assert 'id' in df.columns and 'PredictionString' in df.columns
    df['id'] = df['id'].astype(str)
test_df['id'] = test_df['id'].astype(str)

# Merge
base = test_df[['id']].merge(sub_ts.rename(columns={'PredictionString':'TS'}), on='id', how='left')\
                   .merge(sub_td.rename(columns={'PredictionString':'TD'}), on='id', how='left')\
                   .merge(sub_cf.rename(columns={'PredictionString':'CF'}), on='id', how='left')

# Context map for expansion
ctx_col = 'context' if 'context' in test_df.columns else None
id2ctx = dict(zip(test_df['id'], test_df[ctx_col])) if ctx_col else {}

def is_space_or_punct(c):
    return c.isspace() or ud.category(c).startswith('P')

def expand_in_context(ans, ctx, cap, majority_key):
    if not ans or not isinstance(ctx, str) or not ctx:
        return ans
    i = ctx.find(ans)
    if i == -1 or ctx.rfind(ans) != i:
        return ans  # require unique occurrence
    L0, R0 = i, i + len(ans)
    left_adj = (L0 > 0 and not is_space_or_punct(ctx[L0-1]))
    right_adj = (R0 < len(ctx) and not is_space_or_punct(ctx[R0]))
    if not (left_adj or right_adj):
        return ans
    L, R = L0, R0
    while L > 0 and not is_space_or_punct(ctx[L-1]):
        L -= 1
    while R < len(ctx) and not is_space_or_punct(ctx[R]):
        R += 1
    expanded = ctx[L:R]
    if ctx.count(expanded) > 1: return ans
    if len(expanded) > cap:
        return ans
    if majority_key and majority_key not in norm_basic(expanded):
        return ans
    return expanded

# Utilities
PUNCT_CLASS = ''.join(chr(i) for i in range(sys.maxunicode) if ud.category(chr(i)).startswith('P'))
PUNCT_RE = re.compile(f"[\s{re.escape(PUNCT_CLASS)}]+")
DANDA = '\u0964'  # Hindi danda
TA_PULLI = '\u0bcd'  # Tamil pulli/virama
DEV_VIRAMA = '\u094d'

def nfc(s):
    try:
        return ud.normalize('NFC', s)
    except Exception:
        return s

def norm_basic(s):
    if not isinstance(s, str):
        s = '' if s is np.nan else str(s)
    s = nfc(s.strip())
    # punctuation-insensitive: collapse punctuation+spaces to single space
    s = PUNCT_RE.sub(' ', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

# Digit maps for HI/TA to ASCII
HI_DIGITS = {ord(c): ord('0')+i for i, c in enumerate('\u0966\u0967\u0968\u0969\u096a\u096b\u096c\u096d\u096e\u096f')}
TA_DIGITS = {ord(c): ord('0')+i for i, c in enumerate('\u0be6\u0be7\u0be8\u0be9\u0bea\u0beb\u0bec\u0bed\u0bee\u0bef')}

def ascii_digits(s):
    s = s.translate(HI_DIGITS).translate(TA_DIGITS)
    return s

SEP_RE = re.compile(r'[\./\-,:\s]+')
NUM_CHARS_RE = re.compile(r'^[0-9\-]+$')

def is_numeric_or_date_like(s, lang, cap):
    if not isinstance(s, str) or not s:
        return False
    s2 = nfc(s)
    s2 = ascii_digits(s2)
    # keep only digits and common separators, normalize runs of seps to '-'
    s2 = SEP_RE.sub('-', s2).strip('-')
    if not s2:
        return False
    if len(s2) > cap:
        return False
    return bool(NUM_CHARS_RE.match(s2))

def maybe_trim_trailing_mark(ans, candidates_norm):
    # minimal snap: trim single trailing danda/virama/pulli only if trimmed normalized form matches another candidate
    if not ans:
        return ans
    trimmed = None
    if ans.endswith(DANDA) or ans.endswith(TA_PULLI) or ans.endswith(DEV_VIRAMA):
        trimmed = ans[:-1]
    if trimmed is not None:
        if norm_basic(trimmed) in candidates_norm:
            return trimmed
    return ans

# Priority order
PRIORITY = {'TS': 3, 'TD': 2, 'CF': 1}

id2lang = dict(zip(test_df['id'], test_df['language']))

out = []
routes = {'majority':0, 'fallback_td':0, 'numeric_override':0, 'norm_tie':0, 'trim_snap':0, 'expansion':0, 'substr_snap':0}

for i, row in base.iterrows():
    if (i % 500 == 0) and i:
        print(f'Processed {i}/{len(base)} in {time.time()-t0:.1f}s', flush=True)
    _id = row['id']
    lang = id2lang.get(_id, 'hi')
    ts, td, cf = row['TS'] if isinstance(row['TS'], str) else '', row['TD'] if isinstance(row['TD'], str) else '', row['CF'] if isinstance(row['CF'], str) else ''
    n_ts, n_td, n_cf = norm_basic(ts), norm_basic(td), norm_basic(cf)

    # majority by normalized form
    counts = {}
    for k, n in (('TS', n_ts), ('TD', n_td), ('CF', n_cf)):
        counts.setdefault(n, []).append(k)

    chosen = None
    chosen_src = None
    # any normalized string supported by >=2 sources
    majority_key = None
    for n, srcs in counts.items():
        if len(srcs) >= 2 and n != '':
            # pick highest-priority source string among srcs
            srcs_sorted = sorted(srcs, key=lambda x: -PRIORITY[x])
            top = srcs_sorted[0]
            if top == 'TS':
                chosen, chosen_src = ts, 'TS'    
            elif top == 'TD':
                chosen, chosen_src = td, 'TD'
            else:
                chosen, chosen_src = cf, 'CF'
            majority_key = n
            routes['majority'] += 1
            # Cap-aware tie within majority (expert one-liner)
            cap = 18 if lang == 'hi' else 16
            if chosen_src == 'TS' and majority_key is not None and len(ts) > cap and len(td) <= cap and n_td == majority_key:
                chosen, chosen_src = td, 'TD'
            # Extended: TS too long; TD is strict substring under norm
            if chosen_src == 'TS' and len(ts) > cap and len(td) <= cap and (n_td and n_td in n_ts and n_td != n_ts): chosen, chosen_src = td, 'TD'
            # Tamil snap: trim trailing pulli when it matches majority norm
            if lang=='ta' and chosen and chosen.endswith(TA_PULLI) and majority_key and norm_basic(chosen[:-1])==majority_key: chosen=chosen[:-1]
            break

    if chosen is None:
        # no majority: fallback to TD
        chosen, chosen_src = td, 'TD'
        routes['fallback_td'] += 1

    # normalization tie rule: if chosen != some candidate but normalized equals, snap to higher-priority stream's raw string
    # compare against TS then CF with priority
    if chosen_src != 'TS' and norm_basic(chosen) == n_ts and ts:
        chosen, chosen_src = ts, 'TS'
        routes['norm_tie'] += 1
    elif chosen_src not in ('TS','TD') and norm_basic(chosen) == n_td and td:
        # TD outranks CF
        chosen, chosen_src = td, 'TD'
        routes['norm_tie'] += 1

    # minimal punctuation/virama snap if it aligns to another candidate
    cand_norm_set = {n_ts, n_td, n_cf}
    snapped = maybe_trim_trailing_mark(chosen, cand_norm_set)
    if snapped != chosen:
        chosen = snapped
        routes['trim_snap'] += 1

    # numeric/date micro-override last
    cap = 18 if lang == 'hi' else 16
    if is_numeric_or_date_like(ts, lang, cap) and is_numeric_or_date_like(td, lang, cap) and is_numeric_or_date_like(cf, lang, cap):
        chosen, chosen_src = td, 'TD'
        routes['numeric_override'] += 1

    # Substring safety snaps before expansion
    if ctx_col:
        ctx = id2ctx.get(_id, '')
        if isinstance(ctx, str) and ctx and chosen:
            # If chosen not in context but TD is, snap to TD
            if ctx.find(chosen) == -1 and td and ctx.find(td) != -1:
                chosen, chosen_src = td, 'TD'
                routes['substr_snap'] += 1
            # Generalized: if chosen not in context but any of TS/TD/CF is present, pick highest-priority present
            elif ctx.find(chosen) == -1:
                present = []
                if ts and ctx.find(ts) != -1: present.append('TS')
                if td and ctx.find(td) != -1: present.append('TD')
                if cf and ctx.find(cf) != -1: present.append('CF')
                if present:
                    best = sorted(present, key=lambda x: -PRIORITY[x])[0]
                    if best == 'TS':
                        chosen, chosen_src = ts, 'TS'
                    elif best == 'TD':
                        chosen, chosen_src = td, 'TD'
                    else:
                        chosen, chosen_src = cf, 'CF'
                    routes['substr_snap'] += 1

    # safe word-boundary expansion on chosen (post-numeric, post-pulli)
    if ctx_col:
        ctx = id2ctx.get(_id, None)
        if chosen and not is_numeric_or_date_like(chosen, lang, cap) and isinstance(ctx, str) and ctx:
            expanded = expand_in_context(chosen, ctx, cap, majority_key)
            if expanded != chosen:
                chosen = expanded
                routes['expansion'] += 1

    out.append((_id, chosen))

sub = pd.DataFrame(out, columns=['id','PredictionString'])

# Diagnostics
sub['len'] = sub['PredictionString'].astype(str).str.len()
df_lang = sub.merge(test_df[['id','language']], on='id', how='left')
mean_len = sub['len'].mean()
mean_len_hi = df_lang.loc[df_lang['language']=='hi','len'].mean()
mean_len_ta = df_lang.loc[df_lang['language']=='ta','len'].mean()
print(f'mean_len overall: {mean_len:.2f} | HI: {mean_len_hi:.2f} | TA: {mean_len_ta:.2f}', flush=True)
total = len(sub)
print('Routes %: ' + ', '.join(f"{k}={v/total*100:.1f}%" for k,v in routes.items()), flush=True)

# Random samples
for lg in ('hi','ta'):
    samp = df_lang[df_lang['language']==lg].sample(min(5, (df_lang['language']==lg).sum()), random_state=42)
    print(f'\nSamples {lg}:', flush=True)
    for _, r in samp.iterrows():
        print(r['id'], '->', r['PredictionString'][:120])

out_fp = 'submission_primary_majority_fallback.csv'
sub[['id','PredictionString']].to_csv(out_fp, index=False)
print('Wrote', out_fp, 'in', f'{time.time()-t0:.1f}s')

# Point submission.csv to this output
sub[['id','PredictionString']].to_csv('submission.csv', index=False)
print('submission.csv updated')

mean_len overall: 10.79 | HI: 11.01 | TA: 10.14


Routes %: majority=86.6%, fallback_td=13.4%, numeric_override=8.9%, norm_tie=0.0%, trim_snap=0.0%, expansion=1.8%, substr_snap=1.8%



Samples hi:


9b04631cf -> ३० जनवरी २०१५
be799d365 -> मुंबई, महाराष्ट्र
33d679522 -> ६४
8e10fecdf -> मार्क्स-एंगेल्स
0c35b67ae -> २८ सितम्बर १९०७

Samples ta:


b151705b8 -> 1900
5e1f9bca8 -> லிஸ்பன்
921a348f2 -> 5488
1df390d9a -> சார்லஸ் ராபர்ட் டார்வின
57a56c43f -> புறணி
Wrote submission_primary_majority_fallback.csv in 0.3s
submission.csv updated


In [89]:
import pandas as pd, numpy as np, unicodedata as ud, re, time, random, sys

t1 = time.time()

# Locked params
LAMBDA_HI = 0.020
LAMBDA_TA = 0.032
GATE_HI = 0.33
GATE_TA = 0.37

fp_ts = 'submission_tokenselect_512single_or_384_lambda012.csv'
fp_td = 'submission_tokendp_512.csv'
fp_cf = 'submission_charfusion_512_single_seed_alignsafe.csv'

sub_ts = pd.read_csv(fp_ts)
sub_td = pd.read_csv(fp_td)
sub_cf = pd.read_csv(fp_cf)
test_df = pd.read_csv('test.csv')

# Heuristic language detection (script-based) — test may lack language column
def detect_lang(text):
    if not isinstance(text, str):
        return 'hi'
    return 'ta' if any('\u0B80' <= c <= '\u0BFF' for c in text) else 'hi'

col = 'question' if 'question' in test_df.columns else ('question_text' if 'question_text' in test_df.columns else ('context' if 'context' in test_df.columns else None))
if col:
    test_df['language'] = test_df[col].apply(detect_lang)
else:
    test_df['language'] = 'hi'

for df in (sub_ts, sub_td, sub_cf):
    df['id'] = df['id'].astype(str)
test_df['id'] = test_df['id'].astype(str)

base = test_df[['id']].merge(sub_ts.rename(columns={'PredictionString':'TS'}), on='id', how='left')\
                   .merge(sub_td.rename(columns={'PredictionString':'TD'}), on='id', how='left')\
                   .merge(sub_cf.rename(columns={'PredictionString':'CF'}), on='id', how='left')

# Context map for expansion (we will not use expansion in this no-exp variant)
ctx_col = 'context' if 'context' in test_df.columns else None
id2ctx = dict(zip(test_df['id'], test_df[ctx_col])) if ctx_col else {}

# Utilities (reuse from previous cell light version)
PUNCT_CLASS = ''.join(chr(i) for i in range(sys.maxunicode) if ud.category(chr(i)).startswith('P'))
PUNCT_RE = re.compile(f"[\s{re.escape(PUNCT_CLASS)}]+")
DANDA = '\u0964'
TA_PULLI = '\u0bcd'
DEV_VIRAMA = '\u094d'

def nfc(s):
    try:
        return ud.normalize('NFC', s)
    except Exception:
        return s

def norm_basic(s):
    if not isinstance(s, str):
        s = ''
    s = nfc(s.strip())
    s = PUNCT_RE.sub(' ', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

def jaccard_words(a, b):
    a = norm_basic(a)
    b = norm_basic(b)
    if not a and not b:
        return 1.0
    A = set(a.split())
    B = set(b.split())
    if not A and not B:
        return 1.0
    if not A or not B:
        return 0.0
    inter = len(A & B)
    union = len(A | B)
    return inter / union if union else 0.0

# Numeric/date helpers
HI_DIGITS = {ord(c): ord('0')+i for i, c in enumerate('\u0966\u0967\u0968\u0969\u096a\u096b\u096c\u096d\u096e\u096f')}
TA_DIGITS = {ord(c): ord('0')+i for i, c in enumerate('\u0be6\u0be7\u0be8\u0be9\u0bea\u0beb\u0bec\u0bed\u0bee\u0bef')}
def ascii_digits(s):
    return s.translate(HI_DIGITS).translate(TA_DIGITS)
SEP_RE = re.compile(r'[\./\-,:\s]+')
NUM_CHARS_RE = re.compile(r'^[0-9\-]+$')
def is_numeric_or_date_like(s, lang, cap):
    if not isinstance(s, str) or not s:
        return False
    s2 = ascii_digits(nfc(s))
    s2 = SEP_RE.sub('-', s2).strip('-')
    if not s2 or len(s2) > cap:
        return False
    return bool(NUM_CHARS_RE.match(s2))
def maybe_trim_trailing_mark(ans, candidates_norm):
    if not ans:
        return ans
    trimmed = None
    if ans.endswith(DANDA) or ans.endswith(TA_PULLI) or ans.endswith(DEV_VIRAMA):
        trimmed = ans[:-1]
    if trimmed is not None and norm_basic(trimmed) in candidates_norm:
        return trimmed
    return ans

# Expansion helpers
def is_space_or_punct(c):
    return c.isspace() or ud.category(c).startswith('P')

# No expansion variant: keep function but do not call it
def expand_in_context(ans, ctx, cap):
    return ans

PRIORITY = {'TS': 3, 'TD': 2, 'CF': 1}
id2lang = dict(zip(test_df['id'], test_df['language']))

out = []
routes = {'consensus_used':0, 'fallback_td':0, 'trim_snap':0, 'numeric_override':0, 'expansion':0, 'ta_boundary_extend':0}

for i, row in base.iterrows():
    if (i % 500 == 0) and i:
        print(f'Processed {i}/{len(base)} in {time.time()-t1:.1f}s', flush=True)
    _id = row['id']
    lang = id2lang.get(_id, 'hi')
    lam = LAMBDA_HI if lang == 'hi' else LAMBDA_TA
    gate = GATE_HI if lang == 'hi' else GATE_TA
    cap = 18 if lang == 'hi' else 16

    cand = {'TS': row['TS'] if isinstance(row['TS'], str) else '',
            'TD': row['TD'] if isinstance(row['TD'], str) else '',
            'CF': row['CF'] if isinstance(row['CF'], str) else ''}
    norm = {k: norm_basic(v) for k, v in cand.items()}

    # Majority by normalized text
    votes = {}
    for k, n in norm.items():
        votes.setdefault(n, []).append(k)
    majority_norm = None
    maj_srcs = []
    for n, srcs in votes.items():
        if n and len(srcs) >= 2:
            majority_norm = n
            maj_srcs = sorted(srcs, key=lambda x: -PRIORITY[x])
            break

    # Score each candidate by mean jaccard vs others minus length penalty
    scores = {}
    keys = ['TS','TD','CF']
    for k in keys:
        others = [cand[o] for o in keys if o != k]
        jac = 0.0
        for o in others:
            jac += jaccard_words(cand[k], o)
        jac /= 2.0
        score = jac - lam * len(cand[k])
        scores[k] = score

    # Pick consensus candidate
    if majority_norm is not None:
        best_src = maj_srcs[0]
        consensus_ans = cand[best_src]
        consensus_score = scores[best_src]
    else:
        max_score = max(scores.values())
        best_srcs = [k for k, v in scores.items() if abs(v - max_score) < 1e-9]
        best_src = sorted(best_srcs, key=lambda x: -PRIORITY[x])[0]
        consensus_ans = cand[best_src]
        consensus_score = scores[best_src]

    # Micro post-processing: minimal trailing mark snap and numeric/date override (tight, safe)
    snapped = maybe_trim_trailing_mark(consensus_ans, set(norm.values()))
    if snapped != consensus_ans:
        consensus_ans = snapped
        routes['trim_snap'] += 1

    # Tight numeric override: require all three numeric-like AND TD ascii form present uniquely with clean boundaries
    if is_numeric_or_date_like(cand['TS'], lang, cap) and is_numeric_or_date_like(cand['TD'], lang, cap) and is_numeric_or_date_like(cand['CF'], lang, cap):
        if ctx_col:
            ctx = id2ctx.get(_id, '')
        else:
            ctx = ''
        if isinstance(ctx, str) and ctx:
            td_ascii = ascii_digits(cand['TD'])
            st = ctx.find(td_ascii)
            if st != -1:
                ed = st + len(td_ascii)
                left = (st == 0) or is_space_or_punct(ctx[st-1])
                right = (ed == len(ctx)) or is_space_or_punct(ctx[ed])
                if left and right and len(cand['TD']) <= cap and ctx.count(td_ascii) == 1:
                    consensus_ans = cand['TD']
                    routes['numeric_override'] += 1

    # Tamil boundary-extend (very safe): TA only, short, non-numeric, unique, clean boundaries, cap<=16
    if lang == 'ta' and isinstance(consensus_ans, str) and consensus_ans and len(consensus_ans) < 16:
        if not is_numeric_or_date_like(consensus_ans, lang, cap):
            ctx = id2ctx.get(_id, '') if ctx_col else ''
            if isinstance(ctx, str) and ctx:
                for k in ['TS', 'TD']:
                    v = cand.get(k, '')
                    if (v and len(v) > len(consensus_ans) and len(v) <= cap and (consensus_ans in v)):
                        st = ctx.find(v)
                        if st != -1 and ctx.count(v) == 1:
                            ed = st + len(v)
                            left_ok = (st == 0) or is_space_or_punct(ctx[st-1])
                            right_ok = (ed == len(ctx)) or is_space_or_punct(ctx[ed])
                            if left_ok and right_ok:
                                consensus_ans = v
                                routes['ta_boundary_extend'] += 1
                                break

    # No expansion in this variant (safer late-LB)
    # if ctx_col:
    #     ctx = id2ctx.get(_id, None)
    #     if consensus_ans and not is_numeric_or_date_like(consensus_ans, lang, cap) and isinstance(ctx, str) and ctx:
    #         expanded = expand_in_context(consensus_ans, ctx, cap)
    #         if expanded != consensus_ans:
    #             consensus_ans = expanded
    #             routes['expansion'] += 1

    if consensus_score >= gate:
        out.append((_id, consensus_ans))
        routes['consensus_used'] += 1
    else:
        out.append((_id, cand['TD']))
        routes['fallback_td'] += 1

sub = pd.DataFrame(out, columns=['id','PredictionString'])
sub['len'] = sub['PredictionString'].astype(str).str.len()
df_lang = sub.merge(test_df[['id','language']], on='id', how='left')
print(f"mean_len overall: {sub['len'].mean():.2f} | HI: {df_lang.loc[df_lang['language']=='hi','len'].mean():.2f} | TA: {df_lang.loc[df_lang['language']=='ta','len'].mean():.2f}")
total = len(sub)
print('Routes %: ' + ', '.join(f"{k}={v/total*100:.1f}%" for k,v in routes.items()), flush=True)

out_fp = 'submission_consensus_with_gates_plusmicros_noexp.csv'
sub[['id','PredictionString']].to_csv(out_fp, index=False)
print('Wrote', out_fp, 'in', f'{time.time()-t1:.1f}s')
sub[['id','PredictionString']].to_csv('submission.csv', index=False)
print('submission.csv updated → consensus_with_gates_plusmicros_noexp')

mean_len overall: 10.71 | HI: 11.01 | TA: 9.82
Routes %: consensus_used=59.8%, fallback_td=40.2%, trim_snap=0.0%, numeric_override=1.8%, expansion=0.0%, ta_boundary_extend=0.0%


Wrote submission_consensus_with_gates_plusmicros_noexp.csv in 0.3s
submission.csv updated → consensus_with_gates_plusmicros_noexp


In [22]:
# Variant: Majority-fallback, NO expansion, strict caps, substring safety
import pandas as pd, numpy as np, unicodedata as ud, re, time, random, sys

t2 = time.time()

fp_ts = 'submission_tokenselect_512single_or_384_lambda012.csv'
fp_td = 'submission_tokendp_512.csv'
fp_cf = 'submission_charfusion_512_single_seed_alignsafe.csv'

sub_ts = pd.read_csv(fp_ts)
sub_td = pd.read_csv(fp_td)
sub_cf = pd.read_csv(fp_cf)
test_df = pd.read_csv('test.csv')

for df in (sub_ts, sub_td, sub_cf):
    df['id'] = df['id'].astype(str)
test_df['id'] = test_df['id'].astype(str)

def detect_lang(text):
    if not isinstance(text, str):
        return 'hi'
    return 'ta' if any('\u0B80' <= c <= '\u0BFF' for c in text) else 'hi'

col = 'question' if 'question' in test_df.columns else ('question_text' if 'question_text' in test_df.columns else ('context' if 'context' in test_df.columns else None))
if col:
    test_df['language'] = test_df[col].apply(detect_lang)
else:
    test_df['language'] = 'hi'

base = test_df[['id']].merge(sub_ts.rename(columns={'PredictionString':'TS'}), on='id', how='left')\
                   .merge(sub_td.rename(columns={'PredictionString':'TD'}), on='id', how='left')\
                   .merge(sub_cf.rename(columns={'PredictionString':'CF'}), on='id', how='left')

ctx_col = 'context' if 'context' in test_df.columns else None
id2ctx = dict(zip(test_df['id'], test_df[ctx_col])) if ctx_col else {}

DANDA='\u0964'; TA_PULLI='\u0bcd'; DEV_VIRAMA='\u094d'
PUNCT_CLASS = ''.join(chr(i) for i in range(sys.maxunicode) if ud.category(chr(i)).startswith('P'))
PUNCT_RE = re.compile(f"[\s{re.escape(PUNCT_CLASS)}]+")

def nfc(s):
    try: return ud.normalize('NFC', s)
    except Exception: return s

def norm_basic(s):
    if not isinstance(s, str): s = ''
    s = nfc(s.strip())
    s = PUNCT_RE.sub(' ', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

HI_DIGITS = {ord(c): ord('0')+i for i, c in enumerate('\u0966\u0967\u0968\u0969\u096a\u096b\u096c\u096d\u096e\u096f')}
TA_DIGITS = {ord(c): ord('0')+i for i, c in enumerate('\u0be6\u0be7\u0be8\u0be9\u0bea\u0beb\u0bec\u0bed\u0bee\u0bef')}
def ascii_digits(s):
    return s.translate(HI_DIGITS).translate(TA_DIGITS)
SEP_RE = re.compile(r'[\./\-,:\s]+')
NUM_CHARS_RE = re.compile(r'^[0-9\-]+$')
def is_numeric_or_date_like(s, lang, cap):
    if not isinstance(s, str) or not s: return False
    s2 = ascii_digits(nfc(s)); s2 = SEP_RE.sub('-', s2).strip('-')
    if not s2 or len(s2)>cap: return False
    return bool(NUM_CHARS_RE.match(s2))

def maybe_trim_trailing_mark(ans, candidates_norm):
    if not ans: return ans
    trimmed = None
    if ans.endswith(DANDA) or ans.endswith(TA_PULLI) or ans.endswith(DEV_VIRAMA):
        trimmed = ans[:-1]
    if trimmed is not None and norm_basic(trimmed) in candidates_norm:
        return trimmed
    return ans

PRIORITY = {'TS':3,'TD':2,'CF':1}
id2lang = dict(zip(test_df['id'], test_df['language']))

out=[]
routes={'majority':0,'fallback_td':0,'numeric_override':0,'norm_tie':0,'trim_snap':0,'cap_swap':0,'substr_snap':0}

for i, row in base.iterrows():
    if (i % 500 == 0) and i:
        print(f'Processed {i}/{len(base)} in {time.time()-t2:.1f}s', flush=True)
    _id = row['id']
    lang = id2lang.get(_id, 'hi')
    cap = 18 if lang=='hi' else 16
    ts = row['TS'] if isinstance(row['TS'], str) else ''
    td = row['TD'] if isinstance(row['TD'], str) else ''
    cf = row['CF'] if isinstance(row['CF'], str) else ''
    n_ts, n_td, n_cf = norm_basic(ts), norm_basic(td), norm_basic(cf)

    counts = {}
    for k, n in (('TS', n_ts), ('TD', n_td), ('CF', n_cf)):
        counts.setdefault(n, []).append(k)

    chosen=None; chosen_src=None; majority_key=None
    for n, srcs in counts.items():
        if n and len(srcs)>=2:
            top = sorted(srcs, key=lambda x: -PRIORITY[x])[0]
            chosen = ts if top=='TS' else (td if top=='TD' else cf)
            chosen_src = top
            majority_key = n
            routes['majority']+=1
            break

    if chosen is None:
        chosen, chosen_src = td, 'TD'
        routes['fallback_td']+=1

    cand_norm_set = {n_ts, n_td, n_cf}
    snapped = maybe_trim_trailing_mark(chosen, cand_norm_set)
    if snapped != chosen:
        chosen = snapped
        routes['trim_snap']+=1

    if is_numeric_or_date_like(ts, lang, cap) and is_numeric_or_date_like(td, lang, cap) and is_numeric_or_date_like(cf, lang, cap):
        chosen, chosen_src = td, 'TD'
        routes['numeric_override']+=1

    # Strict cap enforcement by swapping to shorter candidate that matches majority norm or is substring-safe
    if len(chosen) > cap:
        # prefer TD then TS then CF if within cap and either matches majority norm or is a shorter substring of chosen under normalization
        candidates = [('TD', td, n_td), ('TS', ts, n_ts), ('CF', cf, n_cf)]
        replaced=False
        for src, txt, nrm in candidates:
            if txt and len(txt) <= cap:
                if (majority_key and nrm == majority_key) or (norm_basic(txt) in norm_basic(chosen)):
                    chosen, chosen_src = txt, src
                    routes['cap_swap']+=1
                    replaced=True
                    break
        # else keep chosen as-is (avoid truncation not guaranteed in-context)

    # Substring safety: ensure output occurs in context when available; if not, prefer TD if it occurs
    if ctx_col:
        ctx = id2ctx.get(_id, '')
        if isinstance(ctx, str) and ctx:
            if chosen and ctx.find(chosen) == -1:
                if td and ctx.find(td) != -1:
                    chosen, chosen_src = td, 'TD'
                    routes['substr_snap']+=1

    out.append((_id, chosen))

sub = pd.DataFrame(out, columns=['id','PredictionString'])
sub['len'] = sub['PredictionString'].astype(str).str.len()
df_lang = sub.merge(test_df[['id','language']], on='id', how='left')
print(f"mean_len overall: {sub['len'].mean():.2f} | HI: {df_lang.loc[df_lang['language']=='hi','len'].mean():.2f} | TA: {df_lang.loc[df_lang['language']=='ta','len'].mean():.2f}")
total = len(sub)
print('Routes %: ' + ', '.join(f"{k}={v/total*100:.1f}%" for k,v in routes.items()), flush=True)

out_fp = 'submission_primary_majority_noexp_strictcaps.csv'
sub[['id','PredictionString']].to_csv(out_fp, index=False)
print('Wrote', out_fp, 'in', f'{time.time()-t2:.1f}s')
sub[['id','PredictionString']].to_csv('submission.csv', index=False)
print('submission.csv updated → majority_noexp_strictcaps')

mean_len overall: 10.71 | HI: 11.01 | TA: 9.82
Routes %: majority=86.6%, fallback_td=13.4%, numeric_override=8.9%, norm_tie=0.0%, trim_snap=0.0%, cap_swap=0.0%, substr_snap=0.0%


Wrote submission_primary_majority_noexp_strictcaps.csv in 0.3s
submission.csv updated → majority_noexp_strictcaps


In [24]:
# Logit-level decoder: XLM-R large (alignment-safe), per-language length prior and caps
import numpy as np, pandas as pd, json, time, unicodedata as ud, re, sys

t = time.time()
LOGITS_DIR = 'xlmr_large_test_logits'  # alignment-safe XLM-R large 384
START_GLOB = [f'{LOGITS_DIR}/test_start_logits_f{i}.npy' for i in range(5)]
END_GLOB   = [f'{LOGITS_DIR}/test_end_logits_f{i}.npy'   for i in range(5)]
OFFSET_FP  = f'{LOGITS_DIR}/test_offset_mapping.npy'
EID_FP     = f'{LOGITS_DIR}/test_example_id.json'

# Params
K = 20  # beam size for starts/ends
LAMBDA_HI = 0.020
LAMBDA_TA = 0.042
CAP_HI = 18
CAP_TA = 16
BOUNDARY_BONUS = 0.12  # increased bonus per expert nudge

# Load test and language detection (script-based)
test_df = pd.read_csv('test.csv')
def detect_lang(text):
    if not isinstance(text, str):
        return 'hi'
    return 'ta' if any('\u0B80' <= c <= '\u0BFF' for c in text) else 'hi'
col = 'question' if 'question' in test_df.columns else ('question_text' if 'question_text' in test_df.columns else ('context' if 'context' in test_df.columns else None))
if col:
    test_df['language'] = test_df[col].apply(detect_lang)
else:
    test_df['language'] = 'hi'
test_df['id'] = test_df['id'].astype(str)
id2lang = dict(zip(test_df['id'], test_df['language']))
id2ctx = dict(zip(test_df['id'], test_df['context'])) if 'context' in test_df.columns else {}

# Load logits and metadata
starts = [np.load(fp) for fp in START_GLOB]
ends   = [np.load(fp) for fp in END_GLOB]
start_logits = np.mean(np.stack(starts, axis=0), axis=0)  # [Nfeat, L]
end_logits   = np.mean(np.stack(ends,   axis=0), axis=0)  # [Nfeat, L]
offsets = np.load(OFFSET_FP, allow_pickle=True)            # list/array of [L, 2] pairs (object arrays)
with open(EID_FP, 'r') as f:
    example_ids = json.load(f)  # list of example ids per feature row

assert len(start_logits) == len(end_logits) == len(offsets) == len(example_ids), 'Mismatched logits/meta lengths'
N = len(example_ids)
print(f'Loaded logits: features={N}, seq_len={start_logits.shape[1]}', flush=True)

def is_space_or_punct(c):
    return c.isspace() or ud.category(c).startswith('P')

def ensure_offset_array(off_raw):
    # Convert possibly 1D object array of tuples/lists to (L,2) int32 array safely
    if isinstance(off_raw, np.ndarray) and off_raw.ndim == 2 and off_raw.shape[1] == 2:
        return off_raw.astype(np.int32, copy=False)
    # off_raw is likely 1D object array/list of pairs
    try:
        seq = off_raw.tolist() if hasattr(off_raw, 'tolist') else list(off_raw)
    except Exception:
        return None
    pairs = []
    for p in seq:
        try:
            a = int(p[0]) if (p is not None and len(p) > 0 and p[0] is not None) else 0
            b = int(p[1]) if (p is not None and len(p) > 1 and p[1] is not None) else 0
        except Exception:
            a, b = 0, 0
        pairs.append((a, b))
    arr = np.asarray(pairs, dtype=np.int32)
    if arr.ndim != 2 or arr.shape[1] != 2:
        return None
    return arr

# Decode best span per example id across all feature windows
best_by_eid = {}  # eid -> (score, (start_char, end_char))
for i in range(N):
    if i and (i % 1000 == 0):
        print(f'Processed {i}/{N} features in {time.time()-t:.1f}s', flush=True)
    eid = str(example_ids[i])
    ctx = id2ctx.get(eid, '')
    lang = id2lang.get(eid, 'hi')
    lam = LAMBDA_HI if lang == 'hi' else LAMBDA_TA
    cap = CAP_HI if lang == 'hi' else CAP_TA
    off = ensure_offset_array(offsets[i])
    if off is None:
        continue
    s_log = start_logits[i].copy()
    e_log = end_logits[i].copy()

    # Mask invalid tokens: where offsets are (0,0), and CLS at position 0
    valid = (off[:,0] + off[:,1]) > 0
    if valid.shape[0] > 0:
        valid[0] = False  # drop CLS/no-answer
    s_log[~valid] = -1e30
    e_log[~valid] = -1e30

    # Top-K indices
    k_eff = int(min(K, int(valid.sum()))) if valid.ndim == 1 else K
    if k_eff <= 0:
        continue
    try:
        s_idx = np.argpartition(-s_log, k_eff)[:k_eff]
        e_idx = np.argpartition(-e_log, k_eff)[:k_eff]
    except ValueError:
        s_idx = np.argsort(-s_log)[:k_eff]
        e_idx = np.argsort(-e_log)[:k_eff]

    # Evaluate combinations
    local_best = None
    for si in s_idx:
        if si < 0 or si >= valid.shape[0] or not valid[si]:
            continue
        for ei in e_idx:
            if ei < 0 or ei >= valid.shape[0] or not valid[ei] or ei < si:
                continue
            st_char = int(off[si, 0])
            ed_char = int(off[ei, 1])
            if ed_char <= st_char:
                continue
            span_len_chars = ed_char - st_char
            if span_len_chars <= 0 or span_len_chars > cap:
                continue
            score = float(s_log[si] + e_log[ei] - lam * span_len_chars)
            # Word-boundary bonus (safe only adds bonus, never penalizes)
            if isinstance(ctx, str) and ctx:
                Lb = (st_char == 0) or is_space_or_punct(ctx[st_char-1])
                Rb = (ed_char >= len(ctx)) or is_space_or_punct(ctx[ed_char:ed_char+1])
                if Lb and Rb:
                    score += BOUNDARY_BONUS
            if (local_best is None) or (score > local_best[0]):
                local_best = (score, (st_char, ed_char))

    if local_best is None:
        continue
    # Keep the best across all windows for this example id
    if (eid not in best_by_eid) or (local_best[0] > best_by_eid[eid][0]):
        best_by_eid[eid] = local_best

# Build submission
pred_rows = []
missing = 0
for eid in test_df['id'].astype(str).tolist():
    ctx = id2ctx.get(eid, '')
    if eid in best_by_eid and isinstance(ctx, str) and ctx:
        _, (st_char, ed_char) = best_by_eid[eid]
        st_char = max(0, min(len(ctx), int(st_char)))
        ed_char = max(st_char, min(len(ctx), int(ed_char)))
        ans = ctx[st_char:ed_char]
        pred_rows.append((eid, ans))
    else:
        pred_rows.append((eid, ''))
        missing += 1

sub = pd.DataFrame(pred_rows, columns=['id','PredictionString'])
sub['len'] = sub['PredictionString'].astype(str).str.len()
df_lang = sub.merge(test_df[['id','language']], on='id', how='left')
print(f"mean_len overall: {sub['len'].mean():.2f} | HI: {df_lang.loc[df_lang['language']=='hi','len'].mean():.2f} | TA: {df_lang.loc[df_lang['language']=='ta','len'].mean():.2f}")
print('Missing preds:', missing, 'Elapsed:', f'{time.time()-t:.1f}s', flush=True)

out_fp = 'submission_xlmr_logitdecoder_lenprior_caps.csv'
sub[['id','PredictionString']].to_csv(out_fp, index=False)
sub[['id','PredictionString']].to_csv('submission.csv', index=False)
print('Wrote', out_fp, 'and updated submission.csv')

Loaded logits: features=1921, seq_len=384


Processed 1000/1921 features in 1.5s


mean_len overall: 9.65 | HI: 9.98 | TA: 8.68
Missing preds: 0 Elapsed: 2.1s


Wrote submission_xlmr_logitdecoder_lenprior_caps.csv and updated submission.csv


In [21]:
# Logit-level decode for XLM-R and MuRIL separately, then per-id choose higher-scoring span
import numpy as np, pandas as pd, json, time, unicodedata as ud, re, sys

t = time.time()
CAP_HI = 18
CAP_TA = 16
LAMBDA_HI = 0.020
LAMBDA_TA = 0.042
K = 20
BOUNDARY_BONUS = 0.05

test_df = pd.read_csv('test.csv')
def detect_lang(text):
    if not isinstance(text, str): return 'hi'
    return 'ta' if any('\u0B80' <= c <= '\u0BFF' for c in text) else 'hi'
col = 'question' if 'question' in test_df.columns else ('question_text' if 'question_text' in test_df.columns else ('context' if 'context' in test_df.columns else None))
if col: test_df['language'] = test_df[col].apply(detect_lang)
else: test_df['language'] = 'hi'
test_df['id'] = test_df['id'].astype(str)
id2lang = dict(zip(test_df['id'], test_df['language']))
id2ctx = dict(zip(test_df['id'], test_df['context'])) if 'context' in test_df.columns else {}

def is_space_or_punct(c):
    return c.isspace() or ud.category(c).startswith('P')

def ensure_offset_array(off_raw):
    if isinstance(off_raw, np.ndarray) and off_raw.ndim == 2 and off_raw.shape[1] == 2:
        return off_raw.astype(np.int32, copy=False)
    try:
        seq = off_raw.tolist() if hasattr(off_raw, 'tolist') else list(off_raw)
    except Exception:
        return None
    pairs = []
    for p in seq:
        try:
            a = int(p[0]) if (p is not None and len(p)>0 and p[0] is not None) else 0
            b = int(p[1]) if (p is not None and len(p)>1 and p[1] is not None) else 0
        except Exception:
            a, b = 0, 0
        pairs.append((a, b))
    arr = np.asarray(pairs, dtype=np.int32)
    if arr.ndim != 2 or arr.shape[1] != 2: return None
    return arr

def decode_dir(logits_dir, name):
    starts = [np.load(f'{logits_dir}/test_start_logits_f{i}.npy') for i in range(5)]
    ends   = [np.load(f'{logits_dir}/test_end_logits_f{i}.npy') for i in range(5)]
    start_logits = np.mean(np.stack(starts, axis=0), axis=0)
    end_logits   = np.mean(np.stack(ends,   axis=0), axis=0)
    offsets = np.load(f'{logits_dir}/test_offset_mapping.npy', allow_pickle=True)
    with open(f'{logits_dir}/test_example_id.json','r') as f:
        example_ids = json.load(f)
    N = len(example_ids)
    assert len(start_logits)==len(end_logits)==len(offsets)==N, f'mismatch {name}'
    print(f'[{name}] features={N}, seq_len={start_logits.shape[1]}', flush=True)
    best = {}  # eid -> (score, st, ed)
    for i in range(N):
        if i and (i%1000==0):
            print(f'[{name}] {i}/{N} in {time.time()-t:.1f}s', flush=True)
        eid = str(example_ids[i])
        ctx = id2ctx.get(eid, '')
        lang = id2lang.get(eid, 'hi')
        lam = LAMBDA_HI if lang=='hi' else LAMBDA_TA
        cap = CAP_HI if lang=='hi' else CAP_TA
        off = ensure_offset_array(offsets[i])
        if off is None: continue
        s_log = start_logits[i].copy(); e_log = end_logits[i].copy()
        valid = (off[:,0] + off[:,1]) > 0
        if valid.shape[0] > 0: valid[0] = False
        s_log[~valid] = -1e30; e_log[~valid] = -1e30
        k_eff = int(min(K, int(valid.sum()))) if valid.ndim==1 else K
        if k_eff <= 0: continue
        try:
            s_idx = np.argpartition(-s_log, k_eff)[:k_eff]
            e_idx = np.argpartition(-e_log, k_eff)[:k_eff]
        except ValueError:
            s_idx = np.argsort(-s_log)[:k_eff]
            e_idx = np.argsort(-e_log)[:k_eff]
        local_best = None
        for si in s_idx:
            if si<0 or si>=valid.shape[0] or not valid[si]: continue
            for ei in e_idx:
                if ei<0 or ei>=valid.shape[0] or not valid[ei] or ei<si: continue
                st_char = int(off[si,0]); ed_char = int(off[ei,1])
                if ed_char <= st_char: continue
                span_len = ed_char - st_char
                if span_len <= 0 or span_len > cap: continue
                score = float(s_log[si] + e_log[ei] - lam*span_len)
                if isinstance(ctx, str) and ctx:
                    Lb = (st_char==0) or is_space_or_punct(ctx[st_char-1])
                    Rb = (ed_char>=len(ctx)) or is_space_or_punct(ctx[ed_char:ed_char+1])
                    if Lb and Rb: score += BOUNDARY_BONUS
                if (local_best is None) or (score>local_best[0]):
                    local_best = (score, st_char, ed_char)
        if local_best is None: continue
        if (eid not in best) or (local_best[0] > best[eid][0]):
            best[eid] = local_best
    # build predictions and scores
    preds = {}; scores = {}
    for eid in test_df['id']:
        ctx = id2ctx.get(eid, '')
        if (eid in best) and isinstance(ctx, str) and ctx:
            sc, st, ed = best[eid]
            st = max(0, min(len(ctx), int(st))); ed = max(st, min(len(ctx), int(ed)))
            preds[eid] = ctx[st:ed]; scores[eid] = float(sc)
        else:
            preds[eid] = ''; scores[eid] = -1e9
    return preds, scores

# Decode XLM-R and MuRIL
xlmr_preds, xlmr_scores = decode_dir('xlmr_large_test_logits', 'xlmr')
muril_preds, muril_scores = decode_dir('muril_large_test_logits', 'muril')

# Per-id choose higher scoring; ensure substring validity
out_rows = []
for eid in test_df['id']:
    ctx = id2ctx.get(eid, '')
    ax, sx = xlmr_preds.get(eid, ''), xlmr_scores.get(eid, -1e9)
    am, sm = muril_preds.get(eid, ''), muril_scores.get(eid, -1e9)
    cand = [('xlmr', ax, sx), ('muril', am, sm)]
    # prefer answers that occur in context
    cand_valid = [c for c in cand if isinstance(ctx, str) and ctx and c[1] and ctx.find(c[1])!=-1]
    use = None
    if cand_valid:
        use = max(cand_valid, key=lambda x: x[2])
    else:
        use = max(cand, key=lambda x: x[2])
    out_rows.append((eid, use[1]))

sub = pd.DataFrame(out_rows, columns=['id','PredictionString'])
sub['len'] = sub['PredictionString'].astype(str).str.len()
df_lang = sub.merge(test_df[['id','language']], on='id', how='left')
print(f"mean_len overall: {sub['len'].mean():.2f} | HI: {df_lang.loc[df_lang['language']=='hi','len'].mean():.2f} | TA: {df_lang.loc[df_lang['language']=='ta','len'].mean():.2f}")
out_fp = 'submission_choose_xlmr_or_muril_logitdecoder.csv'
sub[['id','PredictionString']].to_csv(out_fp, index=False)
sub[['id','PredictionString']].to_csv('submission.csv', index=False)
print('Wrote', out_fp, 'and updated submission.csv in', f'{time.time()-t:.1f}s')

[xlmr] features=1921, seq_len=384


[xlmr] 1000/1921 in 1.6s


[muril] features=1513, seq_len=384


[muril] 1000/1513 in 3.4s


mean_len overall: 9.54 | HI: 9.94 | TA: 8.36
Wrote submission_choose_xlmr_or_muril_logitdecoder.csv and updated submission.csv in 3.7s


In [20]:
# XLM-R logit decoder: quick lambda sweep to target safer mean lengths
import numpy as np, pandas as pd, json, time, unicodedata as ud, re, sys

t = time.time()
LOGITS_DIR = 'xlmr_large_test_logits'
START_GLOB = [f'{LOGITS_DIR}/test_start_logits_f{i}.npy' for i in range(5)]
END_GLOB   = [f'{LOGITS_DIR}/test_end_logits_f{i}.npy'   for i in range(5)]
OFFSET_FP  = f'{LOGITS_DIR}/test_offset_mapping.npy'
EID_FP     = f'{LOGITS_DIR}/test_example_id.json'

CAP_HI = 18
CAP_TA = 16
K = 20
BOUNDARY_BONUS = 0.05

test_df = pd.read_csv('test.csv')
def detect_lang(text):
    if not isinstance(text, str):
        return 'hi'
    return 'ta' if any('\u0B80' <= c <= '\u0BFF' for c in text) else 'hi'
col = 'question' if 'question' in test_df.columns else ('question_text' if 'question_text' in test_df.columns else ('context' if 'context' in test_df.columns else None))
if col:
    test_df['language'] = test_df[col].apply(detect_lang)
else:
    test_df['language'] = 'hi'
test_df['id'] = test_df['id'].astype(str)
id2lang = dict(zip(test_df['id'], test_df['language']))
id2ctx = dict(zip(test_df['id'], test_df['context'])) if 'context' in test_df.columns else {}

def is_space_or_punct(c):
    return c.isspace() or ud.category(c).startswith('P')

def ensure_offset_array(off_raw):
    if isinstance(off_raw, np.ndarray) and off_raw.ndim == 2 and off_raw.shape[1] == 2:
        return off_raw.astype(np.int32, copy=False)
    try:
        seq = off_raw.tolist() if hasattr(off_raw, 'tolist') else list(off_raw)
    except Exception:
        return None
    pairs = []
    for p in seq:
        try:
            a = int(p[0]) if (p is not None and len(p) > 0 and p[0] is not None) else 0
            b = int(p[1]) if (p is not None and len(p) > 1 and p[1] is not None) else 0
        except Exception:
            a, b = 0, 0
        pairs.append((a, b))
    arr = np.asarray(pairs, dtype=np.int32)
    if arr.ndim != 2 or arr.shape[1] != 2:
        return None
    return arr

def decode_once(lambda_hi, lambda_ta):
    starts = [np.load(fp) for fp in START_GLOB]
    ends   = [np.load(fp) for fp in END_GLOB]
    start_logits = np.mean(np.stack(starts, axis=0), axis=0)
    end_logits   = np.mean(np.stack(ends,   axis=0), axis=0)
    offsets = np.load(OFFSET_FP, allow_pickle=True)
    with open(EID_FP, 'r') as f:
        example_ids = json.load(f)
    N = len(example_ids)
    best_by_eid = {}
    for i in range(N):
        eid = str(example_ids[i])
        ctx = id2ctx.get(eid, '')
        lang = id2lang.get(eid, 'hi')
        lam = lambda_hi if lang == 'hi' else lambda_ta
        cap = CAP_HI if lang == 'hi' else CAP_TA
        off = ensure_offset_array(offsets[i])
        if off is None:
            continue
        s_log = start_logits[i].copy()
        e_log = end_logits[i].copy()
        valid = (off[:,0] + off[:,1]) > 0
        if valid.shape[0] > 0:
            valid[0] = False
        s_log[~valid] = -1e30
        e_log[~valid] = -1e30
        k_eff = int(min(K, int(valid.sum()))) if valid.ndim == 1 else K
        if k_eff <= 0:
            continue
        try:
            s_idx = np.argpartition(-s_log, k_eff)[:k_eff]
            e_idx = np.argpartition(-e_log, k_eff)[:k_eff]
        except ValueError:
            s_idx = np.argsort(-s_log)[:k_eff]
            e_idx = np.argsort(-e_log)[:k_eff]
        local_best = None
        for si in s_idx:
            if si < 0 or si >= valid.shape[0] or not valid[si]:
                continue
            for ei in e_idx:
                if ei < 0 or ei >= valid.shape[0] or not valid[ei] or ei < si:
                    continue
                st_char = int(off[si,0]); ed_char = int(off[ei,1])
                if ed_char <= st_char:
                    continue
                span_len = ed_char - st_char
                if span_len <= 0 or span_len > cap:
                    continue
                score = float(s_log[si] + e_log[ei] - lam * span_len)
                if isinstance(ctx, str) and ctx:
                    Lb = (st_char == 0) or is_space_or_punct(ctx[st_char-1])
                    Rb = (ed_char >= len(ctx)) or is_space_or_punct(ctx[ed_char:ed_char+1])
                    if Lb and Rb:
                        score += BOUNDARY_BONUS
                if (local_best is None) or (score > local_best[0]):
                    local_best = (score, st_char, ed_char)
        if local_best is None:
            continue
        if (eid not in best_by_eid) or (local_best[0] > best_by_eid[eid][0]):
            best_by_eid[eid] = local_best
    pred_rows = []
    for eid in test_df['id']:
        ctx = id2ctx.get(eid, '')
        if eid in best_by_eid and isinstance(ctx, str) and ctx:
            _, st, ed = best_by_eid[eid]
            st = max(0, min(len(ctx), int(st))); ed = max(st, min(len(ctx), int(ed)))
            pred_rows.append((eid, ctx[st:ed]))
        else:
            pred_rows.append((eid, ''))
    sub = pd.DataFrame(pred_rows, columns=['id','PredictionString'])
    sub['len'] = sub['PredictionString'].astype(str).str.len()
    df_lang = sub.merge(test_df[['id','language']], on='id', how='left')
    mean_all = sub['len'].mean(); mean_hi = df_lang.loc[df_lang['language']=='hi','len'].mean(); mean_ta = df_lang.loc[df_lang['language']=='ta','len'].mean()
    return sub[['id','PredictionString']], mean_all, mean_hi, mean_ta

grid = [
    (0.015, 0.038),
    (0.018, 0.040),
    (0.020, 0.042),
    (0.022, 0.045),
    (0.025, 0.048),
]

for lam_hi, lam_ta in grid:
    sub, m_all, m_hi, m_ta = decode_once(lam_hi, lam_ta)
    out_fp = f'submission_xlmr_logitdecoder_lenprior_caps_hi{lam_hi:.3f}_ta{lam_ta:.3f}.csv'
    sub.to_csv(out_fp, index=False)
    print(f'lambda_hi={lam_hi:.3f} lambda_ta={lam_ta:.3f} -> mean_len: all {m_all:.2f} | HI {m_hi:.2f} | TA {m_ta:.2f} -> wrote {out_fp}', flush=True)

print('Lambda sweep done in', f'{time.time()-t:.1f}s')

lambda_hi=0.015 lambda_ta=0.038 -> mean_len: all 9.65 | HI 9.98 | TA 8.68 -> wrote submission_xlmr_logitdecoder_lenprior_caps_hi0.015_ta0.038.csv


lambda_hi=0.018 lambda_ta=0.040 -> mean_len: all 9.65 | HI 9.98 | TA 8.68 -> wrote submission_xlmr_logitdecoder_lenprior_caps_hi0.018_ta0.040.csv


lambda_hi=0.020 lambda_ta=0.042 -> mean_len: all 9.65 | HI 9.98 | TA 8.68 -> wrote submission_xlmr_logitdecoder_lenprior_caps_hi0.020_ta0.042.csv


lambda_hi=0.022 lambda_ta=0.045 -> mean_len: all 9.65 | HI 9.98 | TA 8.68 -> wrote submission_xlmr_logitdecoder_lenprior_caps_hi0.022_ta0.045.csv


lambda_hi=0.025 lambda_ta=0.048 -> mean_len: all 9.65 | HI 9.98 | TA 8.68 -> wrote submission_xlmr_logitdecoder_lenprior_caps_hi0.025_ta0.048.csv


Lambda sweep done in 10.1s


In [38]:
# Improved XLM-R logit decoder: start-only top-K, forward end scan, per-feature z-norm so length prior bites; do not touch submission.csv
import numpy as np, pandas as pd, json, time, unicodedata as ud, re, sys

t = time.time()
LOGITS_DIR = 'xlmr_large_test_logits'
START_GLOB = [f'{LOGITS_DIR}/test_start_logits_f{i}.npy' for i in range(5)]
END_GLOB   = [f'{LOGITS_DIR}/test_end_logits_f{i}.npy'   for i in range(5)]
OFFSET_FP  = f'{LOGITS_DIR}/test_offset_mapping.npy'
EID_FP     = f'{LOGITS_DIR}/test_example_id.json'

# Per-expert: bump caps slightly to allow useful longer spans without drift
CAP_HI = 19
CAP_TA = 17
K_START = 120  # top-K for starts only; ends are scanned forward until char-cap
BOUNDARY_BONUS = 0.12
TGT_HI = 11
TGT_TA = 10

test_df = pd.read_csv('test.csv')
def detect_lang(text):
    if not isinstance(text, str):
        return 'hi'
    return 'ta' if any('\u0B80' <= c <= '\u0BFF' for c in text) else 'hi'
col = 'question' if 'question' in test_df.columns else ('question_text' if 'question_text' in test_df.columns else ('context' if 'context' in test_df.columns else None))
if col:
    test_df['language'] = test_df[col].apply(detect_lang)
else:
    test_df['language'] = 'hi'
test_df['id'] = test_df['id'].astype(str)
id2lang = dict(zip(test_df['id'], test_df['language']))
id2ctx = dict(zip(test_df['id'], test_df['context'])) if 'context' in test_df.columns else {}

def is_space_or_punct(c):
    return c.isspace() or ud.category(c).startswith('P')

def ensure_offset_array(off_raw):
    if isinstance(off_raw, np.ndarray) and off_raw.ndim == 2 and off_raw.shape[1] == 2:
        return off_raw.astype(np.int32, copy=False)
    try:
        seq = off_raw.tolist() if hasattr(off_raw, 'tolist') else list(off_raw)
    except Exception:
        return None
    pairs = []
    for p in seq:
        try:
            a = int(p[0]) if (p is not None and len(p) > 0 and p[0] is not None) else 0
            b = int(p[1]) if (p is not None and len(p) > 1 and p[1] is not None) else 0
        except Exception:
            a, b = 0, 0
        pairs.append((a, b))
    arr = np.asarray(pairs, dtype=np.int32)
    if arr.ndim != 2 or arr.shape[1] != 2:
        return None
    return arr

def znorm_row(x, valid_mask):
    v = x[valid_mask]
    if v.size == 0:
        return x
    m = v.mean()
    s = v.std()
    y = x.copy()
    if (not np.isfinite(s)) or (s < 1e-6):
        y[valid_mask] = v - m
        return y
    y[valid_mask] = (v - m) / (s + 1e-6)
    return y

def decode_once(lambda_hi, lambda_ta, tag):
    starts = [np.load(fp) for fp in START_GLOB]
    ends   = [np.load(fp) for fp in END_GLOB]
    start_logits = np.mean(np.stack(starts, axis=0), axis=0)
    end_logits   = np.mean(np.stack(ends,   axis=0), axis=0)
    offsets = np.load(OFFSET_FP, allow_pickle=True)
    with open(EID_FP, 'r') as f:
        example_ids = json.load(f)
    N = len(example_ids)
    best_by_eid = {}  # eid -> (score, st_char, ed_char, span_len)
    for i in range(N):
        if i and (i % 1000 == 0):
            print(f'{i}/{N} in {time.time()-t:.1f}s', flush=True)
        eid = str(example_ids[i])
        ctx = id2ctx.get(eid, '')
        lang = id2lang.get(eid, 'hi')
        lam = lambda_hi if lang == 'hi' else lambda_ta
        tgt = TGT_HI if lang == 'hi' else TGT_TA
        cap = CAP_HI if lang == 'hi' else CAP_TA
        off = ensure_offset_array(offsets[i])
        if off is None:
            continue
        s_log = start_logits[i].copy(); e_log = end_logits[i].copy()
        valid = (off[:,0] + off[:,1]) > 0
        if valid.shape[0] > 0:
            valid[0] = False  # drop CLS
        s_log[~valid] = -1e30; e_log[~valid] = -1e30
        # z-normalize over valid tokens so lambda competes with logits
        s_log = znorm_row(s_log, valid)
        e_log = znorm_row(e_log, valid)
        k_eff = int(min(K_START, int(valid.sum()))) if valid.ndim == 1 else K_START
        if k_eff <= 0:
            continue
        try:
            s_idx = np.argpartition(-s_log, k_eff)[:k_eff]
        except ValueError:
            s_idx = np.argsort(-s_log)[:k_eff]
        local_best = None
        L = valid.shape[0]
        for si in s_idx:
            if si < 0 or si >= L or not valid[si]:
                continue
            st_char = int(off[si,0])
            # forward scan ends until char-cap exceeded; offsets are monotonic so we can break early
            for ei in range(si, L):
                if not valid[ei]:
                    continue
                ed_char = int(off[ei,1])
                if ed_char <= st_char:
                    continue
                span_len = ed_char - st_char
                if span_len > cap:
                    break
                score = float(s_log[si] + e_log[ei] - lam * abs(span_len - tgt))
                if isinstance(ctx, str) and ctx:
                    Lb = (st_char == 0) or is_space_or_punct(ctx[st_char-1])
                    Rb = (ed_char >= len(ctx)) or is_space_or_punct(ctx[ed_char:ed_char+1])
                    if Lb and Rb:
                        score += BOUNDARY_BONUS
                if (local_best is None) or (score > local_best[0]):
                    local_best = (score, st_char, ed_char, span_len)
        if local_best is None:
            continue
        if (eid not in best_by_eid) or (local_best[0] > best_by_eid[eid][0]):
            best_by_eid[eid] = local_best
    pred_rows = []
    for eid in test_df['id']:
        ctx = id2ctx.get(eid, '')
        if (eid in best_by_eid) and isinstance(ctx, str) and ctx:
            _, st, ed, _ = best_by_eid[eid]
            st = max(0, min(len(ctx), int(st))); ed = max(st, min(len(ctx), int(ed)))
            pred_rows.append((eid, ctx[st:ed]))
        else:
            pred_rows.append((eid, ''))
    sub = pd.DataFrame(pred_rows, columns=['id','PredictionString'])
    sub['len'] = sub['PredictionString'].astype(str).str.len()
    df_lang = sub.merge(test_df[['id','language']], on='id', how='left')
    m_all = sub['len'].mean(); m_hi = df_lang.loc[df_lang['language']=='hi','len'].mean(); m_ta = df_lang.loc[df_lang['language']=='ta','len'].mean()
    out_fp = f'submission_xlmr_logitdecoder_lenprior_caps_{tag}_scan.csv'
    sub[['id','PredictionString']].to_csv(out_fp, index=False)
    print(f'{tag}_scan: mean_len all {m_all:.2f} | HI {m_hi:.2f} | TA {m_ta:.2f} -> {out_fp}', flush=True)
    return out_fp, m_all, m_hi, m_ta

# Per-expert: run caps-bumped settings closest to target means
grid = [
    (1.4, 1.8, 'hi1.40_ta1.80_caps19_17'),
    (1.6, 2.0, 'hi1.60_ta2.00_caps19_17'),
]

results = []
for lam_hi, lam_ta, tag in grid:
    fp, m_all, m_hi, m_ta = decode_once(lam_hi, lam_ta, tag)
    results.append((fp, m_all))

print('Done. Not updating submission.csv. Elapsed:', f'{time.time()-t:.1f}s')

1000/1921 in 3.7s


hi1.40_ta1.80_caps19_17_scan: mean_len all 10.57 | HI 10.96 | TA 9.39 -> submission_xlmr_logitdecoder_lenprior_caps_hi1.40_ta1.80_caps19_17_scan.csv


1000/1921 in 10.4s


hi1.60_ta2.00_caps19_17_scan: mean_len all 10.53 | HI 10.90 | TA 9.39 -> submission_xlmr_logitdecoder_lenprior_caps_hi1.60_ta2.00_caps19_17_scan.csv


Done. Not updating submission.csv. Elapsed: 12.9s


In [97]:
# 4-stream no-expansion ensemble simplified: majority over all 4, boundary presence gate, numeric override enabled, substring safety, TA boundary-extend
import pandas as pd, numpy as np, unicodedata as ud, re, time, sys, os

t4 = time.time()

# Inputs
fp_ts = 'submission_tokenselect_512single_or_384_lambda012.csv'
fp_td = 'submission_tokendp_512.csv'
fp_cf = 'submission_charfusion_512_single_seed_alignsafe.csv'
# Choose an LD file produced by Cell 7; prefer newest *_caps19_17_scan (z-norm + forward scan) first
fp_ld_candidates = [
    'submission_xlmr_logitdecoder_lenprior_caps_hi1.40_ta1.80_caps19_17_scan.csv',
    'submission_xlmr_logitdecoder_lenprior_caps_hi1.60_ta2.00_caps19_17_scan.csv',
    'submission_xlmr_logitdecoder_lenprior_caps_hi1.40_ta1.80_scan.csv',
    'submission_xlmr_logitdecoder_lenprior_caps_hi1.60_ta2.00_scan.csv',
    'submission_xlmr_logitdecoder_lenprior_caps_hi1.20_ta1.60_scan.csv',
    'submission_xlmr_logitdecoder_lenprior_caps_hi1.00_ta1.40_scan.csv'
]
fp_ld = next((c for c in fp_ld_candidates if os.path.exists(c)), None)
if fp_ld is None:
    fp_ld = 'submission_xlmr_logitdecoder_lenprior_caps_hi0.020_ta0.042_K200.csv' if os.path.exists('submission_xlmr_logitdecoder_lenprior_caps_hi0.020_ta0.042_K200.csv') else 'submission_xlmr_logitdecoder_lenprior_caps.csv'

sub_ts = pd.read_csv(fp_ts)
sub_td = pd.read_csv(fp_td)
sub_cf = pd.read_csv(fp_cf)
sub_ld = pd.read_csv(fp_ld)
test_df = pd.read_csv('test.csv')

for df in (sub_ts, sub_td, sub_cf, sub_ld):
    assert 'id' in df.columns and 'PredictionString' in df.columns
    df['id'] = df['id'].astype(str)
test_df['id'] = test_df['id'].astype(str)

def detect_lang(text):
    if not isinstance(text, str):
        return 'hi'
    return 'ta' if any('\u0B80' <= c <= '\u0BFF' for c in text) else 'hi'
col = 'question' if 'question' in test_df.columns else ('question_text' if 'question_text' in test_df.columns else ('context' if 'context' in test_df.columns else None))
if col:
    test_df['language'] = test_df[col].apply(detect_lang)
else:
    test_df['language'] = 'hi'

base = (test_df[['id']]
    .merge(sub_ts.rename(columns={'PredictionString':'TS'}), on='id', how='left')
    .merge(sub_ld.rename(columns={'PredictionString':'LD'}), on='id', how='left')
    .merge(sub_td.rename(columns={'PredictionString':'TD'}), on='id', how='left')
    .merge(sub_cf.rename(columns={'PredictionString':'CF'}), on='id', how='left'))

ctx_col = 'context' if 'context' in test_df.columns else None
id2ctx = dict(zip(test_df['id'], test_df[ctx_col])) if ctx_col else {}
id2lang = dict(zip(test_df['id'], test_df['language']))

PUNCT_CLASS = ''.join(chr(i) for i in range(sys.maxunicode) if ud.category(chr(i)).startswith('P'))
PUNCT_RE = re.compile(f"[\s{re.escape(PUNCT_CLASS)}]+")
DANDA='\u0964'; TA_PULLI='\u0bcd'; DEV_VIRAMA='\u094d'

def nfc(s):
    try: return ud.normalize('NFC', s)
    except Exception: return s

def norm_basic(s):
    if not isinstance(s, str): s = ''
    s = nfc(s.strip())
    # remove ZWJ/ZWNJ
    s = s.replace('\u200c','').replace('\u200d','')
    # punctuation+spaces collapse
    s = PUNCT_RE.sub(' ', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

HI_DIGITS = {ord(c): ord('0')+i for i, c in enumerate('\u0966\u0967\u0968\u0969\u096a\u096b\u096c\u096d\u096e\u096f')}
TA_DIGITS = {ord(c): ord('0')+i for i, c in enumerate('\u0be6\u0be7\u0be8\u0be9\u0bea\u0beb\u0bec\u0bed\u0bee\u0bef')}
def ascii_digits(s):
    if not isinstance(s, str):
        return ''
    return s.translate(HI_DIGITS).translate(TA_DIGITS)
SEP_RE = re.compile(r'[\./\-,:\s]+')
NUM_CHARS_RE = re.compile(r'^[0-9\-]+$')
def is_numeric_or_date_like(s, lang, cap):
    if not isinstance(s, str) or not s: return False
    s2 = ascii_digits(nfc(s)); s2 = SEP_RE.sub('-', s2).strip('-')
    if not s2 or len(s2) > cap: return False
    return bool(NUM_CHARS_RE.match(s2))

def is_space_or_punct(c):
    return c.isspace() or ud.category(c).startswith('P')

def is_present_boundary(ans, ctx):
    if not (isinstance(ans,str) and ans and isinstance(ctx,str) and ctx): return False
    st = ctx.find(ans)
    if st == -1: return False
    ed = st + len(ans)
    left = (st==0) or is_space_or_punct(ctx[st-1])
    right = (ed==len(ctx)) or is_space_or_punct(ctx[ed])
    return left and right

PRIORITY = {'TS':3,'LD':2,'TD':1,'CF':0}

out = []
routes = {'majority':0,'present_priority':0,'fallback_td':0,'numeric_override':0,'substr_gate':0,'trail_shorter':0,'tamil_digit_norm':0,'ts_ld_flip':0,'tamil_rescue':0,'ta_boundary_extend':0}

for i, row in base.iterrows():
    if (i % 500 == 0) and i:
        print(f'Processed {i}/{len(base)} in {time.time()-t4:.1f}s', flush=True)
    _id = row['id']
    lang = id2lang.get(_id, 'hi')
    cap = 18 if lang=='hi' else 16
    # Micro cap-override per expert: TA uses 17
    cap_override = 18 if lang=='hi' else 17
    ctx = id2ctx.get(_id, '') if ctx_col else ''
    ts = row['TS'] if isinstance(row['TS'], str) else ''
    ld = row['LD'] if isinstance(row['LD'], str) else ''
    td = row['TD'] if isinstance(row['TD'], str) else ''
    cf = row['CF'] if isinstance(row['CF'], str) else ''
    cand = {'TS':ts,'LD':ld,'TD':td,'CF':cf}
    nrm = {k: norm_basic(v) for k,v in cand.items()}

    # Majority vote over all four streams (TS, LD, TD, CF) with TA tie pref (LD over TS when norms tie)
    counts = {}
    for k in ('TS','LD','TD','CF'):
        n = nrm[k]
        counts.setdefault(n, []).append(k)
    chosen = None; chosen_src = None; majority_key = None
    for n, srcs in counts.items():
        if n and len(srcs) >= 2:
            top = sorted(srcs, key=lambda x: -PRIORITY[x])[0]
            if lang == 'ta' and ('TS' in srcs) and ('LD' in srcs) and top == 'TS':
                top = 'LD'
            chosen, chosen_src = cand[top], top
            majority_key = n
            routes['majority'] += 1
            break

    # Presence gating: among TS/LD/TD that occur in context, prefer boundary-robust first, then priority TS>LD>TD
    if chosen is None:
        present = []
        if isinstance(ctx,str) and ctx:
            for key in ('TS','LD','TD'):
                v = cand[key]
                if v and ctx.find(v) != -1:
                    present.append(key)
        if present:
            pref = [(key, cand[key], is_present_boundary(cand[key], ctx)) for key in present]
            robust = [(key,v) for key,v,rob in pref if rob]
            pool = robust if robust else [(key,v) for key,v,_rob in pref]
            best = sorted(pool, key=lambda kv: -PRIORITY[kv[0]])[0]
            chosen, chosen_src = best[1], best[0]
            routes['present_priority'] += 1
        else:
            # nothing TS/LD/TD in context: allow CF if present, else TD
            if isinstance(ctx,str) and ctx and cf and ctx.find(cf) != -1:
                chosen, chosen_src = cf, 'CF'
                routes['present_priority'] += 1
            else:
                chosen, chosen_src = td, 'TD'
                routes['fallback_td'] += 1

    # Strengthened TS->LD flip: rarer and never longer than TS
    if chosen_src == 'TS' and isinstance(ld, str) and ld and isinstance(ctx, str) and ctx:
        ts_in_ctx = (ctx.find(ts) != -1)
        ts_too_long = len(ts) > cap
        st_ts = ctx.find(ts); ed_ts = st_ts + len(ts) if st_ts != -1 else -1
        ts_left_bad = (st_ts > 0) and (not is_space_or_punct(ctx[st_ts-1]))
        ts_right_bad = (0 <= ed_ts < len(ctx)) and (not is_space_or_punct(ctx[ed_ts]))
        ts_not_clean = (st_ts == -1) or ts_left_bad or ts_right_bad
        ts_bad = ts_too_long or (not ts_in_ctx) or ts_not_clean

        st_ld = ctx.find(ld); ed_ld = st_ld + len(ld) if st_ld != -1 else -1
        ld_present = (st_ld != -1)
        ld_rare = ld_present and (ctx.count(ld) <= 1)
        ld_left_ok = (st_ld == 0) or is_space_or_punct(ctx[st_ld-1])
        ld_right_ok = (ed_ld == len(ctx)) or (0 <= ed_ld < len(ctx) and is_space_or_punct(ctx[ed_ld]))
        ld_clean = ld_present and ld_left_ok and ld_right_ok

        if ts_bad and ld_clean and ld_rare and (len(ld) <= cap) and (len(ld) < len(ts)):
            chosen, chosen_src = ld, 'LD'
            routes['ts_ld_flip'] += 1

    # Trailing mark rule when TS and LD normalize equal but raw differ: pick shorter if longer ends with terminal mark
    if nrm.get('TS','') and (nrm.get('TS') == nrm.get('LD')) and ts and ld and (ts != ld):
        t_short, t_long = (ts, ld) if len(ts) < len(ld) else (ld, ts)
        if t_long and (t_long.endswith(DANDA) or t_long.endswith(DEV_VIRAMA) or t_long.endswith(TA_PULLI) or (ud.category(t_long[-1]).startswith('P'))):
            if chosen in (ts, ld) and chosen != t_short:
                chosen = t_short
                routes['trail_shorter'] += 1

    # Hindi danda trim: ultra-safe
    if lang == 'hi' and isinstance(chosen, str) and chosen.endswith(DANDA) and isinstance(ctx, str) and ctx:
        trimmed = chosen[:-1].strip()
        if trimmed and is_present_boundary(trimmed, ctx):
            chosen = trimmed
            routes['trail_shorter'] += 1

    # Tamil rescue: stricter to cut overfiring
    if lang == 'ta' and isinstance(chosen, str) and chosen and len(chosen) < 5 and isinstance(ts, str) and ts:
        if len(ts) <= cap and isinstance(ctx, str) and ctx and ctx.count(ts) == 1:
            st = ctx.find(ts)
            if st != -1:
                ed = st + len(ts)
                left = (st == 0) or is_space_or_punct(ctx[st-1])
                right = (ed == len(ctx)) or is_space_or_punct(ctx[ed])
                if left and right:
                    chosen, chosen_src = ts, 'TS'
                    routes['tamil_rescue'] += 1

    # Numeric/date override (enabled, tight): require TS, LD, TD all numeric-like; pick TD if unique boundary match
    if is_numeric_or_date_like(ts, lang, cap_override) and is_numeric_or_date_like(ld, lang, cap_override) and is_numeric_or_date_like(td, lang, cap_override):
        if isinstance(ctx,str) and ctx:
            td_ascii = ascii_digits(td)
            st = ctx.find(td_ascii)
            if st != -1:
                ed = st + len(td_ascii)
                left = (st==0) or is_space_or_punct(ctx[st-1])
                right = (ed==len(ctx)) or is_space_or_punct(ctx[ed])
                if left and right and len(td) <= cap_override and ctx.count(td_ascii) <= 2:
                    chosen, chosen_src = td, 'TD'
                    routes['numeric_override'] += 1

    # Tamil boundary-extend (guarded, safer): literal containment, unique occurrence, +1 growth, clean boundaries, cap_override
    if lang == 'ta' and isinstance(chosen, str) and chosen and isinstance(ctx, str) and ctx and len(chosen) < cap_override:
        if not is_numeric_or_date_like(chosen, lang, cap_override):
            choices = []
            for k in ('TS','LD','TD'):
                v = cand.get(k, '')
                if not v or len(v) <= len(chosen) or len(v) > cap_override:
                    continue
                if len(v) > len(chosen) + 1:
                    continue
                if ctx.count(v) != 1:
                    continue
                if not is_present_boundary(v, ctx):
                    continue
                if chosen not in v:
                    continue
                choices.append((len(v), PRIORITY.get(k,0), k, v))
            if choices:
                choices.sort()
                _lv, _pr, k_best, v_best = choices[-1]
                chosen, chosen_src = v_best, k_best
                routes['ta_boundary_extend'] += 1

    # Substring safety: if chosen not present but some others are, snap to highest-priority present among TS/LD/TD, else CF, else TD
    if isinstance(ctx,str) and ctx and chosen:
        if ctx.find(chosen) == -1:
            present = []
            for key in ('TS','LD','TD'):
                v = cand[key]
                if v and ctx.find(v) != -1:
                    present.append(key)
            if not present and cf and ctx.find(cf) != -1:
                present = ['CF']
            if present:
                best = sorted(present, key=lambda x: -PRIORITY[x])[0]
                chosen, chosen_src = cand[best], best
                routes['substr_gate'] += 1

    # Tamil digit normalization for final matching: if TA and ascii_digits(chosen) is in ctx but chosen is not, switch
    if lang == 'ta' and isinstance(ctx,str) and ctx and isinstance(chosen,str) and chosen:
        ch_ascii = ascii_digits(chosen)
        if (ctx.find(ch_ascii) != -1) and (ctx.find(chosen) == -1):
            chosen = ch_ascii
            routes['tamil_digit_norm'] += 1

    out.append((_id, chosen))

sub = pd.DataFrame(out, columns=['id','PredictionString'])
sub['len'] = sub['PredictionString'].astype(str).str.len()
df_lang = sub.merge(test_df[['id','language']], on='id', how='left')
print(f"mean_len overall: {sub['len'].mean():.2f} | HI: {df_lang.loc[df_lang['language']=='hi','len'].mean():.2f} | TA: {df_lang.loc[df_lang['language']=='ta','len'].mean():.2f}")
total = len(sub)
print('Routes %: ' + ', '.join(f"{k}={v/total*100:.1f}%" for k,v in routes.items()), flush=True)

out_fp = f'submission_4stream_noexp_TS_LD_TD_CF.csv'
sub[['id','PredictionString']].to_csv(out_fp, index=False)
print('Wrote', out_fp, 'in', f'{time.time()-t4:.1f}s')
print('Note: Not updating submission.csv here.')

mean_len overall: 10.92 | HI: 11.07 | TA: 10.46
Routes %: majority=89.3%, present_priority=10.7%, fallback_td=0.0%, numeric_override=0.9%, substr_gate=3.6%, trail_shorter=0.0%, tamil_digit_norm=0.0%, ts_ld_flip=1.8%, tamil_rescue=2.7%, ta_boundary_extend=0.9%


Wrote submission_4stream_noexp_TS_LD_TD_CF.csv in 0.3s
Note: Not updating submission.csv here.


In [30]:
# Logit-level blend: XLM-R (384) + MuRIL (384) with per-feature z-norm, targeted length prior, and safe caps
import numpy as np, pandas as pd, json, time, unicodedata as ud, re, sys, os

t = time.time()
DIR_X = 'xlmr_large_test_logits'
DIR_M = 'muril_large_test_logits'
STARTS_X = [f'{DIR_X}/test_start_logits_f{i}.npy' for i in range(5)]
ENDS_X   = [f'{DIR_X}/test_end_logits_f{i}.npy'   for i in range(5)]
STARTS_M = [f'{DIR_M}/test_start_logits_f{i}.npy' for i in range(5)]
ENDS_M   = [f'{DIR_M}/test_end_logits_f{i}.npy'   for i in range(5)]
OFF_X = f'{DIR_X}/test_offset_mapping.npy'
EID_X = f'{DIR_X}/test_example_id.json'

CAP_HI = 18
CAP_TA = 16
K = 100  # moderate beam, faster than 200 but enough diversity
BOUNDARY_BONUS = 0.12
LAM_HI = 0.020
LAM_TA = 0.042
TGT_HI = 11
TGT_TA = 10

test_df = pd.read_csv('test.csv')
def detect_lang(text):
    if not isinstance(text, str):
        return 'hi'
    return 'ta' if any('\u0B80' <= c <= '\u0BFF' for c in text) else 'hi'
col = 'question' if 'question' in test_df.columns else ('question_text' if 'question_text' in test_df.columns else ('context' if 'context' in test_df.columns else None))
if col:
    test_df['language'] = test_df[col].apply(detect_lang)
else:
    test_df['language'] = 'hi'
test_df['id'] = test_df['id'].astype(str)
id2lang = dict(zip(test_df['id'], test_df['language']))
id2ctx = dict(zip(test_df['id'], test_df['context'])) if 'context' in test_df.columns else {}

def ensure_offset_array(off_raw):
    if isinstance(off_raw, np.ndarray) and off_raw.ndim == 2 and off_raw.shape[1] == 2:
        return off_raw.astype(np.int32, copy=False)
    try:
        seq = off_raw.tolist() if hasattr(off_raw, 'tolist') else list(off_raw)
    except Exception:
        return None
    pairs = []
    for p in seq:
        try:
            a = int(p[0]) if (p is not None and len(p) > 0 and p[0] is not None) else 0
            b = int(p[1]) if (p is not None and len(p) > 1 and p[1] is not None) else 0
        except Exception:
            a, b = 0, 0
        pairs.append((a, b))
    arr = np.asarray(pairs, dtype=np.int32)
    if arr.ndim != 2 or arr.shape[1] != 2:
        return None
    return arr

def znorm_row(x):
    m = x.mean()
    s = x.std()
    if not np.isfinite(s) or s < 1e-6:
        return x*0.0
    return (x - m) / s

def is_space_or_punct(c):
    return c.isspace() or ud.category(c).startswith('P')

# Load and average folds for each model
sx = np.mean(np.stack([np.load(p) for p in STARTS_X], axis=0), axis=0)
ex = np.mean(np.stack([np.load(p) for p in ENDS_X],   axis=0), axis=0)
sm = np.mean(np.stack([np.load(p) for p in STARTS_M], axis=0), axis=0)
em = np.mean(np.stack([np.load(p) for p in ENDS_M],   axis=0), axis=0)
off = np.load(OFF_X, allow_pickle=True)
with open(EID_X, 'r') as f:
    eids = json.load(f)
N = len(eids)
assert sx.shape == ex.shape and sm.shape == em.shape and sx.shape[0]==N == len(off), 'shape mismatch'
L = sx.shape[1]
print(f'Blending logits: N={N}, L={L}', flush=True)

# Per-feature z-norm, then average models
S = np.empty_like(sx); E = np.empty_like(ex)
for i in range(N):
    S[i] = 0.5*znorm_row(sx[i]) + 0.5*znorm_row(sm[i])
    E[i] = 0.5*znorm_row(ex[i]) + 0.5*znorm_row(em[i])

# Decode best span per example id across all feature windows
best_by_eid = {}  # eid -> (score, st_char, ed_char, span_len)
for i in range(N):
    if i and (i % 1000 == 0):
        print(f'{i}/{N} in {time.time()-t:.1f}s', flush=True)
    eid = str(eids[i])
    ctx = id2ctx.get(eid, '')
    lang = id2lang.get(eid, 'hi')
    lam = LAM_HI if lang=='hi' else LAM_TA
    tgt = TGT_HI if lang=='hi' else TGT_TA
    cap = CAP_HI if lang=='hi' else CAP_TA
    off_i = ensure_offset_array(off[i])
    if off_i is None:
        continue
    s_log = S[i].copy(); e_log = E[i].copy()
    valid = (off_i[:,0] + off_i[:,1]) > 0
    if valid.shape[0] > 0:
        valid[0] = False  # drop CLS
    s_log[~valid] = -1e30; e_log[~valid] = -1e30
    k_eff = int(min(K, int(valid.sum()))) if valid.ndim==1 else K
    if k_eff <= 0:
        continue
    try:
        s_idx = np.argpartition(-s_log, k_eff)[:k_eff]
        e_idx = np.argpartition(-e_log, k_eff)[:k_eff]
    except ValueError:
        s_idx = np.argsort(-s_log)[:k_eff]
        e_idx = np.argsort(-e_log)[:k_eff]
    local_best = None
    for si in s_idx:
        if si < 0 or si >= valid.shape[0] or not valid[si]:
            continue
        for ei in e_idx:
            if ei < 0 or ei >= valid.shape[0] or not valid[ei] or ei < si:
                continue
            st_char = int(off_i[si,0]); ed_char = int(off_i[ei,1])
            if ed_char <= st_char:
                continue
            span_len = ed_char - st_char
            if span_len <= 0 or span_len > cap:
                continue
            score = float(s_log[si] + e_log[ei] - lam * abs(span_len - tgt))
            if isinstance(ctx, str) and ctx:
                Lb = (st_char == 0) or is_space_or_punct(ctx[st_char-1])
                Rb = (ed_char >= len(ctx)) or is_space_or_punct(ctx[ed_char:ed_char+1])
                if Lb and Rb:
                    score += BOUNDARY_BONUS
            if (local_best is None) or (score > local_best[0]):
                local_best = (score, st_char, ed_char, span_len)
    if local_best is None:
        continue
    if (eid not in best_by_eid) or (local_best[0] > best_by_eid[eid][0]):
        best_by_eid[eid] = local_best

# Build submission
pred_rows = []
for eid in test_df['id']:
    ctx = id2ctx.get(eid, '')
    if eid in best_by_eid and isinstance(ctx, str) and ctx:
        _, st, ed, _ = best_by_eid[eid]
        st = max(0, min(len(ctx), int(st))); ed = max(st, min(len(ctx), int(ed)))
        pred_rows.append((eid, ctx[st:ed]))
    else:
        pred_rows.append((eid, ''))
sub = pd.DataFrame(pred_rows, columns=['id','PredictionString'])
sub['len'] = sub['PredictionString'].astype(str).str.len()
df_lang = sub.merge(test_df[['id','language']], on='id', how='left')
print(f"mean_len overall: {sub['len'].mean():.2f} | HI: {df_lang.loc[df_lang['language']=='hi','len'].mean():.2f} | TA: {df_lang.loc[df_lang['language']=='ta','len'].mean():.2f}")
out_fp = 'submission_logitblend_xlmr_muril_targetprior.csv'
sub[['id','PredictionString']].to_csv(out_fp, index=False)
print('Wrote', out_fp, 'Elapsed:', f'{time.time()-t:.1f}s')
print('Note: Not updating submission.csv; primary remains majority-fallback.')

Blending logits: N=1921, L=384


IndexError: index 1513 is out of bounds for axis 0 with size 1513

In [96]:
# Point submission.csv to 4-stream no-expansion ensemble and write
import pandas as pd, os
fp = 'submission_4stream_noexp_TS_LD_TD_CF.csv'
assert os.path.exists(fp), f'Missing {fp}'
sub = pd.read_csv(fp)
sub[['id','PredictionString']].to_csv('submission.csv', index=False)
print('submission.csv updated ->', fp)

submission.csv updated -> submission_4stream_noexp_TS_LD_TD_CF.csv


In [87]:
# LD-majority ensemble: LD1 (hi1.40_ta1.80_caps19_17_scan), LD2 (hi1.60_ta2.00_caps19_17_scan), TD; TS backup; CF last resort
import pandas as pd, numpy as np, unicodedata as ud, re, os, time, sys

t5 = time.time()

fp_ld1 = 'submission_xlmr_logitdecoder_lenprior_caps_hi1.40_ta1.80_caps19_17_scan.csv'
fp_ld2 = 'submission_xlmr_logitdecoder_lenprior_caps_hi1.60_ta2.00_caps19_17_scan.csv'
fp_td  = 'submission_tokendp_512.csv'
fp_ts  = 'submission_tokenselect_512single_or_384_lambda012.csv'
fp_cf  = 'submission_charfusion_512_single_seed_alignsafe.csv'

avail = {
    'LD1': os.path.exists(fp_ld1),
    'LD2': os.path.exists(fp_ld2),
    'TD': os.path.exists(fp_td),
    'TS': os.path.exists(fp_ts),
    'CF': os.path.exists(fp_cf),
}
assert avail['LD1'] or avail['LD2'], 'Missing LD files; run Cell 7 first.'
assert avail['TD'], 'Missing TD file.'

sub_ld1 = pd.read_csv(fp_ld1) if avail['LD1'] else None
sub_ld2 = pd.read_csv(fp_ld2) if avail['LD2'] else None
sub_td  = pd.read_csv(fp_td)
sub_ts  = pd.read_csv(fp_ts) if avail['TS'] else None
sub_cf  = pd.read_csv(fp_cf) if avail['CF'] else None
test_df = pd.read_csv('test.csv')

for df in [d for d in (sub_ld1, sub_ld2, sub_td, sub_ts, sub_cf) if d is not None]:
    assert 'id' in df.columns and 'PredictionString' in df.columns
    df['id'] = df['id'].astype(str)
test_df['id'] = test_df['id'].astype(str)

def detect_lang(text):
    if not isinstance(text, str): return 'hi'
    return 'ta' if any('\u0B80' <= c <= '\u0BFF' for c in text) else 'hi'
col = 'question' if 'question' in test_df.columns else ('question_text' if 'question_text' in test_df.columns else ('context' if 'context' in test_df.columns else None))
if col: test_df['language'] = test_df[col].apply(detect_lang)
else: test_df['language'] = 'hi'

base = test_df[['id']]
if avail['LD1']: base = base.merge(sub_ld1.rename(columns={'PredictionString':'LD1'}), on='id', how='left')
if avail['LD2']: base = base.merge(sub_ld2.rename(columns={'PredictionString':'LD2'}), on='id', how='left')
base = base.merge(sub_td.rename(columns={'PredictionString':'TD'}), on='id', how='left')
if avail['TS']: base = base.merge(sub_ts.rename(columns={'PredictionString':'TS'}), on='id', how='left')
if avail['CF']: base = base.merge(sub_cf.rename(columns={'PredictionString':'CF'}), on='id', how='left')

ctx_col = 'context' if 'context' in test_df.columns else None
id2ctx = dict(zip(test_df['id'], test_df[ctx_col])) if ctx_col else {}
id2lang = dict(zip(test_df['id'], test_df['language']))

PUNCT_CLASS = ''.join(chr(i) for i in range(sys.maxunicode) if ud.category(chr(i)).startswith('P'))
PUNCT_RE = re.compile(f"[\s{re.escape(PUNCT_CLASS)}]+")
DANDA='\u0964'; TA_PULLI='\u0bcd'; DEV_VIRAMA='\u094d'

def nfc(s):
    try: return ud.normalize('NFC', s)
    except Exception: return s
def norm_basic(s):
    if not isinstance(s, str): s = ''
    s = nfc(s.strip())
    s = s.replace('\u200c','').replace('\u200d','')
    s = PUNCT_RE.sub(' ', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s
def norm_spaces_only(s):
    s = nfc(s if isinstance(s, str) else '')
    return re.sub(r'\s+', ' ', s).strip()
def is_space_or_punct(c):
    return c.isspace() or ud.category(c).startswith('P')
def is_present_boundary(ans, ctx):
    if not (isinstance(ans,str) and ans and isinstance(ctx,str) and ctx): return False
    st = ctx.find(ans)
    if st == -1: return False
    ed = st + len(ans)
    left = (st==0) or is_space_or_punct(ctx[st-1])
    right = (ed==len(ctx)) or is_space_or_punct(ctx[ed])
    return left and right
HI_DIGITS = {ord(c): ord('0')+i for i, c in enumerate('\u0966\u0967\u0968\u0969\u096a\u096b\u096c\u096d\u096e\u096f')}
TA_DIGITS = {ord(c): ord('0')+i for i, c in enumerate('\u0be6\u0be7\u0be8\u0be9\u0bea\u0beb\u0bec\u0bed\u0bee\u0bef')}
def ascii_digits(s):
    if not isinstance(s, str): return ''
    return s.translate(HI_DIGITS).translate(TA_DIGITS)
SEP_RE = re.compile(r'[\./\-,:\s]+')
NUM_CHARS_RE = re.compile(r'^[0-9\-]+$')
def is_numeric_or_date_like(s, lang, cap):
    if not isinstance(s, str) or not s: return False
    s2 = ascii_digits(nfc(s)); s2 = SEP_RE.sub('-', s2).strip('-')
    if not s2 or len(s2) > cap: return False
    return bool(NUM_CHARS_RE.match(s2))

PRIORITY = {'LD1':4,'LD2':3,'TD':2,'TS':1,'CF':0}

out = []
routes = {'majority':0,'present_priority':0,'fallback_td':0,'numeric_override':0,'substr_gate':0,'boundary_snap':0,'ta_zwj_snap':0,'ta_boundary_extend':0}

for i, row in base.iterrows():
    if (i % 500 == 0) and i:
        print(f'Processed {i}/{len(base)} in {time.time()-t5:.1f}s', flush=True)
    _id = row['id']
    lang = id2lang.get(_id, 'hi')
    # minimally-relaxed caps for overrides/micros per expert set
    cap_override = 18 if lang == 'hi' else 20
    ctx = id2ctx.get(_id, '') if ctx_col else ''
    cand = {}
    for k in ('LD1','LD2','TD','TS','CF'):
        if k in row.index:
            v = row[k] if isinstance(row[k], str) else ''
            if v is not None: cand[k] = v
    nrm = {k: norm_basic(v) for k,v in cand.items()}

    # Majority over LD1, LD2, TD (priority-only TB to allow micros to act later)
    counts = {}
    for k in ('LD1','LD2','TD'):
        if k in nrm:
            counts.setdefault(nrm[k], []).append(k)
    chosen=None; chosen_src=None; majority_key=None
    for n, srcs in counts.items():
        if n and len(srcs)>=2:
            top = sorted(srcs, key=lambda x: -PRIORITY.get(x, 0))[0]
            chosen, chosen_src = cand[top], top
            majority_key = n
            routes['majority'] += 1
            break

    # Presence gate among LD1/LD2/TD; TS if none present; CF last
    if chosen is None:
        present = []
        if isinstance(ctx,str) and ctx:
            for key in ('LD1','LD2','TD'):
                if key in cand and cand[key] and ctx.find(cand[key])!=-1:
                    present.append(key)
        if present:
            pref = [(key, cand[key], is_present_boundary(cand[key], ctx)) for key in present]
            robust = [(key,v) for key,v,rob in pref if rob]
            pool = robust if robust else [(key,v) for key,v,_rob in pref]
            best = sorted(pool, key=lambda kv: -PRIORITY[kv[0]])[0]
            chosen, chosen_src = best[1], best[0]
            routes['present_priority'] += 1
        else:
            if 'TS' in cand and cand['TS'] and isinstance(ctx,str) and ctx and ctx.find(cand['TS'])!=-1:
                chosen, chosen_src = cand['TS'], 'TS'
                routes['present_priority'] += 1
            elif 'TD' in cand:
                chosen, chosen_src = cand['TD'], 'TD'
                routes['fallback_td'] += 1
            elif 'CF' in cand and cand['CF'] and isinstance(ctx,str) and ctx and ctx.find(cand['CF'])!=-1:
                chosen, chosen_src = cand['CF'], 'CF'
                routes['present_priority'] += 1
            else:
                chosen, chosen_src = cand.get('TS',''), 'TS'

    # Tight numeric/date override (minimally-relaxed: allow 2-of-3 numeric-like with TD numeric-like)
    cap_num = cap_override
    num_flags = {k: (k in cand and is_numeric_or_date_like(cand[k], lang, cap_num)) for k in ('LD1','LD2','TD')}
    if sum(1 for v in num_flags.values() if v) >= 2 and ('TD' in cand and num_flags['TD']):
        if isinstance(ctx,str) and ctx:
            td_ascii = ascii_digits(cand['TD'])
            st = ctx.find(td_ascii)
            if st != -1:
                ed = st + len(td_ascii)
                left = (st==0) or is_space_or_punct(ctx[st-1])
                right = (ed==len(ctx)) or is_space_or_punct(ctx[ed])
                if left and right and len(cand['TD']) <= cap_num and ctx.count(td_ascii) <= 2:
                    chosen, chosen_src = cand['TD'], 'TD'
                    routes['numeric_override'] += 1

    # === Micro-rules: tighten to reduce boundary_snap firing to ~1–2% and nudge TA length ===
    def _tight_norm(x):
        x = nfc(x if isinstance(x, str) else '')
        x = re.sub(r'\s+', ' ', x).strip()
        return x

    TERMINAL_MARKS = {DANDA, DEV_VIRAMA, TA_PULLI}
    def _is_terminal(c):
        return (c in TERMINAL_MARKS) or ud.category(c).startswith('P')
    def _strip_one_terminal(s):
        if isinstance(s, str) and s and _is_terminal(s[-1]):
            return s[:-1]
        return s

    # 1) Boundary Snap (only when chosen is boundary-bad; no absent-case; no shortening)
    if isinstance(ctx, str) and ctx and chosen:
        chosen_present = (ctx.find(chosen) != -1)
        chosen_bad = (chosen_present and (not is_present_boundary(chosen, ctx)))
        if chosen_bad:
            Lc = len(chosen)
            chosen_tight = _tight_norm(chosen)
            chosen_tight_trim = _tight_norm(_strip_one_terminal(chosen)) if Lc > 1 else chosen_tight
            chosen_digits = ascii_digits(chosen)
            chosen_is_num = is_numeric_or_date_like(chosen, lang, cap_override)
            for k in ('LD1', 'LD2', 'TD'):
                alt = cand.get(k, '')
                if not alt:
                    continue
                La = len(alt)
                # Allow equal or longer by +1 always; allow +2 only with containment; disallow shorter to avoid length drift
                if not ((La == Lc) or (La == Lc + 1) or (La == Lc + 2)):
                    continue
                if La > cap_override:
                    continue
                if ctx.count(alt) > 3:
                    continue
                alt_tight = _tight_norm(alt)
                alt_tight_trim = _tight_norm(_strip_one_terminal(alt)) if La > 1 else alt_tight
                # normalization: equality or trimmed equality; containment only if alt is longer
                norm_eq = (alt_tight == chosen_tight) or (alt_tight == chosen_tight_trim) or (alt_tight_trim == chosen_tight)
                norm_contain = (La > Lc) and ((chosen_tight in alt_tight) or (chosen_tight_trim in alt_tight))
                if not (norm_eq or norm_contain):
                    continue
                # digits constraint: equal OR both non-numeric-like
                alt_digits = ascii_digits(alt)
                alt_is_num = is_numeric_or_date_like(alt, lang, cap_override)
                digits_ok = (alt_digits == chosen_digits) or (not chosen_is_num and not alt_is_num)
                if not digits_ok:
                    continue
                if is_present_boundary(alt, ctx):
                    chosen, chosen_src = alt, k
                    routes['boundary_snap'] += 1
                    break

    # 2) Tamil ZWJ/ZWNJ Snap (very low-fire; slight rarity relax)
    if lang == 'ta' and isinstance(chosen, str) and isinstance(ctx, str) and ctx and chosen:
        if ('\u200c' in chosen or '\u200d' in chosen):
            stripped = chosen.replace('\u200c', '').replace('\u200d', '')
            if stripped != chosen and len(stripped) <= cap_override:
                need = (ctx.find(chosen) == -1) or (not is_present_boundary(chosen, ctx))
                if need and ctx.count(stripped) <= 3 and is_present_boundary(stripped, ctx):
                    chosen = stripped
                    routes['ta_zwj_snap'] += 1

    # 3) Tamil Boundary-Extend (guarded; encourage slightly longer clean spans)
    if lang == 'ta' and isinstance(chosen, str) and isinstance(ctx, str) and ctx and chosen and len(chosen) < cap_override:
        guard_short = len(chosen) <= 12
        if guard_short:
            ct = _tight_norm(chosen)
            ct_trim = _tight_norm(_strip_one_terminal(chosen)) if len(chosen) > 1 else ct
            choices = []
            for k in ('LD1', 'LD2', 'TD', 'TS'):
                v = cand.get(k, '')
                if not v or len(v) <= len(chosen) or len(v) > cap_override:
                    continue
                if not is_present_boundary(v, ctx):
                    continue
                vt = _tight_norm(v)
                # relaxed containment: allow tight-norm containment or literal containment
                contains_ok = (chosen in v) or (ct in vt) or (ct_trim in vt)
                if not contains_ok:
                    continue
                equal_or_off_by1 = (vt == ct) or (len(v) > 1 and _tight_norm(v[:-1]) == ct) or (len(chosen) > 1 and vt == ct_trim)
                if equal_or_off_by1 and ctx.count(v) <= 6:
                    choices.append((len(v), PRIORITY.get(k, 0), k, v))
            if choices:
                choices.sort()
                _lv, _pr, k_best, v_best = choices[-1]
                chosen, chosen_src = v_best, k_best
                routes['ta_boundary_extend'] += 1
    # === end micro-rules ===

    # Substring safety: if chosen not present but others are, snap to highest-priority present among LD1/LD2/TD, else TS, else CF
    if isinstance(ctx,str) and ctx and chosen:
        if ctx.find(chosen) == -1:
            present = []
            for key in ('LD1','LD2','TD','TS'):
                if key in cand and cand[key] and ctx.find(cand[key])!=-1:
                    present.append(key)
            if not present and 'CF' in cand and cand['CF'] and ctx.find(cand['CF'])!=-1:
                present = ['CF']
            if present:
                best = sorted(present, key=lambda x: -PRIORITY[x])[0]
                chosen, chosen_src = cand[best], best
                routes['substr_gate'] += 1

    out.append((_id, chosen))

sub = pd.DataFrame(out, columns=['id','PredictionString'])
sub['len'] = sub['PredictionString'].astype(str).str.len()
df_lang = sub.merge(test_df[['id','language']], on='id', how='left')
print(f"mean_len overall: {sub['len'].mean():.2f} | HI: {df_lang.loc[df_lang['language']=='hi','len'].mean():.2f} | TA: {df_lang.loc[df_lang['language']=='ta','len'].mean():.2f}")
total = len(sub)
print('Routes %: ' + ', '.join(f"{k}={v/total*100:.1f}%" for k,v in routes.items()), flush=True)

out_fp = 'submission_ld_majority_LD1_LD2_TD.csv'
sub[['id','PredictionString']].to_csv(out_fp, index=False)
sub[['id','PredictionString']].to_csv('submission.csv', index=False)
print('Wrote', out_fp, 'and updated submission.csv in', f'{time.time()-t5:.1f}s')

mean_len overall: 10.58 | HI: 11.02 | TA: 9.25
Routes %: majority=92.9%, present_priority=7.1%, fallback_td=0.0%, numeric_override=0.9%, substr_gate=0.0%, boundary_snap=0.9%, ta_zwj_snap=0.0%, ta_boundary_extend=0.0%


Wrote submission_ld_majority_LD1_LD2_TD.csv and updated submission.csv in 0.3s
