In [None]:
!pip install protobuf==3.20.*


Collecting protobuf==3.20.*
  Downloading protobuf-3.20.3-py2.py3-none-any.whl.metadata (720 bytes)
Downloading protobuf-3.20.3-py2.py3-none-any.whl (162 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.1/162.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 6.33.0
    Uninstalling protobuf-6.33.0:
      Successfully uninstalled protobuf-6.33.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
opentelemetry-proto 1.37.0 requires protobuf<7.0,>=5.0, but you have protobuf 3.20.3 which is incompatible.
onnx 1.18.0 requires protobuf>=4.25.1, but you have protobuf 3.20.3 which is incompatible.
a2a-sdk 0.3.10 requires p

In [2]:
!pip install pytorch-crf

Collecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl.metadata (2.4 kB)
Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Installing collected packages: pytorch-crf
Successfully installed pytorch-crf-0.7.2


## AraBert Model (last 2 layers unfrozen) + BiLSTM(3 Layers) + CRF ##

In [3]:
# Cell 2: Imports & configuration

import os
import unicodedata
import string
from typing import List, Tuple, Dict

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

TRAIN_PATH = "/kaggle/input/dataset1/train.txt"
VAL_PATH = "/kaggle/input/dataset1/val.txt"

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

Using device: cuda


In [4]:
# Cell 3: Arabic diacritics & text utilities

ARABIC_DIACRITICS = set([
    "\u064b",  # Fathatan
    "\u064c",  # Dammatan
    "\u064d",  # Kasratan
    "\u064e",  # Fatha
    "\u064f",  # Damma
    "\u0650",  # Kasra
    "\u0651",  # Shadda
    "\u0652",  # Sukun
    "\u0670",  # Superscript Alef
])

def is_diacritic(ch: str) -> bool:
    return ch in ARABIC_DIACRITICS

def is_arabic_letter(ch: str) -> bool:
    if not ("\u0600" <= ch <= "\u06FF" or "\u0750" <= ch <= "\u077F"):
        return False
    if is_diacritic(ch):
        return False
    cat = unicodedata.category(ch)
    return cat.startswith("L")

def strip_diacritics(text: str) -> str:
    return "".join(ch for ch in text if not is_diacritic(ch))

# Arabic linguistic categories
SUN_LETTERS = set("تثدذرزسشصضطظنل")
MOON_LETTERS = set("ءأإابجحخعغفقكمهوي")

# Common prefixes and suffixes
ARABIC_PREFIXES = set("وفبكلس")  # wa, fa, bi, ka, li, sa
ARABIC_SUFFIXES = set("هاكني")   # ha, ka, ni, ya (pronoun suffixes)

# Special characters
ALEF_VARIANTS = set("اأإآى")
WAW_YA = set("وي")
TA_MARBUTA = "ة"
ALEF_MAQSURA = "ى"
HAMZA_VARIANTS = set("ءأإؤئ")

print("Linguistic features defined")

Linguistic features defined


In [5]:
# Cell 4: Parse line with diacritic normalization

def normalize_shadda_order(diacritics: List[str]) -> str:
    """Normalize: Shadda always comes first in combo."""
    if not diacritics:
        return ""
    shadda = "\u0651"
    has_shadda = shadda in diacritics
    others = [d for d in diacritics if d != shadda]
    if has_shadda:
        return shadda + "".join(others)
    return "".join(others)

def process_line_to_bases_and_labels(line: str) -> Tuple[List[str], List[str]]:
    line = line.rstrip("\n")
    base_chars: List[str] = []
    label_combos: List[str] = []

    current_base = None
    current_diacritics: List[str] = []

    for ch in line:
        if is_diacritic(ch):
            if current_base is not None:
                current_diacritics.append(ch)
        else:
            if current_base is not None:
                combo = normalize_shadda_order(current_diacritics)
                label_combos.append(combo)
                base_chars.append(current_base)
            current_base = ch
            current_diacritics = []

    if current_base is not None:
        combo = normalize_shadda_order(current_diacritics)
        label_combos.append(combo)
        base_chars.append(current_base)

    return base_chars, label_combos

In [6]:
# Cell 5: Build label vocabulary

def build_label_vocab(path: str) -> Dict[str, int]:
    combos = set()
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            _, labels = process_line_to_bases_and_labels(line)
            combos.update(labels)

    normalized = set()
    for c in combos:
        normalized.add("NONE" if c == "" else c)

    sorted_labels = sorted(normalized, key=lambda x: (x != "NONE", x))
    label2id = {lab: i for i, lab in enumerate(sorted_labels)}

    print("Label set:")
    for i, lab in enumerate(sorted_labels):
        print(f"  {i}: {repr(lab)}")
    return label2id

label2id = build_label_vocab(TRAIN_PATH)
id2label = {i: lab for lab, i in label2id.items()}
NUM_LABELS = len(label2id)
print(f"\nNUM_LABELS = {NUM_LABELS}")

Label set:
  0: 'NONE'
  1: 'ً'
  2: 'ٌ'
  3: 'ٍ'
  4: 'َ'
  5: 'ُ'
  6: 'ِ'
  7: 'ّ'
  8: 'ًّ'
  9: 'ٌّ'
  10: 'ٍّ'
  11: 'َّ'
  12: 'ُّ'
  13: 'ِّ'
  14: 'ْ'

NUM_LABELS = 15


In [7]:
# Cell 6: Character vocabulary

SPECIAL_TOKENS = ["<PAD>", "<UNK>"]

def build_char_vocab(path: str) -> Dict[str, int]:
    chars = set()
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            base_chars, _ = process_line_to_bases_and_labels(line)
            chars.update(base_chars)

    sorted_chars = sorted(chars)
    vocab = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
    for ch in sorted_chars:
        if ch not in vocab:
            vocab[ch] = len(vocab)

    print(f"Vocab size: {len(vocab)}")
    return vocab

char2id = build_char_vocab(TRAIN_PATH)
id2char = {i: ch for ch, i in char2id.items()}
VOCAB_SIZE = len(char2id)

Vocab size: 74


In [8]:
# Cell 7: Line structure with word boundaries

def line_to_struct(line: str):
    base_chars, combos = process_line_to_bases_and_labels(line)
    plain_text = "".join(base_chars)
    words = plain_text.split()

    char2word = []
    current_word_idx = -1
    inside_word = False

    for ch in plain_text:
        if ch.isspace():
            char2word.append(-1)
            if inside_word:
                inside_word = False
        else:
            if not inside_word:
                inside_word = True
                current_word_idx += 1
            char2word.append(current_word_idx)

    return base_chars, combos, plain_text, words, char2word

In [9]:
# Cell 8: Enhanced linguistic feature extraction - 24 features per character

def extract_enhanced_features(plain_text: str, char2word: List[int], words: List[str]) -> List[List[float]]:
    """
    Extract 24 linguistic features per character.
    """
    features = []
    n = len(plain_text)
    
    # Precompute word info
    word_starts = set()
    word_ends = set()
    pos = 0
    for word in words:
        word_starts.add(pos)
        word_ends.add(pos + len(word) - 1)
        pos += len(word) + 1  # +1 for space
    
    for i, ch in enumerate(plain_text):
        f = []
        
        # === Basic type features (4) ===
        f.append(1.0 if is_arabic_letter(ch) else 0.0)
        f.append(1.0 if ch.isspace() else 0.0)
        f.append(1.0 if ch.isdigit() else 0.0)
        f.append(1.0 if unicodedata.category(ch).startswith("P") else 0.0)
        
        # === Arabic letter categories (6) ===
        f.append(1.0 if ch in SUN_LETTERS else 0.0)
        f.append(1.0 if ch in MOON_LETTERS else 0.0)
        f.append(1.0 if ch in ALEF_VARIANTS else 0.0)
        f.append(1.0 if ch in HAMZA_VARIANTS else 0.0)
        f.append(1.0 if ch in WAW_YA else 0.0)
        f.append(1.0 if ch == TA_MARBUTA else 0.0)
        
        # === Position in word (4) ===
        is_word_start = i in word_starts
        is_word_end = i in word_ends
        f.append(1.0 if is_word_start else 0.0)
        f.append(1.0 if is_word_end else 0.0)
        f.append(1.0 if is_word_start and is_word_end else 0.0)  # Single char word
        
        # Relative position in word
        w_idx = char2word[i] if i < len(char2word) else -1
        if w_idx >= 0 and w_idx < len(words):
            word_len = len(words[w_idx])
            # Find position within word
            word_start_pos = sum(len(words[j]) + 1 for j in range(w_idx))
            pos_in_word = i - word_start_pos
            f.append(pos_in_word / max(word_len - 1, 1) if word_len > 1 else 0.5)
        else:
            f.append(0.0)
        
        # === Morphological hints (5) ===
        # Prefix character (beginning of word)
        f.append(1.0 if is_word_start and ch in ARABIC_PREFIXES else 0.0)
        # Suffix character (end of word)
        f.append(1.0 if is_word_end and ch in ARABIC_SUFFIXES else 0.0)
        
        # Definite article detection (ال)
        is_alef_lam = False
        if is_word_start and ch == 'ا' and i + 1 < n and plain_text[i + 1] == 'ل':
            is_alef_lam = True
        if i > 0 and plain_text[i - 1] == 'ا' and ch == 'ل' and (i - 1) in word_starts:
            is_alef_lam = True
        f.append(1.0 if is_alef_lam else 0.0)
        
        # After definite article (sun letter assimilation context)
        after_al = False
        if i >= 2 and w_idx >= 0:
            word_start_pos = sum(len(words[j]) + 1 for j in range(w_idx))
            if i - word_start_pos == 2:  # Third char in word
                if plain_text[word_start_pos:word_start_pos+2] == "ال":
                    after_al = True
        f.append(1.0 if after_al else 0.0)
        
        # Ta Marbuta at word end (almost always Fatha)
        f.append(1.0 if ch == TA_MARBUTA and is_word_end else 0.0)
        
        # === Context features (5) ===
        # Previous character type
        prev_ch = plain_text[i - 1] if i > 0 else ' '
        f.append(1.0 if is_arabic_letter(prev_ch) else 0.0)
        f.append(1.0 if prev_ch in ALEF_VARIANTS else 0.0)
        
        # Next character type
        next_ch = plain_text[i + 1] if i + 1 < n else ' '
        f.append(1.0 if is_arabic_letter(next_ch) else 0.0)
        f.append(1.0 if next_ch.isspace() or i + 1 >= n else 0.0)  # Before space/end
        f.append(1.0 if next_ch == TA_MARBUTA else 0.0)
        
        features.append(f)
    
    return features

NUM_ENHANCED_FEATURES = 24

# Test
test_text = "الكتاب"
test_base, _, test_plain, test_words, test_c2w = line_to_struct(test_text)
test_feats = extract_enhanced_features(test_plain, test_c2w, test_words)
print(f"Features per char: {len(test_feats[0])}")
print(f"Total chars: {len(test_feats)}")
assert len(test_feats[0]) == NUM_ENHANCED_FEATURES, f"Expected {NUM_ENHANCED_FEATURES}, got {len(test_feats[0])}"
print("Feature extraction test passed!")

Features per char: 24
Total chars: 6
Feature extraction test passed!


In [10]:
# Cell 9: Dataset with enhanced features

class EnhancedDiacritizationDataset(Dataset):
    def __init__(self, path: str, char2id: Dict[str, int], label2id: Dict[str, int]):
        self.samples = []
        self.char2id = char2id
        self.label2id = label2id

        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue

                base_chars, combos, plain_text, words, char2word = line_to_struct(line)
                
                # Labels
                multi_labels = []
                binary_labels = []
                mask = []
                for ch, combo in zip(base_chars, combos):
                    lab_name = "NONE" if combo == "" else combo
                    if lab_name not in self.label2id:
                        lab_name = "NONE"
                    lab_id = self.label2id[lab_name]
                    multi_labels.append(lab_id)
                    binary_labels.append(0 if lab_name == "NONE" else 1)
                    mask.append(1 if is_arabic_letter(ch) else 0)

                char_ids = [self.char2id.get(ch, self.char2id["<UNK>"]) for ch in base_chars]
                
                # Enhanced features
                enhanced_feats = extract_enhanced_features(plain_text, char2word, words)

                self.samples.append({
                    "char_ids": char_ids,
                    "multi_labels": multi_labels,
                    "binary_labels": binary_labels,
                    "mask": mask,
                    "plain_text": plain_text,
                    "words": words,
                    "char2word": char2word,
                    "enhanced_feats": enhanced_feats,
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

print("Enhanced dataset class defined")

Enhanced dataset class defined


In [11]:
# Cell 10: Collate function with enhanced features

def collate_fn_enhanced(batch):
    max_len = max(len(sample["char_ids"]) for sample in batch)
    pad_id = char2id["<PAD>"]

    batch_char_ids = []
    batch_multi = []
    batch_binary = []
    batch_mask = []
    batch_plain_text = []
    batch_words = []
    batch_char2word = []
    batch_enhanced_feats = []

    for sample in batch:
        L = len(sample["char_ids"])
        pad_len = max_len - L

        batch_char_ids.append(sample["char_ids"] + [pad_id] * pad_len)
        batch_multi.append(sample["multi_labels"] + [0] * pad_len)
        batch_binary.append(sample["binary_labels"] + [0] * pad_len)
        
        # Mask: 1 for Arabic letters only (for evaluation)
        # Don't force first position - use actual Arabic letter mask
        mask = sample["mask"] + [0] * pad_len
        batch_mask.append(mask)
        
        batch_char2word.append(sample["char2word"] + [-1] * pad_len)

        batch_plain_text.append(sample["plain_text"])
        batch_words.append(sample["words"])
        
        # Pad enhanced features
        feats = sample["enhanced_feats"]
        zero_feat = [0.0] * NUM_ENHANCED_FEATURES
        padded_feats = feats + [zero_feat] * pad_len
        batch_enhanced_feats.append(padded_feats)

    return {
        "char_ids": torch.tensor(batch_char_ids, dtype=torch.long),
        "multi_labels": torch.tensor(batch_multi, dtype=torch.long),
        "binary_labels": torch.tensor(batch_binary, dtype=torch.float32),
        "mask": torch.tensor(batch_mask, dtype=torch.float32),
        "plain_text": batch_plain_text,
        "words": batch_words,
        "char2word": torch.tensor(batch_char2word, dtype=torch.long),
        "enhanced_feats": torch.tensor(batch_enhanced_feats, dtype=torch.float32),
    }

In [12]:
# Cell 11: Create datasets and loaders

print("Loading datasets...")
train_dataset = EnhancedDiacritizationDataset(TRAIN_PATH, char2id, label2id)
val_dataset = EnhancedDiacritizationDataset(VAL_PATH, char2id, label2id)

BATCH_SIZE = 24

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn_enhanced)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_enhanced)

print(f"Train: {len(train_dataset):,} samples, {len(train_loader):,} batches")
print(f"Val: {len(val_dataset):,} samples, {len(val_loader):,} batches")

# Test one batch
sample_batch = next(iter(train_loader))
print(f"\nBatch shapes:")
print(f"  char_ids: {sample_batch['char_ids'].shape}")
print(f"  enhanced_feats: {sample_batch['enhanced_feats'].shape}")

Loading datasets...
Train: 50,000 samples, 2,084 batches
Val: 2,500 samples, 105 batches

Batch shapes:
  char_ids: torch.Size([24, 704])
  enhanced_feats: torch.Size([24, 704, 24])


In [13]:
# Cell 12: Load AraBERT

BERT_MODEL_NAME = "aubmindlab/bert-base-arabertv02"

print(f"Loading {BERT_MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
bert_model = AutoModel.from_pretrained(BERT_MODEL_NAME)
print(f"BERT hidden size: {bert_model.config.hidden_size}")

Loading aubmindlab/bert-base-arabertv02...


tokenizer_config.json:   0%|          | 0.00/381 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/384 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

2025-12-14 00:24:42.686507: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765671882.885899      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765671882.943498      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


model.safetensors:   0%|          | 0.00/543M [00:00<?, ?B/s]

BERT hidden size: 768


In [14]:
# Cell 13: Enhanced model with CRF integration

from torchcrf import CRF

class EnhancedDiacritizer(nn.Module):
    def __init__(self,
                 bert_model,
                 vocab_size: int,
                 num_labels: int,
                 num_enhanced_feats: int = 24,
                 emb_dim: int = 64,
                 feat_hidden_dim: int = 48,  # Larger to handle 24 features
                 lstm_hidden_dim: int = 256,
                 lstm_layers: int = 3,  # Deeper LSTM
                 dropout: float = 0.3,
                 freeze_bert: bool = True,
                 use_crf: bool = True):
        super().__init__()

        self.bert = bert_model
        if freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False

        self.bert_hidden_size = self.bert.config.hidden_size
        self.use_crf = use_crf
        self.num_labels = num_labels

        # Character embedding
        self.char_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=char2id["<PAD>"])

        # Enhanced feature projection with more capacity
        self.feat_proj = nn.Sequential(
            nn.Linear(num_enhanced_feats, feat_hidden_dim),
            nn.LayerNorm(feat_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(feat_hidden_dim, feat_hidden_dim),
            nn.ReLU()
        )

        # Input dimension
        input_dim = emb_dim + feat_hidden_dim + self.bert_hidden_size

        # Deeper BiLSTM
        self.lstm = nn.LSTM(
            input_dim,
            lstm_hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if lstm_layers > 1 else 0
        )
        
        self.lstm_norm = nn.LayerNorm(lstm_hidden_dim * 2)
        self.dropout = nn.Dropout(dropout)

        # Output heads with CRF
        self.binary_head = nn.Sequential(
            nn.Linear(lstm_hidden_dim * 2, lstm_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(lstm_hidden_dim, 1)
        )
        
        self.multi_head = nn.Sequential(
            nn.Linear(lstm_hidden_dim * 2, lstm_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(lstm_hidden_dim, num_labels)
        )
        
        # CRF layers
        if use_crf:
            self.crf_multi = CRF(num_labels, batch_first=True)
            self.crf_binary = CRF(2, batch_first=True)  # Binary: 0 or 1

    def forward(self, batch, use_crf_decode=False):
        char_ids = batch["char_ids"].to(DEVICE)
        enhanced_feats = batch["enhanced_feats"].to(DEVICE)
        plain_text = batch["plain_text"]
        words_list = batch["words"]
        char2word = batch["char2word"].to(DEVICE)

        B, T = char_ids.shape

        # 1) BERT word embeddings - NO GRAD CONTEXT REMOVED to allow fine-tuning
        encoding = tokenizer(
            words_list,
            is_split_into_words=True,
            padding=True,
            truncation=True,
            return_tensors="pt"
        ).to(DEVICE)

        # Allow gradients to flow through BERT for fine-tuning
        bert_out = self.bert(**encoding)
        token_embeddings = bert_out.last_hidden_state

        # Fast vectorized word embedding aggregation
        bert_char_context = torch.zeros((B, T, self.bert_hidden_size), device=DEVICE)
        
        for i in range(B):
            word_ids = encoding.word_ids(batch_index=i)
            num_words = len(words_list[i])
            
            if num_words == 0:
                continue
            
            # Get tokens for this example
            tokens = token_embeddings[i]
            H = tokens.size(-1)
            
            # Aggregate subword embeddings to word level using scatter operations
            word_sums = torch.zeros((num_words, H), device=DEVICE)
            word_counts = torch.zeros((num_words, 1), device=DEVICE)
            
            for tok_idx, w_id in enumerate(word_ids):
                if w_id is not None and w_id < num_words:
                    word_sums[w_id] += tokens[tok_idx]
                    word_counts[w_id] += 1.0
            
            word_counts = torch.clamp(word_counts, min=1.0)
            word_embs = word_sums / word_counts
            
            # Map words to characters using char2word index
            char_indices = char2word[i, :T]  # Get all char-to-word indices for this sequence
            valid_mask = (char_indices >= 0) & (char_indices < num_words)
            valid_chars = char_indices[valid_mask]
            
            if len(valid_chars) > 0:
                bert_char_context[i, valid_mask] = word_embs[valid_chars]

        # 2) Character embeddings
        char_embs = self.char_emb(char_ids)

        # 3) Enhanced feature projection
        feat_proj = self.feat_proj(enhanced_feats)

        # 4) Concatenate
        x = torch.cat([char_embs, feat_proj, bert_char_context], dim=-1)

        # 5) BiLSTM
        lstm_out, _ = self.lstm(x)
        lstm_out = self.lstm_norm(lstm_out)
        lstm_out = self.dropout(lstm_out)

        # 6) Heads
        binary_logits = self.binary_head(lstm_out).squeeze(-1)
        multi_logits = self.multi_head(lstm_out)

        # CRF decoding (optional, for inference)
        if use_crf_decode and self.use_crf:
            # Mask: 1 where valid, 0 where padded
            mask = (char_ids != char2id["<PAD>"]).bool()
            pred_multi = self.crf_multi.decode(multi_logits, mask=mask)
            return binary_logits, multi_logits, pred_multi
        
        return binary_logits, multi_logits

    def compute_crf_loss(self, multi_logits, multi_labels, mask):
        """Compute CRF loss for multi-label sequence."""
        if not self.use_crf:
            raise RuntimeError("CRF not enabled in model")
        # CRF returns negative log-likelihood
        # reduction='token_mean' normalizes by number of tokens (not batch size)
        return -self.crf_multi(multi_logits, multi_labels, mask=mask, reduction='token_mean')

# Count parameters
model = EnhancedDiacritizer(
    bert_model=bert_model,
    vocab_size=VOCAB_SIZE,
    num_labels=NUM_LABELS,
    num_enhanced_feats=NUM_ENHANCED_FEATURES,
    emb_dim=64,
    feat_hidden_dim=48,
    lstm_hidden_dim=256,
    lstm_layers=3,
    dropout=0.3,
    freeze_bert=False,
    use_crf=True
).to(DEVICE)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable:,}")
print(f"CRF enabled: {model.use_crf}")

Trainable parameters: 140,954,327
CRF enabled: True


In [15]:
# Cell 14: Loss function with CRF

def compute_loss(model, binary_logits, multi_logits, binary_labels, multi_labels, mask, char_ids,
                 lambda_binary=1.0, lambda_multi=1.0, use_crf=True):
    B, T = binary_logits.shape
    C = multi_logits.shape[-1]

    # Evaluation mask (Arabic letters only)
    eval_mask_flat = mask.view(-1)
    eval_mask_sum = eval_mask_flat.sum() + 1e-8

    # Binary loss (only on Arabic letters)
    bce = nn.functional.binary_cross_entropy_with_logits(
        binary_logits.view(-1), binary_labels.view(-1), reduction="none"
    )
    bce = (bce * eval_mask_flat).sum() / eval_mask_sum

    # Multi-class loss with CRF
    if use_crf and model.use_crf:
        # CRF mask: all non-padded positions (not just Arabic letters)
        # This allows CRF to learn full sequence structure
        pad_id = char2id["<PAD>"]
        crf_mask = (char_ids != pad_id)
        
        # CRF requires at least one valid position per sequence
        # Ensure each sequence has at least one True value
        for i in range(B):
            if not crf_mask[i].any():
                crf_mask[i, 0] = True
        
        ce = model.compute_crf_loss(multi_logits, multi_labels, mask=crf_mask)
    else:
        # Standard CE loss (only on Arabic letters for consistency)
        ce_raw = nn.functional.cross_entropy(
            multi_logits.view(B * T, C), multi_labels.view(-1), reduction="none"
        )
        ce = (ce_raw * eval_mask_flat).sum() / eval_mask_sum

    loss = lambda_binary * bce + lambda_multi * ce
    return loss, bce.item(), ce.item() if isinstance(ce, torch.Tensor) else ce

In [16]:
# Cell 15: Training configuration with stability improvements

# Differential learning rates: lower for BERT, higher for task-specific layers
bert_param_ids = set(id(p) for p in model.bert.parameters())
bert_params = [p for p in model.parameters() if id(p) in bert_param_ids]
other_params = [p for p in model.parameters() if id(p) not in bert_param_ids]

optimizer = torch.optim.AdamW([
    {'params': bert_params, 'lr': 5e-6},          # Even lower LR for pre-trained weights (more stable)
    {'params': other_params, 'lr': 2e-4}          # Slightly reduced LR for task layers
], weight_decay=0.01, eps=1e-8)

# Learning rate scheduler with warmup
from torch.optim.lr_scheduler import OneCycleLR

EPOCHS = 10
LAMBDA_BINARY = 1.0
LAMBDA_MULTI = 1.0
GRAD_CLIP = 1.0  # Stricter gradient clipping for stability
ACCUM_STEPS = 2  # Gradient accumulation to smooth updates

print("Training config:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rates: BERT=5e-6, Other layers=2e-4 (differential, reduced)")
print(f"  Weight decay: 0.01")
print(f"  Gradient clipping: {GRAD_CLIP} (stricter)")
print(f"  Gradient accumulation: {ACCUM_STEPS} steps")
print(f"  Total trainable params: {trainable:,}")
print(f"  BERT params: {sum(p.numel() for p in bert_params):,}")
print(f"  Task-specific params: {sum(p.numel() for p in other_params):,}")

Training config:
  Epochs: 10
  Batch size: 24
  Learning rates: BERT=5e-6, Other layers=2e-4 (differential, reduced)
  Weight decay: 0.01
  Gradient clipping: 1.0 (stricter)
  Gradient accumulation: 2 steps
  Total trainable params: 140,954,327
  BERT params: 135,193,344
  Task-specific params: 5,760,983


In [17]:
# Cell 16: Training function with maximum speed optimizations

from tqdm import tqdm
import time

# Enable maximum GPU optimization
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
if torch.cuda.is_available():
    torch.cuda.empty_cache()

def train_one_epoch(model, loader, use_crf=True):
    model.train()
    total_loss = 0.0
    total_batches = len(loader)
    start_time = time.time()

    pbar = tqdm(enumerate(loader, 1), total=total_batches, desc="Training", ncols=80, disable=False)
    
    for batch_idx, batch in pbar:
        # Minimal data movement
        char_ids = batch["char_ids"].to(DEVICE, non_blocking=True)
        binary_labels = batch["binary_labels"].to(DEVICE, non_blocking=True)
        multi_labels = batch["multi_labels"].to(DEVICE, non_blocking=True)
        mask = batch["mask"].to(DEVICE, non_blocking=True)

        optimizer.zero_grad()
        
        # Forward + backward in one pass with mixed precision
        with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
            binary_logits, multi_logits = model(batch)
            loss, bce, ce = compute_loss(
                model, binary_logits, multi_logits,
                binary_labels, multi_labels, mask, char_ids,
                LAMBDA_BINARY, LAMBDA_MULTI, use_crf=use_crf
            )
        
        # Safety check
        if torch.isnan(loss) or torch.isinf(loss):
            raise ValueError(f"Training diverged at batch {batch_idx}!")
        
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()

        total_loss += loss.item()
        
        # Update progress bar every 10 batches only
        if batch_idx % 10 == 0:
            elapsed = time.time() - start_time
            batches_per_sec = batch_idx / elapsed
            pbar.set_postfix({
                'loss': f'{loss.item():.3f}', 
                'bps': f'{batches_per_sec:.1f}'
            })

    return total_loss / total_batches

In [18]:
# Cell 17: Prediction and evaluation with CRF - DEBUGGED

def predict_batch(model, batch, binary_threshold=0.5, use_crf=True):
    model.eval()
    with torch.no_grad():
        if use_crf and model.use_crf:
            # Use CRF for multi-label decoding
            binary_logits, multi_logits, pred_multi_lists = model(batch, use_crf_decode=True)
            binary_probs = torch.sigmoid(binary_logits)
            pred_binary = (binary_probs >= binary_threshold).long()
            
            # Convert CRF decoded lists to tensor
            B, T = binary_logits.shape
            pred_multi = torch.full((B, T), fill_value=label2id["NONE"], dtype=torch.long)
            for i, seq in enumerate(pred_multi_lists):
                for j, label_id in enumerate(seq):
                    if j < T:
                        pred_multi[i, j] = label_id
            
            # Don't apply binary mask - trust CRF predictions
            # The CRF already learned when to predict NONE
            return pred_binary.cpu(), pred_multi.cpu()
        else:
            # Standard decoding without CRF
            binary_logits, multi_logits = model(batch)
            binary_probs = torch.sigmoid(binary_logits)
            pred_multi_ids = multi_logits.argmax(dim=-1)
            pred_binary = (binary_probs >= binary_threshold).long()
            
            # Don't override CRF predictions with binary classifier
            return pred_binary.cpu(), pred_multi_ids.cpu()

def evaluate(model, loader, binary_threshold=0.5, use_crf=True, debug=False):
    """Fast evaluation optimized for speed."""
    model.eval()
    total_chars, correct_chars = 0, 0
    total_diac, wrong_diac = 0, 0

    with torch.no_grad():
        for batch in loader:
            gold_multi = batch["multi_labels"].numpy()
            mask = batch["mask"].numpy()
            plain_texts = batch["plain_text"]

            _, pred_multi = predict_batch(model, batch, binary_threshold, use_crf=use_crf)
            pred_multi = pred_multi.numpy()

            # Vectorized evaluation
            B, T = gold_multi.shape
            for i in range(B):
                L = len(plain_texts[i])
                # Only evaluate on Arabic letters
                valid_mask = mask[i, :L] == 1
                valid_gold = gold_multi[i, :L][valid_mask]
                valid_pred = pred_multi[i, :L][valid_mask]
                
                # Character accuracy
                correct = (valid_gold == valid_pred).sum()
                correct_chars += correct
                total_chars += len(valid_gold)
                
                # DER: only count errors on diacritized positions
                diac_mask = valid_gold != 0  # 0 is "NONE"
                total_diac += diac_mask.sum()
                wrong_diac += ((valid_gold[diac_mask] != valid_pred[diac_mask])).sum()

    acc = 100.0 * correct_chars / total_chars if total_chars > 0 else 0.0
    der = 100.0 * wrong_diac / total_diac if total_diac > 0 else 0.0
    
    if debug:
        print(f"\n=== EVALUATION STATS ===")
        print(f"Total chars: {total_chars} | Correct: {correct_chars} | Acc: {acc:.2f}%")
        print(f"Diacritized: {total_diac} | Errors: {wrong_diac} | DER: {der:.2f}%")
    
    return acc, der


In [19]:
# Cell: Data Quality Check - Run this before training!

print("="*60)
print("DATA QUALITY CHECK")
print("="*60)

# Check a few samples from training data
print("\n--- TRAIN DATA SAMPLES ---")
for i in range(min(3, len(train_dataset))):
    sample = train_dataset[i]
    text = sample["plain_text"]
    labels = [id2label[lid] for lid in sample["multi_labels"]]
    mask = sample["mask"]
    
    # Count statistics
    total_chars = len(labels)
    arabic_chars = sum(mask)
    diacritized = sum(1 for lab in labels if lab != "NONE")
    
    print(f"\nSample {i+1}:")
    print(f"  Text: {text[:60]}")
    print(f"  Total chars: {total_chars}, Arabic: {arabic_chars}, With diacritics: {diacritized}")
    print(f"  Labels (first 20): {labels[:20]}")
    print(f"  Diacritic coverage: {100*diacritized/arabic_chars:.1f}% of Arabic letters")

print("\n" + "="*60)

# Check validation data
print("\n--- VALIDATION DATA SAMPLES ---")
for i in range(min(3, len(val_dataset))):
    sample = val_dataset[i]
    text = sample["plain_text"]
    labels = [id2label[lid] for lid in sample["multi_labels"]]
    mask = sample["mask"]
    
    total_chars = len(labels)
    arabic_chars = sum(mask)
    diacritized = sum(1 for lab in labels if lab != "NONE")
    
    print(f"\nSample {i+1}:")
    print(f"  Text: {text[:60]}")
    print(f"  Total chars: {total_chars}, Arabic: {arabic_chars}, With diacritics: {diacritized}")
    print(f"  Labels (first 20): {labels[:20]}")
    print(f"  Diacritic coverage: {100*diacritized/arabic_chars:.1f}% of Arabic letters")

print("\n" + "="*60)


DATA QUALITY CHECK

--- TRAIN DATA SAMPLES ---

Sample 1:
  Text: ولو جمع ثم علم ترك ركن من الأولى بطلتا ويعيدهما جامعا ، أو م
  Total chars: 137, Arabic: 104, With diacritics: 86
  Labels (first 20): ['َ', 'َ', 'ْ', 'NONE', 'َ', 'َ', 'َ', 'NONE', 'ُ', 'َّ', 'NONE', 'َ', 'ِ', 'َ', 'NONE', 'َ', 'ْ', 'َ', 'NONE', 'ُ']
  Diacritic coverage: 82.7% of Arabic letters

Sample 2:
  Text: قال أبو زيد أهل تهامة يؤنثون العضد وبنو تميم يذكرون ، والجمع
  Total chars: 100, Arabic: 71, With diacritics: 60
  Labels (first 20): ['َ', 'NONE', 'َ', 'NONE', 'َ', 'ُ', 'NONE', 'NONE', 'َ', 'ْ', 'ٍ', 'NONE', 'َ', 'ْ', 'ُ', 'NONE', 'ِ', 'َ', 'NONE', 'َ']
  Diacritic coverage: 84.5% of Arabic letters

Sample 3:
  Text: بمنزلة أهل الذمة إذا دخلوا قرية من قرى أهل الحرب ثم ظفر المس
  Total chars: 104, Arabic: 81, With diacritics: 67
  Labels (first 20): ['ِ', 'َ', 'ْ', 'ِ', 'َ', 'ِ', 'NONE', 'َ', 'ْ', 'ِ', 'NONE', 'NONE', 'NONE', 'ِّ', 'َّ', 'ِ', 'NONE', 'NONE', 'َ', 'NONE']
  Diacritic coverage: 82.7% of Arabic 

In [20]:
# EMERGENCY: Check one batch to diagnose the issue

test_batch = next(iter(train_loader))
print("Batch info:")
print(f"  char_ids shape: {test_batch['char_ids'].shape}")
print(f"  multi_labels shape: {test_batch['multi_labels'].shape}")
print(f"  multi_labels unique values: {torch.unique(test_batch['multi_labels'])}")
print(f"  multi_labels max: {test_batch['multi_labels'].max()}")
print(f"  NUM_LABELS: {NUM_LABELS}")

# Quick forward pass test
model.eval()
with torch.no_grad():
    binary_logits, multi_logits = model(test_batch)
    print(f"\nModel outputs:")
    print(f"  binary_logits range: [{binary_logits.min():.2f}, {binary_logits.max():.2f}]")
    print(f"  multi_logits range: [{multi_logits.min():.2f}, {multi_logits.max():.2f}]")
    print(f"  multi_logits shape: {multi_logits.shape}")
    
    # Test CRF loss
    char_ids = test_batch["char_ids"].to(DEVICE)
    multi_labels = test_batch["multi_labels"].to(DEVICE)
    mask = test_batch["mask"].to(DEVICE)
    
    pad_id = char2id["<PAD>"]
    crf_mask = (char_ids != pad_id)
    
    try:
        crf_loss = model.compute_crf_loss(multi_logits, multi_labels, mask=crf_mask)
        print(f"\nCRF loss: {crf_loss.item():.4f}")
    except Exception as e:
        print(f"\nCRF error: {e}")


Batch info:
  char_ids shape: torch.Size([24, 1056])
  multi_labels shape: torch.Size([24, 1056])
  multi_labels unique values: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  9, 11, 12, 13, 14])
  multi_labels max: 14
  NUM_LABELS: 15

Model outputs:
  binary_logits range: [-0.60, 0.32]
  multi_logits range: [-0.78, 0.77]
  multi_logits shape: torch.Size([24, 1056, 15])

CRF loss: 2.7636


In [21]:
# FIXED: Recreate model with corrected CRF loss normalization

# Delete old model to free memory
del model
torch.cuda.empty_cache() if torch.cuda.is_available() else None

# Recreate with the fixed CRF loss function and BERT unfrozen for fine-tuning
model = EnhancedDiacritizer(
    bert_model=bert_model,
    vocab_size=VOCAB_SIZE,
    num_labels=NUM_LABELS,
    num_enhanced_feats=NUM_ENHANCED_FEATURES,
    emb_dim=64,
    feat_hidden_dim=48,
    lstm_hidden_dim=256,
    lstm_layers=3,
    dropout=0.3,
    freeze_bert=False,
    use_crf=True
).to(DEVICE)

# Use differential learning rates for stable fine-tuning
bert_param_ids = set(id(p) for p in model.bert.parameters())
bert_params = [p for p in model.parameters() if id(p) in bert_param_ids]
other_params = [p for p in model.parameters() if id(p) not in bert_param_ids]

optimizer = torch.optim.AdamW([
    {'params': bert_params, 'lr': 5e-6},      # Conservative LR for pre-trained weights
    {'params': other_params, 'lr': 2e-4}      # Moderate LR for task layers
], weight_decay=0.01, eps=1e-8)

print(f"Model recreated with fixed CRF loss")
print(f"New learning rate: 3e-4 (was 1e-3)")

# Test the fixed loss
test_batch = next(iter(train_loader))
model.eval()
with torch.no_grad():
    binary_logits, multi_logits = model(test_batch)
    char_ids = test_batch["char_ids"].to(DEVICE)
    multi_labels = test_batch["multi_labels"].to(DEVICE)
    pad_id = char2id["<PAD>"]
    crf_mask = (char_ids != pad_id)
    
    crf_loss = model.compute_crf_loss(multi_logits, multi_labels, mask=crf_mask)
    print(f"\nFixed CRF loss: {crf_loss.item():.4f} (should be ~2-5 now)")


Model recreated with fixed CRF loss
New learning rate: 3e-4 (was 1e-3)

Fixed CRF loss: 2.7019 (should be ~2-5 now)


In [22]:
# Cell 18: Pre-training verification and optimizer setup

# CRITICAL: Verify BERT is actually unfrozen AND optimize by using only last 2 layers
print("="*60)
print("PRE-TRAINING VERIFICATION & OPTIMIZATION")
print("="*60)

bert_trainable_before = sum(p.numel() for p in model.bert.parameters() if p.requires_grad)
total_trainable_before = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel parameter status BEFORE optimization:")
print(f"  Total trainable: {total_trainable_before:,}")
print(f"  BERT trainable: {bert_trainable_before:,}")

# CRITICAL OPTIMIZATION: Unfreeze ONLY last 2 BERT layers
print("\n⚙️ OPTIMIZING: Unfreezing only last 2 BERT encoder layers...")
for p in model.bert.parameters():
    p.requires_grad = False

# Unfreeze ONLY the last 2 transformer encoder layers (not all 12)
for p in model.bert.encoder.layer[-2:].parameters():
    p.requires_grad = True

bert_trainable = sum(p.numel() for p in model.bert.parameters() if p.requires_grad)
total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel parameter status AFTER optimization:")
print(f"  Total trainable: {total_trainable:,}")
print(f"  BERT trainable: {bert_trainable:,} (last 2 layers only)")
print(f"  ✓ Reduced from {bert_trainable_before:,} → {bert_trainable:,} BERT params")
print(f"  ✓ Training speedup: ~5-7x faster")

# CRITICAL: Recreate optimizer to match current model state
print("\nRecreating optimizer with optimized model state...")
bert_param_ids = set(id(p) for p in model.bert.parameters() if p.requires_grad)
bert_params = [p for p in model.parameters() if id(p) in bert_param_ids and p.requires_grad]
other_params = [p for p in model.parameters() if id(p) not in bert_param_ids and p.requires_grad]

print(f"  BERT params found: {sum(p.numel() for p in bert_params):,}")
print(f"  Task params found: {sum(p.numel() for p in other_params):,}")

optimizer = torch.optim.AdamW([
    {'params': bert_params, 'lr': 5e-6},      # Conservative LR for pre-trained weights
    {'params': other_params, 'lr': 2e-4}      # Moderate LR for task layers
], weight_decay=0.01, eps=1e-8)

print(f"\nOptimizer configured:")
print(f"  BERT learning rate: {optimizer.param_groups[0]['lr']}")
print(f"  Task layers learning rate: {optimizer.param_groups[1]['lr']}")
print(f"  Gradient clipping: {GRAD_CLIP}")
print("="*60 + "\n")


# Cell 18: Main training loop with CRF - OPTIMIZED for speed

import time

best_der = float("inf")
best_state = None
USE_CRF = True
patience = 3  # Early stopping patience
patience_counter = 0

print("="*60)
print("Starting training with CRF + enhanced features...")
print(f"CRF enabled: {USE_CRF}")
print(f"BERT optimization: Last 2 layers (~14M params)")
print(f"Mixed precision: Enabled")
print(f"Early stopping patience: {patience} epochs")
print("="*60)

epoch_times = []
total_start = time.time()

for epoch in range(1, EPOCHS + 1):
    epoch_start = time.time()
    
    try:
        train_loss = train_one_epoch(model, train_loader, use_crf=USE_CRF)
    except ValueError as e:
        print(f"\n❌ Training failed at epoch {epoch}: {e}")
        if best_state is not None:
            model.load_state_dict(best_state)
        break
    
    # Enable debug on first epoch only to reduce overhead
    debug_mode = (epoch == 1)
    val_acc, val_der = evaluate(model, val_loader, binary_threshold=0.5, use_crf=USE_CRF, debug=debug_mode)
    
    epoch_time = time.time() - epoch_start
    epoch_times.append(epoch_time)
    
    # Calculate ETA with exponential moving average
    if epoch > 1:
        avg_epoch_time = 0.7 * epoch_times[-1] + 0.3 * (sum(epoch_times[:-1]) / len(epoch_times[:-1]))
    else:
        avg_epoch_time = epoch_time
    
    remaining_epochs = EPOCHS - epoch
    eta_seconds = avg_epoch_time * remaining_epochs
    eta_minutes = eta_seconds / 60
    
    print(f"\nEpoch {epoch}/{EPOCHS} | {epoch_time:.0f}s | ETA: {eta_minutes:.0f}min")
    print(f"  Loss: {train_loss:.4f} | Val: DER={val_der:.2f}% | Acc={val_acc:.2f}%")

    # Save best model
    if val_der < best_der:
        best_der = val_der
        best_state = {k: v.cpu() for k, v in model.state_dict().items()}
        torch.save(best_state, "enhanced_model_best_crf.pt")
        print(f"  ✓ NEW BEST DER: {best_der:.2f}%")
        patience_counter = 0
    else:
        patience_counter += 1
        print(f"  (no improvement: {patience_counter}/{patience})")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()  # Free memory on no improvement
    
    # Early stopping
    if patience_counter >= patience:
        print(f"\n⏹️ Early stopping after {epoch} epochs")
        break

total_time = time.time() - total_start
print("\n" + "="*60)
print(f"✅ Training complete! Best DER: {best_der:.2f}%")
print(f"Total time: {total_time/60:.1f} min ({total_time/3600:.2f}h)")
print(f"Avg time/epoch: {sum(epoch_times)/len(epoch_times):.0f}s")
print("="*60)

# Load best model
if best_state is not None:
    model.load_state_dict(best_state)


# SAVE MODEL IMMEDIATELY after training
print("\n" + "="*60)
print("SAVING TRAINED MODEL")
print("="*60)
torch.save(best_state, "enhanced_model_best_crf.pt")
print(f"✓ Model saved to: enhanced_model_best_crf.pt")
print(f"  Best DER achieved: {best_der:.2f}%")
print("="*60)

PRE-TRAINING VERIFICATION & OPTIMIZATION

Model parameter status BEFORE optimization:
  Total trainable: 140,954,327
  BERT trainable: 135,193,344

⚙️ OPTIMIZING: Unfreezing only last 2 BERT encoder layers...

Model parameter status AFTER optimization:
  Total trainable: 19,936,727
  BERT trainable: 14,175,744 (last 2 layers only)
  ✓ Reduced from 135,193,344 → 14,175,744 BERT params
  ✓ Training speedup: ~5-7x faster

Recreating optimizer with optimized model state...
  BERT params found: 14,175,744
  Task params found: 5,760,983

Optimizer configured:
  BERT learning rate: 5e-06
  Task layers learning rate: 0.0002
  Gradient clipping: 1.0

Starting training with CRF + enhanced features...
CRF enabled: True
BERT optimization: Last 2 layers (~14M params)
Mixed precision: Enabled
Early stopping patience: 3 epochs


Training: 100%|██████| 2084/2084 [1:02:22<00:00,  1.80s/it, loss=0.145, bps=0.6]



=== EVALUATION STATS ===
Total chars: 407434 | Correct: 391752 | Acc: 96.15%
Diacritized: 335090 | Errors: 15107 | DER: 4.51%

Epoch 1/10 | 3803s | ETA: 570min
  Loss: 0.2936 | Val: DER=4.51% | Acc=96.15%
  ✓ NEW BEST DER: 4.51%


Training: 100%|██████| 2084/2084 [1:01:39<00:00,  1.77s/it, loss=0.079, bps=0.6]



Epoch 2/10 | 3759s | ETA: 503min
  Loss: 0.1049 | Val: DER=3.21% | Acc=97.26%
  ✓ NEW BEST DER: 3.21%


Training: 100%|██████| 2084/2084 [1:02:11<00:00,  1.79s/it, loss=0.076, bps=0.6]



Epoch 3/10 | 3791s | ETA: 442min
  Loss: 0.0811 | Val: DER=2.78% | Acc=97.59%
  ✓ NEW BEST DER: 2.78%


Training: 100%|██████| 2084/2084 [1:02:24<00:00,  1.80s/it, loss=0.079, bps=0.6]



Epoch 4/10 | 3803s | ETA: 380min
  Loss: 0.0689 | Val: DER=2.48% | Acc=97.86%
  ✓ NEW BEST DER: 2.48%


Training: 100%|██████| 2084/2084 [1:01:44<00:00,  1.78s/it, loss=0.043, bps=0.6]



Epoch 5/10 | 3764s | ETA: 314min
  Loss: 0.0611 | Val: DER=2.26% | Acc=98.04%
  ✓ NEW BEST DER: 2.26%


Training: 100%|██████| 2084/2084 [1:02:02<00:00,  1.79s/it, loss=0.070, bps=0.6]



Epoch 6/10 | 3782s | ETA: 252min
  Loss: 0.0555 | Val: DER=2.20% | Acc=98.08%
  ✓ NEW BEST DER: 2.20%


Training: 100%|██████| 2084/2084 [1:01:58<00:00,  1.78s/it, loss=0.038, bps=0.6]



Epoch 7/10 | 3779s | ETA: 189min
  Loss: 0.0511 | Val: DER=2.11% | Acc=98.18%
  ✓ NEW BEST DER: 2.11%


Training: 100%|██████| 2084/2084 [1:01:54<00:00,  1.78s/it, loss=0.053, bps=0.6]



Epoch 8/10 | 3775s | ETA: 126min
  Loss: 0.0476 | Val: DER=2.08% | Acc=98.22%
  ✓ NEW BEST DER: 2.08%


Training: 100%|██████| 2084/2084 [1:02:38<00:00,  1.80s/it, loss=0.037, bps=0.6]



Epoch 9/10 | 3819s | ETA: 63min
  Loss: 0.0445 | Val: DER=2.02% | Acc=98.23%
  ✓ NEW BEST DER: 2.02%


Training: 100%|██████| 2084/2084 [1:02:31<00:00,  1.80s/it, loss=0.049, bps=0.6]



Epoch 10/10 | 3812s | ETA: 0min
  Loss: 0.0418 | Val: DER=2.00% | Acc=98.24%
  ✓ NEW BEST DER: 2.00%

✅ Training complete! Best DER: 2.00%
Total time: 631.6 min (10.53h)
Avg time/epoch: 3789s

SAVING TRAINED MODEL
✓ Model saved to: enhanced_model_best_crf.pt
  Best DER achieved: 2.00%


In [23]:
# Cell 19: Threshold optimization

model.load_state_dict(torch.load("enhanced_model_best_crf.pt", map_location=DEVICE))

print("Threshold sweep (with CRF):")
print("-"*40)

best_thr = 0.5
best_sweep_der = float("inf")

for thr in [0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6]:
    acc, der = evaluate(model, val_loader, binary_threshold=thr, use_crf=USE_CRF)
    marker = " <-- best" if der < best_sweep_der else ""
    print(f"  thr={thr:.2f}: acc={acc:.2f}%, DER={der:.2f}%{marker}")
    
    if der < best_sweep_der:
        best_sweep_der = der
        best_thr = thr

print(f"\nOptimal threshold: {best_thr}")

Threshold sweep (with CRF):
----------------------------------------
  thr=0.30: acc=98.24%, DER=2.00% <-- best
  thr=0.35: acc=98.24%, DER=2.00%
  thr=0.40: acc=98.24%, DER=2.00%
  thr=0.45: acc=98.24%, DER=2.00%
  thr=0.50: acc=98.24%, DER=2.00%
  thr=0.55: acc=98.24%, DER=2.00%
  thr=0.60: acc=98.24%, DER=2.00%

Optimal threshold: 0.3


In [24]:
# Cell 20: Demo diacritization with CRF

def diacritize_text(model, text: str, threshold: float = 0.5, use_crf: bool = True) -> str:
    model.eval()
    
    base_chars, _, plain, words, char2word = line_to_struct(text)
    char_ids = [char2id.get(ch, char2id["<UNK>"]) for ch in base_chars]
    enhanced_feats = extract_enhanced_features(plain, char2word, words)

    batch = {
        "char_ids": torch.tensor([char_ids], dtype=torch.long),
        "enhanced_feats": torch.tensor([enhanced_feats], dtype=torch.float32),
        "mask": torch.ones((1, len(char_ids)), dtype=torch.float32),
        "plain_text": [plain],
        "words": [words],
        "char2word": torch.tensor([char2word], dtype=torch.long),
    }

    _, pred_multi = predict_batch(model, batch, threshold, use_crf=use_crf)
    pred_ids = pred_multi[0].tolist()

    out = []
    for ch, lab_id in zip(text, pred_ids):
        out.append(ch)
        lab = id2label[lab_id]
        if lab != "NONE":
            out.append(lab)
    return "".join(out)

# Test examples
test_sentences = [
    "ولو جمع ثم علم ترك ركن من الاولى بطلت",
    "السلام عليكم ورحمة الله وبركاته",
    "الحمد لله رب العالمين",
]

print("="*60)
print("DIACRITIZATION DEMO (CRF-enhanced)")
print("="*60)

for sent in test_sentences:
    result = diacritize_text(model, sent, threshold=best_thr, use_crf=USE_CRF)
    print(f"\nInput:  {sent}")
    print(f"Output: {result}")

DIACRITIZATION DEMO (CRF-enhanced)

Input:  ولو جمع ثم علم ترك ركن من الاولى بطلت
Output: وَلَوْ جَمَعَ ثُمَّ عَلِمَ تَرْكَ رُكْنٍ مِنْ الاُولَى بَطَلَتَ

Input:  السلام عليكم ورحمة الله وبركاته
Output: السَّلَامُ عَلَيْكُمْ وَرَحْمَةُ اللَّهِ وَبَرَكَاتِهِ

Input:  الحمد لله رب العالمين
Output: الْحَمْدُ لِلَّهِ رَبِّ الْعَالَمِينَ


In [25]:
# Cell 21: Save complete model and outputs

import os

# Create output directory
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)

print("="*60)
print("SAVING MODEL AND OUTPUTS")
print("="*60)

# 1. Save the complete model state
model_path = os.path.join(output_dir, "enhanced_diacritizer_model.pt")
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'best_der': best_der,
    'label2id': label2id,
    'id2label': id2label,
    'char2id': char2id,
    'id2char': id2char,
    'num_labels': NUM_LABELS,
    'vocab_size': VOCAB_SIZE,
    'best_threshold': best_thr,
}, model_path)
print(f"✓ Saved complete model to: {model_path}")

# 2. Save just the model weights (smaller file)
weights_path = os.path.join(output_dir, "model_weights.pt")
torch.save(model.state_dict(), weights_path)
print(f"✓ Saved model weights to: {weights_path}")

# 3. Save vocabularies as text files for reference
vocab_path = os.path.join(output_dir, "vocabularies.txt")
with open(vocab_path, "w", encoding="utf-8") as f:
    f.write("LABEL VOCABULARY\n")
    f.write("="*50 + "\n")
    for lab, idx in sorted(label2id.items(), key=lambda x: x[1]):
        f.write(f"{idx:2d}: {repr(lab)}\n")
    
    f.write("\n\nCHARACTER VOCABULARY\n")
    f.write("="*50 + "\n")
    for char, idx in sorted(char2id.items(), key=lambda x: x[1])[:50]:  # First 50
        f.write(f"{idx:3d}: {repr(char)}\n")
    f.write(f"... (total {len(char2id)} characters)\n")
print(f"✓ Saved vocabularies to: {vocab_path}")

# 4. Save model configuration
config_path = os.path.join(output_dir, "model_config.txt")
with open(config_path, "w", encoding="utf-8") as f:
    f.write("MODEL CONFIGURATION\n")
    f.write("="*50 + "\n")
    f.write(f"BERT Model: {BERT_MODEL_NAME}\n")
    f.write(f"Vocabulary Size: {VOCAB_SIZE:,}\n")
    f.write(f"Number of Labels: {NUM_LABELS}\n")
    f.write(f"Enhanced Features: {NUM_ENHANCED_FEATURES}\n")
    f.write(f"Character Embedding Dim: 64\n")
    f.write(f"Feature Hidden Dim: 48\n")
    f.write(f"LSTM Hidden Dim: 256\n")
    f.write(f"LSTM Layers: 2\n")
    f.write(f"Dropout: 0.3\n")
    f.write(f"CRF Enabled: True\n")
    f.write(f"\n")
    f.write(f"TRAINING RESULTS\n")
    f.write("="*50 + "\n")
    f.write(f"Best DER: {best_der:.2f}%\n")
    f.write(f"Optimal Threshold: {best_thr}\n")
    f.write(f"Trainable Parameters: {trainable:,}\n")
print(f"✓ Saved config to: {config_path}")

# 5. Save example outputs
examples_path = os.path.join(output_dir, "example_outputs.txt")
with open(examples_path, "w", encoding="utf-8") as f:
    f.write("EXAMPLE DIACRITIZATION OUTPUTS\n")
    f.write("="*50 + "\n\n")
    for sent in test_sentences:
        result = diacritize_text(model, sent, threshold=best_thr, use_crf=USE_CRF)
        f.write(f"Input:  {sent}\n")
        f.write(f"Output: {result}\n")
        f.write("-"*50 + "\n")
print(f"✓ Saved examples to: {examples_path}")

# 6. Create a README
readme_path = os.path.join(output_dir, "README.txt")
with open(readme_path, "w", encoding="utf-8") as f:
    f.write("ARABIC DIACRITIZATION MODEL\n")
    f.write("="*60 + "\n\n")
    f.write("This directory contains a trained Arabic text diacritization model.\n\n")
    f.write("FILES:\n")
    f.write("-"*60 + "\n")
    f.write("1. enhanced_diacritizer_model.pt - Complete model checkpoint\n")
    f.write("   (includes model, optimizer, vocabularies, and metrics)\n\n")
    f.write("2. model_weights.pt - Model weights only (smaller file)\n\n")
    f.write("3. vocabularies.txt - Label and character vocabularies\n\n")
    f.write("4. model_config.txt - Model architecture and training results\n\n")
    f.write("5. example_outputs.txt - Sample diacritization outputs\n\n")
    f.write("6. README.txt - This file\n\n")
    f.write("PERFORMANCE:\n")
    f.write("-"*60 + "\n")
    f.write(f"Diacritic Error Rate (DER): {best_der:.2f}%\n")
    f.write(f"Optimal threshold: {best_thr}\n\n")
    f.write("USAGE:\n")
    f.write("-"*60 + "\n")
    f.write("To load the model:\n\n")
    f.write("  checkpoint = torch.load('enhanced_diacritizer_model.pt')\n")
    f.write("  model.load_state_dict(checkpoint['model_state_dict'])\n")
    f.write("  label2id = checkpoint['label2id']\n")
    f.write("  char2id = checkpoint['char2id']\n\n")
    f.write("Then use the diacritize_text() function to diacritize Arabic text.\n")
print(f"✓ Saved README to: {readme_path}")

print("\n" + "="*60)
print("ALL OUTPUTS SAVED SUCCESSFULLY!")
print("="*60)
print(f"\nOutput directory: {os.path.abspath(output_dir)}")
print("\nFiles created:")
for filename in os.listdir(output_dir):
    filepath = os.path.join(output_dir, filename)
    size = os.path.getsize(filepath) / (1024 * 1024)  # MB
    print(f"  - {filename:35s} ({size:.2f} MB)")


SAVING MODEL AND OUTPUTS
✓ Saved complete model to: output/enhanced_diacritizer_model.pt
✓ Saved model weights to: output/model_weights.pt
✓ Saved vocabularies to: output/vocabularies.txt
✓ Saved config to: output/model_config.txt
✓ Saved examples to: output/example_outputs.txt
✓ Saved README to: output/README.txt

ALL OUTPUTS SAVED SUCCESSFULLY!

Output directory: /kaggle/working/output

Files created:
  - model_weights.pt                    (537.79 MB)
  - vocabularies.txt                    (0.00 MB)
  - README.txt                          (0.00 MB)
  - enhanced_diacritizer_model.pt       (689.96 MB)
  - example_outputs.txt                 (0.00 MB)
  - model_config.txt                    (0.00 MB)
