In [None]:
!pip install -q "protobuf==3.20.*"
!pip install -q transformers arabert preprocess

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [None]:
import re
import pickle
import numpy as np
import tensorflow as tf
import itertools
import json
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from arabert.preprocess import ArabertPreprocessor

# Utils

In [None]:
char2idx_path = '/kaggle/input/arabicia-3/char2idx.json'
arabic_letters_map = '/kaggle/input/arabic-letters-map/arabic_letters.pickle'
model_path = '/kaggle/input/arabic-diacritizer-residual/keras/1/1/model_with_features_v2_res.keras'
char_embeddings_path = '/kaggle/input/embeddings-chars/keras/default/1/embedding_matrix(1).npy'

In [None]:
def get_diacritics_map():
    # with open(diacritic2id_path, 'r', encoding='utf-8') as f:
    #     diacritic2id = json.load(f)
    diacritic2id = {
        "َ": 0,
        "ً": 1,
        "ُ": 2,
        "ٌ": 3,
        "ِ": 4,
        "ٍ": 5,
        "ْ": 6,
        "ّ": 7,
        "َّ": 8,
        "ًّ": 9,
        "ُّ": 10,
        "ٌّ": 11,
        "ِّ": 12,
        "ٍّ": 13,
        "": 14
    }
    idx2label = {v: k for k, v in diacritic2id.items()}

    return diacritic2id, idx2label

def get_char_map():
    with open(char2idx_path, 'r', encoding='utf-8') as f:
        char2idx = json.load(f)
    for key, value in char2idx.items():
        if value != 0:
            char2idx[key] = value - 1
    idx2char = {k : v for v, k in char2idx.items()}

    return char2idx, idx2char

In [None]:
def get_arabic_characters():
    with open(arabic_letters_map, 'rb') as f:
        arabic_letters = pickle.load(f)
    return arabic_letters

In [None]:
char2idx, idx2char = get_char_map()
diacritic2id, idx2label = get_diacritics_map()

# Preprocessing

In [None]:
DIACRITICS_PATTERN = re.compile(r'[\u064B-\u0652]')

In [None]:
def split_text_and_diacritics(text):

    letters = []
    labels = []

    i = 0
    while i < len(text):
        char = text[i]

        if DIACRITICS_PATTERN.match(char):
            if labels:
                labels[-1] += char
        else:
            letters.append(char)
            labels.append("")

        i += 1

    return "".join(letters), labels

In [None]:
numeric_pattern = r"\(\s*\d+\s*/\s*\d+\s*\)"
english = r"[a-zA-Z]"
numbers = r"\s*\d+\s*"
numering_items = r"\s*\d+\s*[-]\s*"
empty_brackets = r'\(\s*\)|\[\s*\]|\{\s*\}|<<\s*>>|"\s*"|\'\s*\''

def clean_punctuation_sequence(text):
    collapsible = re.escape(".,:;!?'\"/،؛؟")
    pattern = rf"([{collapsible}])(?:\s*\1)+"

    return re.sub(pattern, r"\1", text)

def remove_unbalanced_brackets(text):
    pair_map = {')': '(', '}': '{', ']': '[', '>':'<', '»': '«', '"':'"', "'":"'"}
    openers = set(['(', '{', '[', '<', '«', '"', "'"])

    stack = []
    indices_to_remove = set()

    for i, char in enumerate(text):
        if char in openers:
            stack.append((char, i))

        elif char in pair_map:
            if stack:
                last_opener, _ = stack[-1]
                if last_opener == pair_map[char]:
                    stack.pop()
                else:
                    indices_to_remove.add(i)
            else:
                indices_to_remove.add(i)

    for char, index in stack:
        indices_to_remove.add(index)

    return "".join([char for i, char in enumerate(text) if i not in indices_to_remove])


def initial_process(line):
    res = re.sub(numering_items, '', line)
    res = re.sub(numeric_pattern, '', res)
    res = re.sub(english, ' ', res)
    res = re.sub(numbers, '', res)
    res = re.sub(empty_brackets, '', res)
    res = re.sub(',', '،', res)
    res = re.sub(';', '؛', res)
    res = re.sub(r'\?', '؟', res)
    res = re.sub(r'/', '', res)
    res = re.sub(r'\*', '', res)
    res = re.sub(r'–', '-', res)
    res = res.replace('\u200f', '')


    res = clean_punctuation_sequence(res)

    res = remove_unbalanced_brackets(res)

    res = re.sub(r"\s+", " ", res).strip()

    return res


def split_citations_raw(line):
    qal_list = [
        "قال", "قالت", "قالوا", "قلت", "قلنا",
        "أقول", "يقول", "يقولون", "قيل", "يقال"
    ]

    qal_regex = "|".join(qal_list)

    qal_with_colon = rf"(?:{qal_regex})\s*[:：]"


    qawloho_regex = r"(?:و|ف)?قول(?:ه)?(?:\s*تعالى)?"

    trigger = rf"({qal_with_colon}|{qawloho_regex})"

    final_lines = []
    matches = list(re.finditer(trigger, line))

    if not matches:
        final_lines.append(line.strip())
    else:
        last_idx = 0
        for m in matches:
            start = m.start()
            if line[last_idx:start]:
                final_lines.append(line[last_idx:start])
            last_idx = start

        final_lines.append(line[last_idx:])

    return final_lines

def slide_window_raw(text, overlap=50, max_len=807):
    if len(text) <= max_len:
        return [text], [0]

    chunks = []
    overlaps = []

    chunks.append(text[:max_len])
    overlaps.append(0)

    current_start = 0
    text_len = len(text)

    while True:
        ideal_stride = max_len - overlap

        ideal_next_start = current_start + ideal_stride

        if ideal_next_start >= text_len:
            break

        found_next_start = -1

        search_limit = current_start

        for i in range(ideal_next_start, search_limit, -1):
            if i < text_len and text[i] == ' ':
                found_next_start = i + 1
                break

        if found_next_start == -1:
            found_next_start = ideal_next_start

        actual_overlap = (current_start + max_len) - found_next_start

        if actual_overlap < 0:
            actual_overlap = 0

        next_chunk = text[found_next_start : found_next_start + max_len]

        chunks.append(next_chunk)
        overlaps.append(actual_overlap)

        current_start = found_next_start

        if current_start + max_len >= text_len:
            break

    return chunks, overlaps


def prepare_for_predict():
    all_recovery = []
    assertions_text = []
    assertions_tashkeel = []
    test = False
    curr_chunks = []
    curr_overlaps = []

    with open('/kaggle/input/val-only/val.txt', "r", encoding="utf-8") as file:

        for line in file:

            cleaned = initial_process(line.strip())
            if test == True:
                assertions_text.append(cleaned)
                line = cleaned
            else:
                line, tashkeel = split_text_and_diacritics(cleaned)
                assertions_text.append(line)
                assertions_tashkeel.append(tashkeel)

            raw_segments = split_citations_raw(line)
            recovery = []

            for seg in raw_segments:
                t_chunks, t_overlaps = slide_window_raw(seg, overlap=50, max_len=807)
                assert len(t_chunks) == len(t_overlaps), print(len(t_chunks), len(t_overlaps))

                for i, chunk in enumerate(t_chunks):
                    recovery.append(i)
                    curr_chunks.append(chunk)

                curr_overlaps.extend(t_overlaps)
            all_recovery.append(recovery)

    print(f"Generated {len(curr_chunks)} chunks.")
    return curr_chunks, curr_overlaps, all_recovery, assertions_text, assertions_tashkeel

In [None]:
chunks, overlaps, recovery, assertions_text, assertions_tashkeel = prepare_for_predict()

Generated 4277 chunks.


# Post Processing

In [None]:
def reconstruct_text_window(chunks, overlaps):
    if not chunks:
        return ""

    reconstructed_parts = []

    for chunk, ov in zip(chunks[0:], overlaps):
        reconstructed_parts.append(chunk[ov:])

    return "".join(reconstructed_parts)


def arabic_only_text_and_tashkeel(text, tashkeel):
    ARABIC_CHARS = get_arabic_characters()
    return "".join([char for char in text if char in ARABIC_CHARS]), [tashkeel[i] for i, char in enumerate(text) if char in ARABIC_CHARS]

def post_process(chunks, overlaps, recovery):
    results = []
    start_chnk_idx = 0
    end_chnk_idx = 0

    for i in range(len(recovery)):
        zero_before = False
        res = ''
        for j in recovery[i]:
            if j == 0:
                if zero_before:
                    res += reconstruct_text_window(chunks[start_chnk_idx:end_chnk_idx + 1], overlaps[start_chnk_idx:end_chnk_idx + 1])
                    start_chnk_idx = end_chnk_idx + 1
                    end_chnk_idx += 1
                zero_before = True
            else:
                end_chnk_idx += 1

        res += reconstruct_text_window(chunks[start_chnk_idx:end_chnk_idx + 1], overlaps[start_chnk_idx:end_chnk_idx + 1])
        start_chnk_idx = end_chnk_idx + 1
        end_chnk_idx += 1
        results.append(res)

    return results

In [None]:
def reconstruct_diacritics_window(chunks, overlaps):
    if not chunks:
        return np.array([])

    reconstructed_parts = []

    for chunk, ov in zip(chunks, overlaps):
        reconstructed_parts.append(chunk[ov:])

    return np.concatenate(reconstructed_parts)


def post_process_diacritics(chunks, overlaps, recovery):
    results = []
    start_chnk_idx = 0
    end_chnk_idx = 0

    for i in range(len(recovery)):
        zero_before = False

        res = np.array([], dtype=int)

        for j in recovery[i]:
            if j == 0:
                if zero_before:
                    segment = reconstruct_diacritics_window(
                        chunks[start_chnk_idx : end_chnk_idx + 1],
                        overlaps[start_chnk_idx : end_chnk_idx + 1]
                    )
                    res = np.concatenate([res, segment])

                    start_chnk_idx = end_chnk_idx + 1
                    end_chnk_idx += 1
                zero_before = True
            else:
                end_chnk_idx += 1

        segment = reconstruct_diacritics_window(
            chunks[start_chnk_idx : end_chnk_idx + 1],
            overlaps[start_chnk_idx : end_chnk_idx + 1]
        )
        res = np.concatenate([res, segment])

        start_chnk_idx = end_chnk_idx + 1
        end_chnk_idx += 1

        results.append(res)

    return results

In [None]:
def get_finals(results, labels, tokens=True):
    flat_labels = list(itertools.chain.from_iterable(labels))
    if tokens:
        new_flat_labels = [idx2label[label] for label in flat_labels]
    else:
        new_flat_labels = flat_labels
    idx = 0
    final_results = []
    for result in results:
        final_str = ''
        for char in result:
            final_str += char + new_flat_labels[idx]
            idx += 1
        final_results.append(final_str)
    return final_results

# Extract Features

In [None]:

arabert_model_name = "aubmindlab/bert-base-arabertv02"
bert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
bert_model = AutoModel.from_pretrained(arabert_model_name)
bert_model.to(device)
bert_model.eval()
arabert_prep = ArabertPreprocessor(model_name=arabert_model_name)


custom_char_embedding = np.load(char_embeddings_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_arabert_embeddings(sentence: str):

    tokens = bert_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
    tokens = {k: v.to(device) for k, v in tokens.items()}

    with torch.no_grad():
        output = bert_model(**tokens)

    emb = output.last_hidden_state.squeeze(0).cpu()
    token_list = bert_tokenizer.convert_ids_to_tokens(tokens["input_ids"][0])

    return emb.numpy(), token_list



def extract_custom_char_embeddings(char):
    char2idx, _ = get_char_map()
    return custom_char_embedding[char2idx[char]]

def tokens_to_word_embeddings(tokens, embeddings):
    word_embeddings = []
    current_word_embs = []

    for token, emb in zip(tokens, embeddings):
        emb_tensor = torch.tensor(emb) if isinstance(emb, np.ndarray) else emb

        if token.startswith("##"):
            current_word_embs.append(emb_tensor)
        else:
            if current_word_embs:
                word_embeddings.append(torch.mean(torch.stack(current_word_embs), dim=0))
            current_word_embs = [emb_tensor]

    if current_word_embs:
        word_embeddings.append(torch.mean(torch.stack(current_word_embs), dim=0))

    return torch.stack(word_embeddings)

In [None]:
def zizo_features(sentence: str):

    sentence_vec = []

    arabert_emb, tokens = get_arabert_embeddings(sentence)
    final_arabert_emb = tokens_to_word_embeddings(tokens, arabert_emb)

    words_raw = sentence.split()
    word_idx = 0
    char_in_word_idx = 0

    emb_dim = final_arabert_emb[0].shape[0]

    for i, char in enumerate(sentence):

        char_emb = extract_custom_char_embeddings(char)
        char_emb_array = np.array(char_emb).flatten()

        if char == ' ':
            bert_vec = np.zeros(emb_dim)

        else:
            bert_vec = final_arabert_emb[word_idx]
            if isinstance(bert_vec, torch.Tensor):
                bert_vec = bert_vec.numpy()

            char_in_word_idx += 1

            if char_in_word_idx == len(words_raw[word_idx]):
                word_idx += 1
                char_in_word_idx = 0

        char_vector = np.concatenate([bert_vec, char_emb_array])
        sentence_vec.append(char_vector)

    return sentence_vec


def extract_features(sentences):
    all_sentence_features = []

    for i in tqdm(range(len(sentences)), total=len(sentences), desc="extracting features"):
        sent = sentences[i]
        features_list = zizo_features("".join(sent))

        all_sentence_features.append(np.array(features_list, dtype=np.float16))

    return all_sentence_features

# Predict

In [None]:
model = tf.keras.models.load_model('/kaggle/input/arabic-diacritizer-residual/keras/1/1/model_with_features_v2_res.keras', compile=False)

In [None]:
INTAHA = r'\s+ا\s*هـ?\s+'
BATCH_SIZE = 32
PADDING_INPUT = -99999.0
INPUT_DIM = 1024

def predict(text_chunks):

    # features = extract_features(text_chunks)


    sentence_lengths = [len(f) for f in text_chunks]

    def test_set_generator(features, lengths):
        for i in range(len(features)):
            yield features[i], [lengths[i]]

    test_dataset = tf.data.Dataset.from_generator(
            lambda: test_set_generator(features, sentence_lengths),
            output_signature=(
                tf.TensorSpec(shape=(None, INPUT_DIM), dtype=tf.float32),
                tf.TensorSpec(shape=(1,), dtype=tf.int32)
            )
        ).padded_batch(BATCH_SIZE, padding_values=(PADDING_INPUT, 15))

    all_predictions = []

    print("Starting prediction...")
    for batch_x, batch_lens in test_dataset:

        batch_probs = model.predict_on_batch(batch_x)

        batch_pred_ids = np.argmax(batch_probs, axis=-1)

        current_batch_lengths = batch_lens.numpy().flatten()

        batch_size_current = batch_pred_ids.shape[0]

        for k in range(batch_size_current):

            valid_len = current_batch_lengths[k]
            pred_seq = batch_pred_ids[k][:valid_len]
            all_predictions.append(pred_seq)

    return all_predictions

In [None]:
features = extract_features(chunks)

extracting features: 100%|██████████| 4277/4277 [06:13<00:00, 11.46it/s]


In [None]:
all_predictions = predict(chunks)

Starting prediction...


In [None]:
results = post_process(chunks, overlaps, recovery)

In [None]:
pred_diac = post_process_diacritics(all_predictions, overlaps, recovery)

In [None]:
predicted_text = get_finals(results, pred_diac)

In [None]:
with open('/kaggle/input/val-only/val.txt', "r", encoding="utf-8") as file:
    lines = file.readlines()

In [None]:
start_index = 0
current_lines = lines
current_preds = predicted_text

matches = 0
total = 0

for line_str, pred_str in zip(current_lines, current_preds):
    og_text, og_tashkeel = split_text_and_diacritics(initial_process(line_str.strip()))
    ll, og = arabic_only_text_and_tashkeel(og_text, og_tashkeel)

    pred_text, pred_tashkeel = split_text_and_diacritics(pred_str.strip())
    rr, pred = arabic_only_text_and_tashkeel(pred_text, pred_tashkeel)

    matches += sum(o == p for o, p in zip(og, pred))
    total += len(og)

print(f"Final acc: {matches * 100 / total:.2f}%")