1. Imports and Initial Setup

In [52]:
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
import numpy as np
from lime.lime_text import LimeTextExplainer
import shap
import matplotlib.pyplot as plt

2. Model Loading and Setup

In [53]:

# Load your saved best model
model_path = "./best_model_xlm-roberta-base"
model = AutoModelForTokenClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Ensure model is in evaluation mode
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

label_list = list(model.config.id2label.values())


SHAP Implementation Section

In [None]:

def shap_explanation(text, model, tokenizer):
    """Simplified SHAP explanation for space-tokenized text"""
    print("\nGenerating SHAP-style explanation...")

    words = space_tokenizer(text)
    num_words = len(words)

    print("Feature importance by position:")
    print("{:15} {:10} {}".format("Word", "Position", "Importance"))
    print("-" * 40)

    # Create a simple positional importance analysis
    for i, word in enumerate(words):
        # Create modified versions of the text
        original_pred = predict_ner(text)[1][i]

        # Create text with this word masked
        masked_words = words.copy()
        masked_words[i] = "[MASK]"
        masked_text = " ".join(masked_words)
        masked_pred = predict_ner(masked_text)[1][i]

        # Calculate importance as prediction change
        importance = abs(original_pred - masked_pred)

        print("{:15} {:10} {:.2f}".format(
            word,
            i,
            importance
        ))

    print("\nKey:")
    print("Importance = How much prediction changes when word is masked")


2. LIME Implementation Section

In [63]:

def lime_explanation(text, model, tokenizer, target_label_idx):
    """Simplified LIME explanation for space-tokenized text"""
    print(f"\nGenerating LIME-style explanation for {label_list[target_label_idx]}...")

    words = space_tokenizer(text)
    num_words = len(words)

    # Create neighborhood of examples by removing one word at a time
    neighborhood = []
    for i in range(num_words):
        modified_words = words.copy()
        modified_words[i] = "[MASK]"
        neighborhood.append(" ".join(modified_words))

    # Get predictions for all examples
    predictions = []
    for example in [text] + neighborhood:
        _, preds = predict_ner(example)
        predictions.append(preds)

    # Calculate importance for each position
    importance_scores = []
    original_preds = predictions[0]
    for i in range(num_words):
        # Focus on the target label's prediction at this position
        original_score = (original_preds[i] == target_label_idx)
        modified_score = (predictions[i+1][i] == target_label_idx)
        importance = abs(original_score - modified_score)
        importance_scores.append((words[i], i, importance))

    # Sort by importance
    importance_scores.sort(key=lambda x: x[2], reverse=True)

    print("\nTop influential words:")
    print("{:15} {:10} {}".format("Word", "Position", "Influence"))
    print("-" * 40)
    for word, pos, imp in importance_scores[:min(5, num_words)]:
        print("{:15} {:10} {:.2f}".format(word, pos, imp))



3. Custom Tokenizer Setup

In [56]:
# Custom space-based tokenizer
def space_tokenizer(text):
    return text.split()

# Wrapper to make compatible with transformers and SHAP
class SpaceTokenizerWrapper:
    def tokenize(self, text):
        return space_tokenizer(text)

    def convert_tokens_to_ids(self, tokens):
        return [i for i in range(len(tokens))]

    def convert_ids_to_tokens(self, ids):
        return [f"token_{i}" for i in ids]

    # Add a __call__ method for SHAP compatibility
    def __call__(self, text, **kwargs):
        # This method should return an object similar to a Hugging Face tokenizer output
        # For the purpose of SHAP's Text masker, simply returning the list of words might suffice
        # or a structure that mimics the tokenizer's output.
        # Let's try returning a list of words for now, as that's what text.split() gives.
        return space_tokenizer(text)


space_tokenizer_wrapper = SpaceTokenizerWrapper()

4. Label Definitions

In [57]:
# Define NER label categories
label_list = [
    "O",
    "B-Product", "I-Product",
    "B-PRICE", "I-PRICE",
    "B-LOC", "I-LOC",
    "B-CONTACT", "I-CONTACT"
]

5. Core Prediction Functions

In [58]:
def predict_ner(text):
    """Predict NER tags using space tokenization"""
    words = space_tokenizer(text)
    input_ids = torch.tensor([[i for i in range(len(words))]]).to(device)
    attention_mask = torch.tensor([[1]*len(words)]).to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

    predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
    return words, predictions

def visualize_predictions(text):
    """Color-coded visualization of predictions"""
    words, preds = predict_ner(text)
    colors = {
        "Product": "\033[91m",  # Red
        "PRICE": "\033[92m",    # Green
        "LOC": "\033[94m",      # Blue
        "CONTACT": "\033[93m",  # Yellow
    }
    reset_color = "\033[0m"

    print("\nPrediction Visualization:")
    for word, pred in zip(words, preds):
        label = label_list[pred]
        if label == "O":
            print(word, end=" ")
        else:
            ent_type = label.split("-")[-1]
            color = colors.get(ent_type, "")
            print(f"{color}{word}{reset_color}", end=" ")
    print("\n")

6. Evaluation Functions

In [59]:
def analyze_errors(text, true_labels):
    """Compare predictions with ground truth"""
    words, preds = predict_ner(text)
    true_labels = true_labels.split()

    print("\nError Analysis:")
    print("{:20} {:15} {:15}".format("Word", "True", "Predicted"))
    print("-" * 50)

    errors = []
    for word, true, pred in zip(words, true_labels, preds):
        pred_label = label_list[pred]
        if true != pred_label:
            errors.append((word, true, pred_label))
        print("{:20} {:15} {:15}".format(word, true, pred_label))

    print("\nSummary:")
    print(f"Total words: {len(words)}")
    print(f"Errors: {len(errors)}")
    print(f"Accuracy: {(len(words)-len(errors))/len(words):.2%}")

    return errors

7. Test Cases Definition

In [60]:
test_cases = [
    {
        "text": "BARDEFU 2 IN 1 Multipurpose juicer ኳሊቲ የጁስ መፍጫ ዋጋ 6800 ብር",
        "true_labels": "B-Product I-Product I-Product I-Product I-Product I-Product O B-Product I-Product O B-PRICE I-PRICE"
    },
    {
        "text": "8000Watt ምላጮቹ ጠንካራ የሆኑ ለቤት ዋጋ 6800 ብር",
        "true_labels": "B-Product I-Product I-Product O O O B-PRICE I-PRICE I-PRICE"
    },
    {
        "text": "አድራሻ ቁ1 መገናኛ ታሜ ጋስ ህንፃ ጎን ስሪ ኤም ሲቲ ሞል 0909522840",
        "true_labels": "B-LOC I-LOC I-LOC I-LOC I-LOC I-LOC O I-LOC I-LOC I-LOC B-CONTACT"
    }
]

8. Test Execution

In [61]:
for i, test_case in enumerate(test_cases, 1):
    print(f"\n{'='*50}")
    print(f"TEST CASE {i}: {test_case['text']}")
    print(f"{'='*50}")

    # 1. Basic prediction
    visualize_predictions(test_case['text'])

    # 2. Error analysis
    if 'true_labels' in test_case:
        errors = analyze_errors(test_case['text'], test_case['true_labels'])

        if errors:
            error_word, true_label, pred_label = errors[0]
            print(f"\nFirst error: '{error_word}' (True: {true_label}, Pred: {pred_label})")

            # Add SHAP and LIME explanations for errors
            try:
                # SHAP explanation
                shap_explanation(test_case['text'], model, tokenizer)

                # LIME explanation for the true label
                true_label_idx = label_list.index(true_label)
                lime_explanation(test_case['text'], model, tokenizer, true_label_idx)
            except Exception as e:
                print(f"Interpretability failed: {str(e)}")

    # 3. General explanations even without errors
    try:
        # SHAP explanation for Product entities
        shap_explanation(test_case['text'], model, tokenizer)

        # LIME explanation for first entity type
        lime_explanation(test_case['text'], model, tokenizer, 1)  # 1 = B-Product
    except Exception as e:
        print(f"General interpretability failed: {str(e)}")

    print(f"\n{'='*50}")
    print(f"COMPLETED TEST CASE {i}")
    print(f"{'='*50}\n")


TEST CASE 1: BARDEFU 2 IN 1 Multipurpose juicer ኳሊቲ የጁስ መፍጫ ዋጋ 6800 ብር

Prediction Visualization:
[93mBARDEFU[0m [93m2[0m [93mIN[0m [93m1[0m [93mMultipurpose[0m [93mjuicer[0m [93mኳሊቲ[0m [93mየጁስ[0m [93mመፍጫ[0m [93mዋጋ[0m [93m6800[0m [93mብር[0m 


Error Analysis:
Word                 True            Predicted      
--------------------------------------------------
BARDEFU              B-Product       I-CONTACT      
2                    I-Product       I-CONTACT      
IN                   I-Product       I-CONTACT      
1                    I-Product       I-CONTACT      
Multipurpose         I-Product       I-CONTACT      
juicer               I-Product       I-CONTACT      
ኳሊቲ                  O               I-CONTACT      
የጁስ                  B-Product       I-CONTACT      
መፍጫ                  I-Product       I-CONTACT      
ዋጋ                   O               I-CONTACT      
6800                 B-PRICE         I-CONTACT      
ብር                   I-PRICE 