In [None]:
# Cell 1: Install dependencies
!pip install transformers==4.45.0 accelerate -q

## AraBERT (Frozen) + BiLSTM ##

In [None]:
# Cell 2: Imports & configuration

import os
import unicodedata
import string
from typing import List, Tuple, Dict

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

TRAIN_PATH = "dataset/train.txt"
VAL_PATH = "dataset/val.txt"

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

In [None]:
# Cell 3: Arabic diacritics & text utilities

ARABIC_DIACRITICS = set([
    "\u064b",  # Fathatan
    "\u064c",  # Dammatan
    "\u064d",  # Kasratan
    "\u064e",  # Fatha
    "\u064f",  # Damma
    "\u0650",  # Kasra
    "\u0651",  # Shadda
    "\u0652",  # Sukun
    "\u0670",  # Superscript Alef
])

def is_diacritic(ch: str) -> bool:
    return ch in ARABIC_DIACRITICS

def is_arabic_letter(ch: str) -> bool:
    if not ("\u0600" <= ch <= "\u06FF" or "\u0750" <= ch <= "\u077F"):
        return False
    if is_diacritic(ch):
        return False
    cat = unicodedata.category(ch)
    return cat.startswith("L")

def strip_diacritics(text: str) -> str:
    return "".join(ch for ch in text if not is_diacritic(ch))

# Arabic linguistic categories
SUN_LETTERS = set("تثدذرزسشصضطظنل")
MOON_LETTERS = set("ءأإابجحخعغفقكمهوي")

# Common prefixes and suffixes
ARABIC_PREFIXES = set("وفبكلس")  # wa, fa, bi, ka, li, sa
ARABIC_SUFFIXES = set("هاكني")   # ha, ka, ni, ya (pronoun suffixes)

# Special characters
ALEF_VARIANTS = set("اأإآى")
WAW_YA = set("وي")
TA_MARBUTA = "ة"
ALEF_MAQSURA = "ى"
HAMZA_VARIANTS = set("ءأإؤئ")

print("Linguistic features defined")

In [None]:
# Cell 4: Parse line with diacritic normalization

def normalize_shadda_order(diacritics: List[str]) -> str:
    """Normalize: Shadda always comes first in combo."""
    if not diacritics:
        return ""
    shadda = "\u0651"
    has_shadda = shadda in diacritics
    others = [d for d in diacritics if d != shadda]
    if has_shadda:
        return shadda + "".join(others)
    return "".join(others)

def process_line_to_bases_and_labels(line: str) -> Tuple[List[str], List[str]]:
    line = line.rstrip("\n")
    base_chars: List[str] = []
    label_combos: List[str] = []

    current_base = None
    current_diacritics: List[str] = []

    for ch in line:
        if is_diacritic(ch):
            if current_base is not None:
                current_diacritics.append(ch)
        else:
            if current_base is not None:
                combo = normalize_shadda_order(current_diacritics)
                label_combos.append(combo)
                base_chars.append(current_base)
            current_base = ch
            current_diacritics = []

    if current_base is not None:
        combo = normalize_shadda_order(current_diacritics)
        label_combos.append(combo)
        base_chars.append(current_base)

    return base_chars, label_combos

In [None]:
# Cell 5: Build label vocabulary

def build_label_vocab(path: str) -> Dict[str, int]:
    combos = set()
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            _, labels = process_line_to_bases_and_labels(line)
            combos.update(labels)

    normalized = set()
    for c in combos:
        normalized.add("NONE" if c == "" else c)

    sorted_labels = sorted(normalized, key=lambda x: (x != "NONE", x))
    label2id = {lab: i for i, lab in enumerate(sorted_labels)}

    print("Label set:")
    for i, lab in enumerate(sorted_labels):
        print(f"  {i}: {repr(lab)}")
    return label2id

label2id = build_label_vocab(TRAIN_PATH)
id2label = {i: lab for lab, i in label2id.items()}
NUM_LABELS = len(label2id)
print(f"\nNUM_LABELS = {NUM_LABELS}")

In [None]:
# Cell 6: Character vocabulary

SPECIAL_TOKENS = ["<PAD>", "<UNK>"]

def build_char_vocab(path: str) -> Dict[str, int]:
    chars = set()
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            base_chars, _ = process_line_to_bases_and_labels(line)
            chars.update(base_chars)

    sorted_chars = sorted(chars)
    vocab = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
    for ch in sorted_chars:
        if ch not in vocab:
            vocab[ch] = len(vocab)

    print(f"Vocab size: {len(vocab)}")
    return vocab

char2id = build_char_vocab(TRAIN_PATH)
id2char = {i: ch for ch, i in char2id.items()}
VOCAB_SIZE = len(char2id)

In [None]:
# Cell 7: Line structure with word boundaries

def line_to_struct(line: str):
    base_chars, combos = process_line_to_bases_and_labels(line)
    plain_text = "".join(base_chars)
    words = plain_text.split()

    char2word = []
    current_word_idx = -1
    inside_word = False

    for ch in plain_text:
        if ch.isspace():
            char2word.append(-1)
            if inside_word:
                inside_word = False
        else:
            if not inside_word:
                inside_word = True
                current_word_idx += 1
            char2word.append(current_word_idx)

    return base_chars, combos, plain_text, words, char2word

In [None]:
# Cell 8: Enhanced linguistic feature extraction - 24 features per character

def extract_enhanced_features(plain_text: str, char2word: List[int], words: List[str]) -> List[List[float]]:
    """
    Extract 24 linguistic features per character.
    """
    features = []
    n = len(plain_text)
    
    # Precompute word info
    word_starts = set()
    word_ends = set()
    pos = 0
    for word in words:
        word_starts.add(pos)
        word_ends.add(pos + len(word) - 1)
        pos += len(word) + 1  # +1 for space
    
    for i, ch in enumerate(plain_text):
        f = []
        
        # === Basic type features (4) ===
        f.append(1.0 if is_arabic_letter(ch) else 0.0)
        f.append(1.0 if ch.isspace() else 0.0)
        f.append(1.0 if ch.isdigit() else 0.0)
        f.append(1.0 if unicodedata.category(ch).startswith("P") else 0.0)
        
        # === Arabic letter categories (6) ===
        f.append(1.0 if ch in SUN_LETTERS else 0.0)
        f.append(1.0 if ch in MOON_LETTERS else 0.0)
        f.append(1.0 if ch in ALEF_VARIANTS else 0.0)
        f.append(1.0 if ch in HAMZA_VARIANTS else 0.0)
        f.append(1.0 if ch in WAW_YA else 0.0)
        f.append(1.0 if ch == TA_MARBUTA else 0.0)
        
        # === Position in word (4) ===
        is_word_start = i in word_starts
        is_word_end = i in word_ends
        f.append(1.0 if is_word_start else 0.0)
        f.append(1.0 if is_word_end else 0.0)
        f.append(1.0 if is_word_start and is_word_end else 0.0)  # Single char word
        
        # Relative position in word
        w_idx = char2word[i] if i < len(char2word) else -1
        if w_idx >= 0 and w_idx < len(words):
            word_len = len(words[w_idx])
            # Find position within word
            word_start_pos = sum(len(words[j]) + 1 for j in range(w_idx))
            pos_in_word = i - word_start_pos
            f.append(pos_in_word / max(word_len - 1, 1) if word_len > 1 else 0.5)
        else:
            f.append(0.0)
        
        # === Morphological hints (5) ===
        # Prefix character (beginning of word)
        f.append(1.0 if is_word_start and ch in ARABIC_PREFIXES else 0.0)
        # Suffix character (end of word)
        f.append(1.0 if is_word_end and ch in ARABIC_SUFFIXES else 0.0)
        
        # Definite article detection (ال)
        is_alef_lam = False
        if is_word_start and ch == 'ا' and i + 1 < n and plain_text[i + 1] == 'ل':
            is_alef_lam = True
        if i > 0 and plain_text[i - 1] == 'ا' and ch == 'ل' and (i - 1) in word_starts:
            is_alef_lam = True
        f.append(1.0 if is_alef_lam else 0.0)
        
        # After definite article (sun letter assimilation context)
        after_al = False
        if i >= 2 and w_idx >= 0:
            word_start_pos = sum(len(words[j]) + 1 for j in range(w_idx))
            if i - word_start_pos == 2:  # Third char in word
                if plain_text[word_start_pos:word_start_pos+2] == "ال":
                    after_al = True
        f.append(1.0 if after_al else 0.0)
        
        # Ta Marbuta at word end (almost always Fatha)
        f.append(1.0 if ch == TA_MARBUTA and is_word_end else 0.0)
        
        # === Context features (5) ===
        # Previous character type
        prev_ch = plain_text[i - 1] if i > 0 else ' '
        f.append(1.0 if is_arabic_letter(prev_ch) else 0.0)
        f.append(1.0 if prev_ch in ALEF_VARIANTS else 0.0)
        
        # Next character type
        next_ch = plain_text[i + 1] if i + 1 < n else ' '
        f.append(1.0 if is_arabic_letter(next_ch) else 0.0)
        f.append(1.0 if next_ch.isspace() or i + 1 >= n else 0.0)  # Before space/end
        f.append(1.0 if next_ch == TA_MARBUTA else 0.0)
        
        features.append(f)
    
    return features

NUM_ENHANCED_FEATURES = 24

# Test
test_text = "الكتاب"
test_base, _, test_plain, test_words, test_c2w = line_to_struct(test_text)
test_feats = extract_enhanced_features(test_plain, test_c2w, test_words)
print(f"Features per char: {len(test_feats[0])}")
print(f"Total chars: {len(test_feats)}")
assert len(test_feats[0]) == NUM_ENHANCED_FEATURES, f"Expected {NUM_ENHANCED_FEATURES}, got {len(test_feats[0])}"
print("Feature extraction test passed!")

In [None]:
# Cell 9: Dataset with enhanced features

class EnhancedDiacritizationDataset(Dataset):
    def __init__(self, path: str, char2id: Dict[str, int], label2id: Dict[str, int]):
        self.samples = []
        self.char2id = char2id
        self.label2id = label2id

        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue

                base_chars, combos, plain_text, words, char2word = line_to_struct(line)
                
                # Labels
                multi_labels = []
                binary_labels = []
                mask = []
                for ch, combo in zip(base_chars, combos):
                    lab_name = "NONE" if combo == "" else combo
                    if lab_name not in self.label2id:
                        lab_name = "NONE"
                    lab_id = self.label2id[lab_name]
                    multi_labels.append(lab_id)
                    binary_labels.append(0 if lab_name == "NONE" else 1)
                    mask.append(1 if is_arabic_letter(ch) else 0)

                char_ids = [self.char2id.get(ch, self.char2id["<UNK>"]) for ch in base_chars]
                
                # Enhanced features
                enhanced_feats = extract_enhanced_features(plain_text, char2word, words)

                self.samples.append({
                    "char_ids": char_ids,
                    "multi_labels": multi_labels,
                    "binary_labels": binary_labels,
                    "mask": mask,
                    "plain_text": plain_text,
                    "words": words,
                    "char2word": char2word,
                    "enhanced_feats": enhanced_feats,
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

print("Enhanced dataset class defined")

In [None]:
# Cell 10: Collate function with enhanced features

def collate_fn_enhanced(batch):
    max_len = max(len(sample["char_ids"]) for sample in batch)
    pad_id = char2id["<PAD>"]

    batch_char_ids = []
    batch_multi = []
    batch_binary = []
    batch_mask = []
    batch_plain_text = []
    batch_words = []
    batch_char2word = []
    batch_enhanced_feats = []

    for sample in batch:
        L = len(sample["char_ids"])
        pad_len = max_len - L

        batch_char_ids.append(sample["char_ids"] + [pad_id] * pad_len)
        batch_multi.append(sample["multi_labels"] + [0] * pad_len)
        batch_binary.append(sample["binary_labels"] + [0] * pad_len)
        batch_mask.append(sample["mask"] + [0] * pad_len)
        batch_char2word.append(sample["char2word"] + [-1] * pad_len)

        batch_plain_text.append(sample["plain_text"])
        batch_words.append(sample["words"])
        
        # Pad enhanced features
        feats = sample["enhanced_feats"]
        zero_feat = [0.0] * NUM_ENHANCED_FEATURES
        padded_feats = feats + [zero_feat] * pad_len
        batch_enhanced_feats.append(padded_feats)

    return {
        "char_ids": torch.tensor(batch_char_ids, dtype=torch.long),
        "multi_labels": torch.tensor(batch_multi, dtype=torch.long),
        "binary_labels": torch.tensor(batch_binary, dtype=torch.float32),
        "mask": torch.tensor(batch_mask, dtype=torch.float32),
        "plain_text": batch_plain_text,
        "words": batch_words,
        "char2word": torch.tensor(batch_char2word, dtype=torch.long),
        "enhanced_feats": torch.tensor(batch_enhanced_feats, dtype=torch.float32),
    }

In [None]:
# Cell 11: Create datasets and loaders

print("Loading datasets...")
train_dataset = EnhancedDiacritizationDataset(TRAIN_PATH, char2id, label2id)
val_dataset = EnhancedDiacritizationDataset(VAL_PATH, char2id, label2id)

BATCH_SIZE = 8

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn_enhanced)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_enhanced)

print(f"Train: {len(train_dataset):,} samples, {len(train_loader):,} batches")
print(f"Val: {len(val_dataset):,} samples, {len(val_loader):,} batches")

# Test one batch
sample_batch = next(iter(train_loader))
print(f"\nBatch shapes:")
print(f"  char_ids: {sample_batch['char_ids'].shape}")
print(f"  enhanced_feats: {sample_batch['enhanced_feats'].shape}")

In [None]:
# Cell 12: Load AraBERT

BERT_MODEL_NAME = "aubmindlab/bert-base-arabertv02"

print(f"Loading {BERT_MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
bert_model = AutoModel.from_pretrained(BERT_MODEL_NAME)
print(f"BERT hidden size: {bert_model.config.hidden_size}")

In [None]:
# Cell 13: Enhanced model with refined feature integration

class EnhancedDiacritizer(nn.Module):
    def __init__(self,
                 bert_model,
                 vocab_size: int,
                 num_labels: int,
                 num_enhanced_feats: int = 24,
                 emb_dim: int = 64,
                 feat_hidden_dim: int = 48,  # Larger to handle 24 features
                 lstm_hidden_dim: int = 256,
                 lstm_layers: int = 2,  # Deeper LSTM
                 dropout: float = 0.3,
                 freeze_bert: bool = True):
        super().__init__()

        self.bert = bert_model
        if freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False

        self.bert_hidden_size = self.bert.config.hidden_size

        # Character embedding
        self.char_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=char2id["<PAD>"])

        # Enhanced feature projection with more capacity
        self.feat_proj = nn.Sequential(
            nn.Linear(num_enhanced_feats, feat_hidden_dim),
            nn.LayerNorm(feat_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(feat_hidden_dim, feat_hidden_dim),
            nn.ReLU()
        )

        # Input dimension
        input_dim = emb_dim + feat_hidden_dim + self.bert_hidden_size

        # Deeper BiLSTM
        self.lstm = nn.LSTM(
            input_dim,
            lstm_hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if lstm_layers > 1 else 0
        )
        
        self.lstm_norm = nn.LayerNorm(lstm_hidden_dim * 2)
        self.dropout = nn.Dropout(dropout)

        # Output heads
        self.binary_head = nn.Sequential(
            nn.Linear(lstm_hidden_dim * 2, lstm_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(lstm_hidden_dim, 1)
        )
        
        self.multi_head = nn.Sequential(
            nn.Linear(lstm_hidden_dim * 2, lstm_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(lstm_hidden_dim, num_labels)
        )

    def forward(self, batch):
        char_ids = batch["char_ids"].to(DEVICE)
        enhanced_feats = batch["enhanced_feats"].to(DEVICE)
        plain_text = batch["plain_text"]
        words_list = batch["words"]
        char2word = batch["char2word"].to(DEVICE)

        B, T = char_ids.shape

        # 1) BERT word embeddings
        encoding = tokenizer(
            words_list,
            is_split_into_words=True,
            padding=True,
            truncation=True,
            return_tensors="pt"
        ).to(DEVICE)

        with torch.no_grad():
            bert_out = self.bert(**encoding)
        token_embeddings = bert_out.last_hidden_state

        # Average subword embeddings per word
        word_embs_list = []
        for i in range(B):
            word_ids = encoding.word_ids(batch_index=i)
            num_words = len(words_list[i])
            H = token_embeddings.size(-1)
            
            if num_words == 0:
                word_embs_list.append(torch.zeros((1, H), device=DEVICE))
                continue
                
            sums = torch.zeros((num_words, H), device=DEVICE)
            counts = torch.zeros((num_words, 1), device=DEVICE)

            for tok_idx, w_id in enumerate(word_ids):
                if w_id is not None and w_id < num_words:
                    sums[w_id] += token_embeddings[i, tok_idx]
                    counts[w_id] += 1.0

            counts = torch.clamp(counts, min=1.0)
            word_embs_list.append(sums / counts)

        # Map to character level
        bert_char_context = torch.zeros((B, T, self.bert_hidden_size), device=DEVICE)
        for i in range(B):
            word_embs = word_embs_list[i]
            for j in range(T):
                w_idx = char2word[i, j].item()
                if 0 <= w_idx < word_embs.size(0):
                    bert_char_context[i, j, :] = word_embs[w_idx]

        # 2) Character embeddings
        char_embs = self.char_emb(char_ids)

        # 3) Enhanced feature projection
        feat_proj = self.feat_proj(enhanced_feats)

        # 4) Concatenate
        x = torch.cat([char_embs, feat_proj, bert_char_context], dim=-1)

        # 5) BiLSTM
        lstm_out, _ = self.lstm(x)
        lstm_out = self.lstm_norm(lstm_out)
        lstm_out = self.dropout(lstm_out)

        # 6) Heads
        binary_logits = self.binary_head(lstm_out).squeeze(-1)
        multi_logits = self.multi_head(lstm_out)

        return binary_logits, multi_logits

# Count parameters
model = EnhancedDiacritizer(
    bert_model=bert_model,
    vocab_size=VOCAB_SIZE,
    num_labels=NUM_LABELS,
    num_enhanced_feats=NUM_ENHANCED_FEATURES,
    emb_dim=64,
    feat_hidden_dim=48,
    lstm_hidden_dim=256,
    lstm_layers=2,
    dropout=0.3,
    freeze_bert=True
).to(DEVICE)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable:,}")

In [None]:
# Cell 14: Loss function

def compute_loss(binary_logits, multi_logits, binary_labels, multi_labels, mask,
                 lambda_binary=1.0, lambda_multi=1.0):
    B, T = binary_logits.shape
    C = multi_logits.shape[-1]

    mask_flat = mask.view(-1)
    mask_sum = mask_flat.sum() + 1e-8

    # Binary loss
    bce = nn.functional.binary_cross_entropy_with_logits(
        binary_logits.view(-1), binary_labels.view(-1), reduction="none"
    )
    bce = (bce * mask_flat).sum() / mask_sum

    # Multi-class loss
    ce = nn.functional.cross_entropy(
        multi_logits.view(B * T, C), multi_labels.view(-1), reduction="none"
    )
    ce = (ce * mask_flat).sum() / mask_sum

    loss = lambda_binary * bce + lambda_multi * ce
    return loss, bce.item(), ce.item()

In [None]:
# Cell 15: Training configuration

optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=1e-3,
    weight_decay=0.01
)

EPOCHS = 10
LAMBDA_BINARY = 1.0
LAMBDA_MULTI = 1.0
GRAD_CLIP = 5.0

print("Training config:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: 1e-3")
print(f"  Weight decay: 0.01")

In [None]:
# Cell 16: Training function

def train_one_epoch(model, loader):
    model.train()
    total_loss, total_bce, total_ce = 0.0, 0.0, 0.0
    total_batches = 0

    for batch in loader:
        binary_labels = batch["binary_labels"].to(DEVICE)
        multi_labels = batch["multi_labels"].to(DEVICE)
        mask = batch["mask"].to(DEVICE)

        optimizer.zero_grad()
        binary_logits, multi_logits = model(batch)
        
        loss, bce, ce = compute_loss(
            binary_logits, multi_logits,
            binary_labels, multi_labels, mask,
            LAMBDA_BINARY, LAMBDA_MULTI
        )
        
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()

        total_loss += loss.item()
        total_bce += bce
        total_ce += ce
        total_batches += 1

    return total_loss / total_batches, total_bce / total_batches, total_ce / total_batches

In [None]:
# Cell 17: Prediction and evaluation

def predict_batch(model, batch, binary_threshold=0.5):
    model.eval()
    with torch.no_grad():
        binary_logits, multi_logits = model(batch)
        binary_probs = torch.sigmoid(binary_logits)
        pred_multi_ids = multi_logits.argmax(dim=-1)
        pred_binary = (binary_probs >= binary_threshold).long()
        
        none_id = label2id["NONE"]
        pred_multi_ids = pred_multi_ids.clone()
        pred_multi_ids[pred_binary == 0] = none_id

        return pred_binary.cpu(), pred_multi_ids.cpu()

def evaluate(model, loader, binary_threshold=0.5):
    model.eval()
    total_chars, correct_chars = 0, 0
    total_diac, wrong_diac = 0, 0

    with torch.no_grad():
        for batch in loader:
            plain_texts = batch["plain_text"]
            gold_multi = batch["multi_labels"].numpy()
            mask = batch["mask"].numpy()

            _, pred_multi = predict_batch(model, batch, binary_threshold)
            pred_multi = pred_multi.numpy()

            B, T = gold_multi.shape
            for i in range(B):
                L = len(plain_texts[i])
                for j in range(L):
                    if mask[i, j] == 0:
                        continue
                    g_id, p_id = gold_multi[i, j], pred_multi[i, j]

                    total_chars += 1
                    if g_id == p_id:
                        correct_chars += 1

                    g_name = id2label[g_id]
                    if g_name != "NONE":
                        total_diac += 1
                        if g_id != p_id:
                            wrong_diac += 1

    acc = 100.0 * correct_chars / total_chars if total_chars > 0 else 0.0
    der = 100.0 * wrong_diac / total_diac if total_diac > 0 else 0.0
    return acc, der

In [None]:
# Cell 18: Main training loop

best_der = float("inf")
best_state = None

print("="*60)
print("Starting training with enhanced features...")
print("="*60)

for epoch in range(1, EPOCHS + 1):
    train_loss, train_bce, train_ce = train_one_epoch(model, train_loader)
    val_acc, val_der = evaluate(model, val_loader, binary_threshold=0.5)
    
    print(f"\nEpoch {epoch}/{EPOCHS}:")
    print(f"  Train: loss={train_loss:.4f}, bce={train_bce:.4f}, ce={train_ce:.4f}")
    print(f"  Val: acc={val_acc:.2f}%, DER={val_der:.2f}%")

    if val_der < best_der:
        best_der = val_der
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        torch.save(best_state, "enhanced_model_best.pt")
        print(f"  ✓ New best DER: {best_der:.2f}%")

print("\n" + "="*60)
print(f"Training complete! Best DER: {best_der:.2f}%")
print("="*60)

In [None]:
# Cell 19: Threshold optimization

model.load_state_dict(torch.load("enhanced_model_best.pt", map_location=DEVICE))

print("Threshold sweep:")
print("-"*40)

best_thr = 0.5
best_sweep_der = float("inf")

for thr in [0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6]:
    acc, der = evaluate(model, val_loader, binary_threshold=thr)
    marker = " <-- best" if der < best_sweep_der else ""
    print(f"  thr={thr:.2f}: acc={acc:.2f}%, DER={der:.2f}%{marker}")
    
    if der < best_sweep_der:
        best_sweep_der = der
        best_thr = thr

print(f"\nOptimal threshold: {best_thr}")

In [None]:
# Cell 20: Demo diacritization

def diacritize_text(model, text: str, threshold: float = 0.5) -> str:
    model.eval()
    
    base_chars, _, plain, words, char2word = line_to_struct(text)
    char_ids = [char2id.get(ch, char2id["<UNK>"]) for ch in base_chars]
    enhanced_feats = extract_enhanced_features(plain, char2word, words)

    batch = {
        "char_ids": torch.tensor([char_ids], dtype=torch.long),
        "enhanced_feats": torch.tensor([enhanced_feats], dtype=torch.float32),
        "mask": torch.ones((1, len(char_ids)), dtype=torch.float32),
        "plain_text": [plain],
        "words": [words],
        "char2word": torch.tensor([char2word], dtype=torch.long),
    }

    _, pred_multi = predict_batch(model, batch, threshold)
    pred_ids = pred_multi[0].tolist()

    out = []
    for ch, lab_id in zip(text, pred_ids):
        out.append(ch)
        lab = id2label[lab_id]
        if lab != "NONE":
            out.append(lab)
    return "".join(out)

# Test examples
test_sentences = [
    "ولو جمع ثم علم ترك ركن من الاولى بطلت",
    "السلام عليكم ورحمة الله وبركاته",
    "الحمد لله رب العالمين",
]

print("="*60)
print("DIACRITIZATION DEMO")
print("="*60)

for sent in test_sentences:
    result = diacritize_text(model, sent, threshold=best_thr)
    print(f"\nInput:  {sent}")
    print(f"Output: {result}")