### Import libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import json
import sys
import os
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments
from transformers.trainer_utils import EvalPrediction
from seqeval.metrics import classification_report
import shap
import lime.lime_text
from nltk.tokenize import word_tokenize
import nltk
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
sys.path.append(os.path.abspath('../src/'))
sys.path.append(os.path.abspath('../scripts/'))

In [4]:
from interpret.interpret_ner_models import load_conll, logger, load_unlabeled_messages, apply_shap, apply_lime, predict_entities

2025-07-22 18:56:05,023 - INFO - NLTK punkt tokenizer downloaded successfully


In [5]:
# Define paths and model
conll_file = '../conLL/amharic_ner.conll'
messages_file = '../data/cleaned_message.csv'
model_dir = '../models/amharic_ner_xlmr'
output_file = '../data/interpretability_results.csv'
report_file = '../Task5_Interpretability_Report.md'

In [6]:
# Label mappings
label2id = {
    "O": 0,
    "B-Product": 1,
    "I-Product": 2,
    "B-PRICE": 3,
    "I-PRICE": 4,
    "B-LOC": 5,
    "I-LOC": 6
}

id2label = {v: k for k, v in label2id.items()}
logger.info(f"Label mappings: {label2id}")

2025-07-22 18:56:05,574 - INFO - Label mappings: {'O': 0, 'B-Product': 1, 'I-Product': 2, 'B-PRICE': 3, 'I-PRICE': 4, 'B-LOC': 5, 'I-LOC': 6}


In [7]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForTokenClassification.from_pretrained(model_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger.info(f"Loaded model and tokenizer from {model_dir}, using device: {device}")

2025-07-22 18:56:07,650 - INFO - Loaded model and tokenizer from ../models/amharic_ner_xlmr, using device: cpu


In [8]:
# Load datasets
dataset = load_conll(conll_file)
val_dataset = dataset.select(range(int(0.8 * len(dataset)), len(dataset)))
messages = load_unlabeled_messages(messages_file)
logger.info(f"Validation set size: {len(val_dataset)}")

2025-07-22 18:56:08,284 - INFO - Loaded 200 sentences from ../conLL/amharic_ner.conll
2025-07-22 18:56:08,284 - INFO - Unique labels: {'O', 'B-LOC', 'I-Product', 'B-PRICE', 'I-PRICE', 'B-Product'}
2025-07-22 18:56:08,715 - INFO - Loaded 50 new messages from ../data/cleaned_message.csv
2025-07-22 18:56:08,723 - INFO - Validation set size: 40


In [9]:
# Analyze validation set and new messages
results = []
difficult_cases = []
for idx, example in enumerate(val_dataset):
    message = " ".join(example['tokens'])
    true_labels = example['ner_tags']
    tokens, pred_labels, logits = predict_entities(message, tokenizer, model, id2label, device)
            
    max_logits = np.max(logits, axis=1) if len(logits) > 0 and len(logits[0]) > 0 else np.array([])
    low_confidence = np.any(max_logits < 0.9) if max_logits.size > 0 else False
            
    mismatch = False
    if len(pred_labels) == len(true_labels):
        for p_label, t_label in zip(pred_labels, true_labels):
            if p_label != t_label:
                mismatch = True
                break
    else:
        mismatch = True

    if mismatch or low_confidence:
        difficult_cases.append({
            "message": message,
            "true_labels": true_labels,
            "pred_labels": pred_labels,
            "logits": max_logits.tolist() if max_logits.size > 0 else []
        })

    if idx < 5:
        shap_values = apply_shap(message, tokenizer, model, id2label, device)
        lime_weights = apply_lime(message, tokenizer, model, id2label, device)
    else:
        shap_values = []
        lime_weights = []

    results.append({
        "message": message,
        "true_labels": ", ".join(true_labels),
        "pred_labels": ", ".join(pred_labels),
        "shap_values": str(shap_values),
        "lime_weights": str(lime_weights)
    })
    logger.info(f"Processed validation example {idx + 1}/{len(val_dataset)}")

2025-07-22 18:56:16,669 - ERROR - Error in SHAP explanation for message: 360 100 ዋጋ፦ 300 ብር ውስን ፍሬ ነው ያለው አድራሻ ቢሮ ቁ S05S06 ...: tuple index out of range
2025-07-22 18:57:52,883 - INFO - Computed LIME explanation for message: 360 100 ዋጋ፦ 300 ብር ውስን ፍሬ ነው ያለው አድራሻ ቢሮ ቁ S05S06 ...
2025-07-22 18:57:52,884 - INFO - Processed validation example 1/40
2025-07-22 18:57:53,277 - ERROR - Error in SHAP explanation for message: 360 100 ዋጋ፦ 300 ብር ውስን ፍሬ ነው ያለው አድራሻ ቢሮ ቁ S05S06 ...: tuple index out of range
2025-07-22 18:59:14,613 - INFO - Computed LIME explanation for message: 360 100 ዋጋ፦ 300 ብር ውስን ፍሬ ነው ያለው አድራሻ ቢሮ ቁ S05S06 ...
2025-07-22 18:59:14,615 - INFO - Processed validation example 2/40
2025-07-22 18:59:15,301 - ERROR - Error in SHAP explanation for message: ዋጋ፦ 250 ብር ውስን ፍሬ ነው ያለው አድራሻ ቢሮ ቁ S05S06 09026607...: tuple index out of range
2025-07-22 19:00:38,197 - INFO - Computed LIME explanation for message: ዋጋ፦ 250 ብር ውስን ፍሬ ነው ያለው አድራሻ ቢሮ ቁ S05S06 09026607...
2025-07-22 19:00:38,198 - INF

In [10]:
# Analyze new messages
for idx, message in enumerate(messages):
    tokens, pred_labels, logits = predict_entities(message, tokenizer, model, id2label, device)
            
    if idx < 5:
        shap_values = apply_shap(message, tokenizer, model, id2label, device)
        lime_weights = apply_lime(message, tokenizer, model, id2label, device)
    else:
        shap_values = []
        lime_weights = []

    results.append({
        "message": message,
        "true_labels": "N/A",
        "pred_labels": ", ".join(pred_labels),
        "shap_values": str(shap_values),
        "lime_weights": str(lime_weights)
    })
    logger.info(f"Processed new message {idx + 1}/{len(messages)}")

2025-07-22 19:03:53,147 - ERROR - Error in SHAP explanation for message: የሞተ ቆዳን እንዲሁም ቆሻሻን ለማፅዳት ተመራጭ ዋጋ፦ 200 ብር ውስን ፍሬ ነው...: tuple index out of range
2025-07-22 19:05:13,859 - INFO - Computed LIME explanation for message: የሞተ ቆዳን እንዲሁም ቆሻሻን ለማፅዳት ተመራጭ ዋጋ፦ 200 ብር ውስን ፍሬ ነው...
2025-07-22 19:05:13,859 - INFO - Processed new message 1/50
2025-07-22 19:05:14,252 - ERROR - Error in SHAP explanation for message: የሞተ ቆዳን እንዲሁም ቆሻሻን ለማፅዳት ተመራጭ ዋጋ፦ 200 ብር ውስን ፍሬ ነው...: tuple index out of range
2025-07-22 19:06:36,324 - INFO - Computed LIME explanation for message: የሞተ ቆዳን እንዲሁም ቆሻሻን ለማፅዳት ተመራጭ ዋጋ፦ 200 ብር ውስን ፍሬ ነው...
2025-07-22 19:06:36,325 - INFO - Processed new message 2/50
2025-07-22 19:06:37,041 - ERROR - Error in SHAP explanation for message: 3 ተቀያያሪ ብርሀን ያለው በቻርጅ የሚሰራ ዋጋ፦ 1600 ብር ውስን ፍሬ ነው ያ...: tuple index out of range
2025-07-22 19:07:55,030 - INFO - Computed LIME explanation for message: 3 ተቀያያሪ ብርሀን ያለው በቻርጅ የሚሰራ ዋጋ፦ 1600 ብር ውስን ፍሬ ነው ያ...
2025-07-22 19:07:55,033 - INFO - Processed 

In [11]:
# Save results
df_results = pd.DataFrame(results)
df_results.to_csv(output_file, index=False, encoding='utf-8')
logger.info(f"Saved interpretability results to {output_file}")

2025-07-22 19:13:49,614 - INFO - Saved interpretability results to ../data/interpretability_results.csv


In [12]:
# Generate interpretability report
with open(report_file, 'w', encoding='utf-8') as f:
    f.write("# Named Entity Recognition Model Interpretability Report\n\n")
    f.write("This report provides insights into the behavior of the Amharic NER model using SHAP and LIME.\n\n")
    
    f.write("## 1. Model Performance\n")
    f.write("The fine-tuned `Davlan/afro-xlmr-base` model achieved an F1 score of 0.9187 (from Task 4) on the Amharic NER dataset.\n\n")

    f.write("## 2. Difficult Cases from Validation Set\n\n")
    if difficult_cases:
        f.write(f"Identified {len(difficult_cases)} difficult cases (mismatches or low-confidence predictions, logits < 0.9):\n\n")
        for i, case in enumerate(difficult_cases[:5]):  # Limit to 5 for brevity
            f.write(f"### Case {i+1}:\n")
            f.write(f"- **Message:** {case['message']}\n")
            f.write(f"- **True Labels:** {', '.join(case['true_labels'])}\n")
            f.write(f"- **Predicted Labels:** {', '.join(case['pred_labels'])}\n")
            f.write(f"- **Max Logits (Confidence):** {', '.join(f'{l:.2f}' for l in case['logits'])}\n\n")
    else:
        f.write("No difficult cases identified based on the defined criteria.\n\n")

    f.write("## 3. SHAP and LIME Explanations\n\n")
    f.write("SHAP and LIME were applied to the first 5 validation examples and new messages to explain token contributions.\n\n")
    explained_results = [r for r in results if r.get('shap_values') and r.get('shap_values') != '[]']
    if explained_results:
        for i, row in enumerate(explained_results[:5]):
            f.write(f"### Explanation for Message {i+1}:\n")
            f.write(f"- **Message:** {row['message']}\n")
            f.write(f"- **Predicted Labels:** {row['pred_labels']}\n")
            f.write(f"- **SHAP Values:** {row['shap_values']}\n")
            f.write(f"- **LIME Weights:** {row['lime_weights']}\n\n")
    else:
        f.write("No SHAP/LIME explanations generated (possible errors during processing).\n\n")
    
    f.write("## 4. Recommendations\n")
    f.write("- **Increase Dataset Size**: Label 300–500 additional messages to improve robustness, especially for rare labels like `I-LOC`.\n")
    f.write("- **Handle Ambiguity**: Add training examples with ambiguous tokens (e.g., numbers like '360' as attributes or prices).\n")
    f.write("- **Optimize Inference**: Quantize the model to reduce inference time (0.5105s/sample from Task 4).\n")
    f.write("- **Custom Explainers**: Develop NER-specific explainers for better token-level insights.\n\n")

    f.write("## 5. Conclusion\n")
    f.write("SHAP and LIME provide valuable insights into the `Davlan/afro-xlmr-base` model’s decisions for Amharic NER. Addressing dataset limitations and ambiguities will further enhance performance. Results are saved in `interpretability_results.csv`, and logs are in `interpret_ner_log.log`.\n")

logger.info(f"Generated interpretability report to {report_file}")

2025-07-22 19:14:24,904 - INFO - Generated interpretability report to ../Task5_Interpretability_Report.md
