# Optional-Diacritic CTC vs Regular CTC (Arabic)

This notebook is a comprehensive demonstration of:

- **Regular CTC** decoding/training objective
- **Optional-diacritic constrained CTC** using a GTN WFST

Goal: provide a rigorous, reproducible comparison for research discussion.

## What this notebook gives you

1. Formal intuition and equations
2. A working implementation with Wav2Vec2 Arabic CTC emissions
3. Side-by-side scoring (regular vs constrained)
4. Decoding and error-analysis hooks
5. Evaluation protocol for supervisor-facing evidence


## 1) Hypothesis and Positioning

### Hypothesis

Constrain CTC alignments using a graph that allows optional diacritic insertion after base Arabic characters. This should:

- reduce invalid/unlikely orthographic sequences,
- improve behavior when diacritics are inconsistently present,
- preserve compatibility with CTC acoustic models.

### About novelty

This notebook demonstrates a technically coherent constrained-CTC formulation. Whether it is *novel* in the publication sense requires literature review against prior graph-constrained CTC/ASR work.

Use this notebook to show:
- a clear formulation,
- reproducible implementation,
- empirical deltas under controlled settings.


## 2) Formulation

Let $X$ be acoustic features and $Y$ target tokens.

### Regular CTC

$$\mathcal{L}_{CTC}(X,Y) = -\log \sum_{\pi \in \mathcal{A}_{CTC}(Y)} p(\pi|X)$$

In GTN terms:

- Build CTC target graph $C(Y)$,
- Build emissions graph $E(X)$ from frame log-probs,
- Score: $-\mathrm{forward}(C \circ E)$.

### Optional-Diacritic Constrained CTC

Introduce WFST $A$ that allows optional diacritic insertion after base characters.

$$\mathcal{L}_{opt}(X,Y) = -\log \sum_{\pi \in \mathcal{A}_{CTC}(Y) \cap \mathcal{A}_{A}} p(\pi|X)$$

GTN computation:

- $C_{opt} = C \circ A$
- loss: $-\mathrm{forward}(C_{opt} \circ E)$

This preserves CTC-style alignment semantics while restricting/expanding alignments via a linguistic graph.

## 3) Setup

In [None]:
import math
import json
import random
import statistics
from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict

import numpy as np
import torch
import torchaudio
import gtn

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from IPython.display import display, Markdown

print("torch:", torch.__version__)
print("gtn:", getattr(gtn, "__version__", "unknown"))
print("cuda available:", torch.cuda.is_available())


In [None]:
# Optional helper deps for large-scale evaluation (uncomment if needed)
# !pip install datasets evaluate jiwer


## 4) Model and data configuration

Set the model id and a sample audio/target pair. For convincing evidence, later run on a held-out set.


In [None]:
MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-arabic"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Replace with your own file and reference
AUDIO_PATH = "path/to/arabic_audio.wav"
TARGET_TEXT = "سلام"

ARABIC_DIACRITICS = ["َ", "ً", "ُ", "ٌ", "ِ", "ٍ", "ْ", "ّ"]


In [None]:
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to(DEVICE).eval()

blank_idx = model.config.pad_token_id
if blank_idx is None:
    raise RuntimeError("No pad_token_id found; cannot infer CTC blank index.")

print("Loaded model:", MODEL_ID)
print("CTC blank index:", blank_idx)


## 5) Core GTN utilities

These are the key building blocks for regular and constrained CTC.


In [None]:
def load_audio_16k(path: str) -> torch.Tensor:
    wav, sr = torchaudio.load(path)
    if wav.size(0) > 1:
        wav = wav.mean(dim=0, keepdim=True)
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
    return wav.squeeze(0)


def logits_from_wav(wav: torch.Tensor) -> torch.Tensor:
    inputs = processor(wav.numpy(), sampling_rate=16000, return_tensors="pt")
    input_values = inputs.input_values.to(DEVICE)
    with torch.no_grad():
        logits = model(input_values).logits[0]  # [T, V]
    return logits.cpu()


def build_emissions_graph(log_probs: np.ndarray) -> gtn.Graph:
    t, v = log_probs.shape
    g = gtn.linear_graph(t, v, calc_grad=False)
    g.set_weights(log_probs.astype(np.float32).reshape(-1))
    return g


def create_ctc_target_graph(target_ids: List[int], blank_idx: int) -> gtn.Graph:
    l = len(target_ids)
    u = 2 * l + 1
    g = gtn.Graph(False)
    for p in range(u):
        idx = (p - 1) // 2
        g.add_node(p == 0, p == u - 1 or p == u - 2)
        label = target_ids[idx] if (p % 2) else blank_idx
        g.add_arc(p, p, label)
        if p > 0:
            g.add_arc(p - 1, p, label)
        if p % 2 and p > 1 and label != target_ids[idx - 1]:
            g.add_arc(p - 2, p, label)
    return g


def build_optional_diacritic_wfst(
    ctc_symbols: List[int],
    base_symbol_ids: List[int],
    diacritic_ids: List[int],
) -> gtn.Graph:
    g = gtn.Graph(False)
    s0 = g.add_node(True, True)
    s1 = g.add_node(False, False)

    base_set = set(base_symbol_ids)

    # Pass-through for all symbols; base chars transition into diacritic state
    for sym in ctc_symbols:
        if sym in base_set:
            g.add_arc(s0, s1, sym, sym, 0.0)
        else:
            g.add_arc(s0, s0, sym, sym, 0.0)

    # Optional insertion: epsilon input -> diacritic output
    for d in diacritic_ids:
        g.add_arc(s1, s0, gtn.epsilon, d, 0.0)

    # Optional skip (no diacritic)
    g.add_arc(s1, s0, gtn.epsilon, gtn.epsilon, 0.0)
    return g


## 6) Single-utterance demonstration

Compute both losses from the same emissions.


In [None]:
audio_file = Path(AUDIO_PATH)
if not audio_file.exists():
    raise FileNotFoundError(f"Update AUDIO_PATH first: {audio_file}")

wav = load_audio_16k(str(audio_file))
logits = logits_from_wav(wav)
log_probs = torch.log_softmax(logits, dim=-1).numpy()  # [T, V]

# Reference text -> tokenizer ids
target_ids = processor.tokenizer(TARGET_TEXT, add_special_tokens=False).input_ids
if not target_ids:
    raise ValueError("Target text tokenized to empty sequence.")

# Build regular CTC loss
E = build_emissions_graph(log_probs)
C = create_ctc_target_graph(target_ids, blank_idx=blank_idx)
regular_loss = gtn.negate(gtn.forward_score(gtn.compose(C, E))).item()

# Build constrained CTC loss
vocab = processor.tokenizer.get_vocab()
diacritic_ids = sorted({vocab[d] for d in ARABIC_DIACRITICS if d in vocab})
ctc_symbols = sorted(set([blank_idx] + target_ids))

A = build_optional_diacritic_wfst(ctc_symbols, target_ids, diacritic_ids)
C_opt = gtn.compose(C, A)
opt_loss = gtn.negate(gtn.forward_score(gtn.compose(C_opt, E))).item()

print("Target text             :", TARGET_TEXT)
print("Target ids              :", target_ids)
print("Diacritic ids in vocab  :", diacritic_ids)
print(f"Regular CTC loss        : {regular_loss:.4f}")
print(f"Optional-Diacritic loss : {opt_loss:.4f}")


In [None]:
# Greedy decode reference (model output, independent from GTN scoring)
pred_ids = torch.argmax(logits, dim=-1).unsqueeze(0)
pred_text = processor.batch_decode(pred_ids)[0]
print("Greedy model decode:", pred_text)


## 7) Visualization hook (constraint graph)

Export the optional-diacritic WFST to DOT for inspection.


In [None]:
inv_vocab = {idx: tok for tok, idx in vocab.items()}
symbols = {i: inv_vocab.get(i, str(i)) for i in sorted(set(ctc_symbols + diacritic_ids))}
symbols[gtn.epsilon] = "ε"

out_dot = "optional_diacritic_wfst.dot"
gtn.draw(A, out_dot, symbols, symbols)
print("Wrote:", out_dot)
print("Tip: dot -Tpng optional_diacritic_wfst.dot -o optional_diacritic_wfst.png")


## 8) Batch evaluation protocol (for convincing evidence)

For a supervisor-ready comparison, do this on a held-out Arabic set:

1. Fix model checkpoint and preprocessing.
2. Use same references for both conditions.
3. Compute per-utterance regular and constrained losses.
4. Decode and compare CER/WER (or Arabic-normalized CER).
5. Run paired significance tests.

Recommended outputs:
- table of mean/median loss delta,
- CER/WER deltas,
- examples where constrained CTC helps/hurts,
- ablations on diacritic set size.


In [None]:
@dataclass
class UtteranceResult:
    uid: str
    target: str
    regular_loss: float
    opt_loss: float
    pred: str


def score_one(audio_path: str, target_text: str) -> UtteranceResult:
    wav = load_audio_16k(audio_path)
    logits = logits_from_wav(wav)
    log_probs = torch.log_softmax(logits, dim=-1).numpy()

    target_ids = processor.tokenizer(target_text, add_special_tokens=False).input_ids
    if not target_ids:
        raise ValueError(f"Empty target tokenization for: {target_text}")

    E = build_emissions_graph(log_probs)
    C = create_ctc_target_graph(target_ids, blank_idx)

    reg = gtn.negate(gtn.forward_score(gtn.compose(C, E))).item()

    ctc_symbols = sorted(set([blank_idx] + target_ids))
    A = build_optional_diacritic_wfst(ctc_symbols, target_ids, diacritic_ids)
    C_opt = gtn.compose(C, A)
    opt = gtn.negate(gtn.forward_score(gtn.compose(C_opt, E))).item()

    pred_ids = torch.argmax(logits, dim=-1).unsqueeze(0)
    pred = processor.batch_decode(pred_ids)[0]

    return UtteranceResult(
        uid=Path(audio_path).stem,
        target=target_text,
        regular_loss=reg,
        opt_loss=opt,
        pred=pred,
    )


In [None]:
# Example manifest format:
# manifest = [
#   {"audio": "/path/a.wav", "text": "..."},
#   {"audio": "/path/b.wav", "text": "..."},
# ]

manifest = []  # fill this with your evaluation set

results = []
for ex in manifest:
    try:
        results.append(score_one(ex["audio"], ex["text"]))
    except Exception as e:
        print("failed:", ex.get("audio", "?"), str(e))

if results:
    deltas = [r.opt_loss - r.regular_loss for r in results]
    print("N:", len(results))
    print("Mean delta (opt - regular):", float(np.mean(deltas)))
    print("Median delta:", float(np.median(deltas)))
    print("Min/Max delta:", float(np.min(deltas)), float(np.max(deltas)))
else:
    print("No results yet. Fill manifest and rerun.")


## 9) Ablation matrix

Run these to strengthen your case:

1. **No constraint** (regular CTC baseline)
2. **Optional all diacritics** (current)
3. **Optional subset of diacritics**
4. **Penalty-weighted insertions** (non-zero arc penalties)
5. **Constraint only at decode** vs **constraint in training loss**

Track:
- CER/WER
- diacritic-specific error rates
- loss curves and calibration
- runtime overhead (composition/scoring)


## 10) Supervisor-ready summary template

Use this structure in your report/slides:

1. Problem: Arabic diacritics are optional/inconsistent in labels.
2. Method: Constrained CTC using optional-diacritic WFST in GTN.
3. Theory: CTC target graph composed with linguistic graph.
4. Implementation: Wav2Vec2 emissions + GTN graph scoring.
5. Results: paired comparison against same checkpoint/data.
6. Analysis: when and why constraints help; failure cases.
7. Claim scope: method is technically valid; novelty requires literature-grounded positioning.
