In [1]:
#!/usr/bin/env python
# preprocess_answerable_muril_wiki.py

import json
import os
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer

# Input and output paths
input_json = "telugu_wiki.json"  # Your single input JSON file
out_dir = "processed_telugu_wiki_muril"  # Output directory
os.makedirs(out_dir, exist_ok=True)

max_length = 512
model_tokenizer_name = "google/muril-large-cased"  # MuRIL Large

###############################################
# 1) Filter out unanswerable QAs
###############################################
def filter_answerable_squad(input_path):
    """
    Returns a new SQuAD JSON dict containing only QAs where is_impossible=False
    with at least one answer.
    """
    with open(input_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    new_data = {
        "version": data.get("version", "filtered_telugu_wiki"),
        "data": []
    }
    for article in data["data"]:
        new_paragraphs = []
        for paragraph in article["paragraphs"]:
            context = paragraph["context"]
            new_qas = []
            for qa in paragraph["qas"]:
                if not qa.get("is_impossible", False) and qa.get("answers"):
                    new_qas.append(qa)
            if new_qas:
                new_paragraphs.append({
                    "context": context,
                    "qas": new_qas
                })
        if new_paragraphs:
            new_data["data"].append({
                "title": article.get("title", ""),
                "paragraphs": new_paragraphs
            })
    return new_data

###############################################
# 2) Build offset-based examples
###############################################
def build_answerable_examples(squad_data, tokenizer, max_length=384):
    """
    For each answerable QA:
      - tokenize question+context
      - find start/end token indices
      - store offset_mapping, context, gold_text, etc.
    """
    examples_out = []
    for article in tqdm(squad_data["data"], desc="Processing articles"):
        for paragraph in article["paragraphs"]:
            context = paragraph["context"]
            for qa in paragraph["qas"]:
                ans = qa["answers"][0]
                ans_start = ans["answer_start"]
                ans_text = ans["text"]
                ans_end = ans_start + len(ans_text)

                enc = tokenizer(
                    qa["question"],
                    context,
                    max_length=max_length,
                    truncation="only_second",
                    return_offsets_mapping=True,
                    return_tensors="pt",
                    padding="max_length"
                )

                input_ids = enc["input_ids"][0]
                attention_mask = enc["attention_mask"][0]
                offset_mapping = enc["offset_mapping"][0].tolist()

                # find start/end token indices
                start_token = None
                end_token = None
                for i, (start_char, end_char) in enumerate(offset_mapping):
                    if start_char <= ans_start < end_char:
                        start_token = i
                    if start_char < ans_end <= end_char:
                        end_token = i

                # fallback if mismatch
                if start_token is None or end_token is None or end_token < start_token:
                    start_token = 0
                    end_token = 0

                ex_item = {
                    "id": qa["id"],
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "start_positions": torch.tensor(start_token, dtype=torch.long),
                    "end_positions": torch.tensor(end_token, dtype=torch.long),
                    "offset_mapping": offset_mapping,
                    "context": context,
                    "gold_text": ans_text
                }
                examples_out.append(ex_item)
    return examples_out

def main():
    print(f"Using tokenizer: {model_tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_name)

    # Filter and build examples
    print("Filtering + building wiki data (TELUGU) ...")
    wiki_data_raw = filter_answerable_squad(input_json)
    wiki_examples = build_answerable_examples(wiki_data_raw, tokenizer, max_length)
    print(f"Wiki answerable examples size: {len(wiki_examples)}")

    # Save as .pt
    output_file = os.path.join(out_dir, "wiki_examples.pt")
    torch.save(wiki_examples, output_file)

    print(f"\nSaved processed file to {output_file}")
    print("Done! Telugu wiki preprocessing completed with MuRIL Large.")

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Using tokenizer: google/muril-large-cased
Filtering + building wiki data (TELUGU) ...


Processing articles: 100%|██████████| 199/199 [00:04<00:00, 45.63it/s]


Wiki answerable examples size: 947

Saved processed file to processed_telugu_wiki_muril/wiki_examples.pt
Done! Telugu wiki preprocessing completed with MuRIL Large.


In [3]:
#!/usr/bin/env python
# evaluate_tydiqa_telugu_muril.py

import os
import re
import torch
import numpy as np
from datasets import Dataset
from transformers import AutoModelForQuestionAnswering, AutoTokenizer

############################################
# 1) Adjust paths and parameters
############################################
DATA_DIR   = "processed_telugu_wiki_muril"
DATA_PATH  = os.path.join(DATA_DIR, "wiki_examples.pt")
MODEL_PATH = "./final_muril_tel_answerable_v2"  # your fine-tuned MuRIL QA model folder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use the same MuRIL model name as in preprocessing
MURIL_TOKENIZER = "google/muril-large-cased"

############################################
# 2) Load Data & Model
############################################
print("[INFO] Loading processed dataset...")
examples_list = torch.load(DATA_PATH)
dataset = Dataset.from_list(examples_list)
print(f"[INFO] Loaded {len(dataset)} total examples for evaluation.")

print(f"[INFO] Loading fine-tuned model from {MODEL_PATH}...")
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MURIL_TOKENIZER)
model.to(device)
model.eval()

############################################
# 3) Define Postprocessing & Metrics
############################################
def normalize_text(s):
    """Lower text and remove punctuation, articles, and extra whitespace."""
    s = s.lower()
    s = re.sub(r"\b(a|an|the)\b", " ", s)
    s = re.sub(r"[^\w\s]", "", s)
    s = " ".join(s.split())
    return s

def exact_match(pred, gold):
    return 1.0 if normalize_text(pred) == normalize_text(gold) else 0.0

def f1_score(pred, gold):
    pred_tokens = normalize_text(pred).split()
    gold_tokens = normalize_text(gold).split()
    common      = set(pred_tokens) & set(gold_tokens)
    num_same    = len(common)

    if len(pred_tokens) == 0 or len(gold_tokens) == 0:
        return 1.0 if pred_tokens == gold_tokens else 0.0
    precision = num_same / len(pred_tokens)
    recall    = num_same / len(gold_tokens)
    if precision + recall == 0:
        return 0.0
    return 2 * precision * recall / (precision + recall)

def postprocess_qa_predictions(examples, start_logits, end_logits):
    """Convert model logits into final text predictions."""
    predictions = {}
    for i, ex in enumerate(examples):
        if i >= len(start_logits) or i >= len(end_logits):
            predictions[ex["id"]] = ""
            continue

        offsets = ex["offset_mapping"]
        context = ex["context"]

        # Identify best start/end
        best_start = int(np.argmax(start_logits[i]))
        best_end   = int(np.argmax(end_logits[i]))

        # Validate token indices
        if best_start >= len(offsets) or best_end >= len(offsets) or best_start > best_end:
            predictions[ex["id"]] = ""
            continue

        # Convert token offsets to char positions, then slice context
        start_char = offsets[best_start][0]
        end_char   = offsets[best_end][1]
        pred_text  = context[start_char:end_char]

        predictions[ex["id"]] = pred_text
    return predictions

def compute_metrics(logits_tuple, examples):
    """
    logits_tuple => (start_logits, end_logits)
    examples => the raw examples with gold_text
    """
    start_logits, end_logits = logits_tuple

    # Convert to numpy if still Tensors
    if isinstance(start_logits, torch.Tensor):
        start_logits = start_logits.cpu().numpy()
    if isinstance(end_logits, torch.Tensor):
        end_logits = end_logits.cpu().numpy()

    # Postprocess predictions
    preds = postprocess_qa_predictions(examples, start_logits, end_logits)

    # Compute EM / F1
    total_em, total_f1, count = 0.0, 0.0, 0
    for ex in examples:
        gold = ex["gold_text"]
        pred = preds.get(ex["id"], "")
        total_em += exact_match(pred, gold)
        total_f1 += f1_score(pred, gold)
        count += 1

    em = 100.0 * total_em / count
    f1 = 100.0 * total_f1 / count
    return {"exact_match": em, "f1": f1}

############################################
# 4) Inference / Evaluation
############################################
print("[INFO] Running inference on each example...")
start_logits_list, end_logits_list = [], []

with torch.no_grad():
    for ex in examples_list:
        input_ids      = ex["input_ids"].unsqueeze(0).to(device)
        attention_mask = ex["attention_mask"].unsqueeze(0).to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        start_logits_list.append(outputs.start_logits.cpu().numpy())
        end_logits_list.append(outputs.end_logits.cpu().numpy())

start_logits_all = np.concatenate(start_logits_list, axis=0)
end_logits_all   = np.concatenate(end_logits_list,   axis=0)

print("[INFO] Computing final metrics...")
metrics = compute_metrics((start_logits_all, end_logits_all), examples_list)

############################################
# 5) Print results
############################################
print("\n===== TeWiki QA (Telugu) - MuRIL Evaluation =====")
print(f"Exact Match (EM): {metrics['exact_match']:.2f}")
print(f"F1 Score:         {metrics['f1']:.2f}")
print("================================================\n")

[INFO] Loading processed dataset...


  examples_list = torch.load(DATA_PATH)


[INFO] Loaded 947 total examples for evaluation.
[INFO] Loading fine-tuned model from ./final_muril_tel_answerable_v2...
[INFO] Running inference on each example...
[INFO] Computing final metrics...

===== TeWiki QA (Telugu) - MuRIL Evaluation =====
Exact Match (EM): 72.65
F1 Score:         86.74

