<a href="https://colab.research.google.com/github/rishindrasai/ExplainableAI_Assignment/blob/main/XAI_Lab_Assignment_4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_20newsgroups
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.inspection import permutation_importance
import shap
from lime.lime_text import LimeTextExplainer

df = pd.read_csv("/content/preprocessed_final.csv", engine='python', on_bad_lines='skip')
print("Dataset shape:", df.shape)
df.head()

categories = None
train = fetch_20newsgroups(subset='train', categories=categories, remove=('headers', 'footers', 'quotes'))
test = fetch_20newsgroups(subset='test', categories=categories, remove=('headers', 'footers', 'quotes'))

vectorizer = TfidfVectorizer(stop_words='english', max_features=10000)
clf = LogisticRegression(max_iter=1000, random_state=42)
pipe = make_pipeline(vectorizer, clf)
pipe.fit(train.data, train.target)

def predict_proba(texts):
    return pipe.predict_proba(texts)

feature_names = vectorizer.get_feature_names_out()

# Transform test data to feature vectors
# Add a check for non-string elements in test.data
processed_test_data = []
for item in test.data:
    if isinstance(item, str):
        processed_test_data.append(item)
    else:
        print(f"Warning: Found non-string element in test.data: {item} (type: {type(item)}). Skipping or converting to string.")
        # Optionally, convert to string if appropriate:
        # processed_test_data.append(str(item))

X_test = vectorizer.transform(processed_test_data).toarray()

# Use the fitted classifier and vectorized test data for permutation_importance
result = permutation_importance(clf, X_test, test.target[:len(processed_test_data)], n_repeats=10, random_state=42, scoring='accuracy')

sorted_idx = result.importances_mean.argsort()[::-1]
top_n = 20
top_features = feature_names[sorted_idx[:top_n]]
top_importances = result.importances_mean[sorted_idx[:top_n]]

plt.figure(figsize=(10,6))
plt.barh(top_features[::-1], top_importances[::-1], color='skyblue')
plt.xlabel("Mean decrease in accuracy (Permutation Importance)")
plt.title("Top 20 Important Words by Permutation Importance")
plt.tight_layout()
plt.show()

background = train.data[:100]

def shap_predict(texts):
    probs = pipe.predict_proba(texts)
    preds = pipe.predict(texts)
    return np.array([probs[i, preds[i]] for i in range(len(texts))])

explainer = shap.KernelExplainer(shap_predict, background)

doc_idx = 0
# Ensure the document for SHAP is a string
doc = [str(test.data[doc_idx])]

shap_values = explainer.shap_values(doc, nsamples=100)

sample_test = test.data[:100]
# Ensure sample_test for global SHAP is a list of strings
processed_sample_test = []
for item in sample_test:
    if isinstance(item, str):
        processed_sample_test.append(item)
    else:
        print(f"Warning: Found non-string element in sample_test: {item} (type: {type(item)}). Skipping or converting to string.")
        # Optionally, convert to string if appropriate:
        # processed_sample_test.append(str(item))

shap_values_global = explainer.shap_values(processed_sample_test, nsamples=100)

shap_values_array = shap_values_global[0]

X_sample = vectorizer.transform(processed_sample_test).toarray()

shap.summary_plot(shap_values_array, features=X_sample, feature_names=feature_names, max_display=20, show=True)

shap.initjs()
shap.force_plot(explainer.expected_value, shap_values[0], doc, feature_names=feature_names, matplotlib=True)

class_names = train.target_names
explainer_lime = LimeTextExplainer(class_names=class_names)

def explain_with_lime(doc_index):
    doc_text = test.data[doc_index]
    # Ensure the document for LIME is a string
    if not isinstance(doc_text, str):
         print(f"Warning: Document at index {doc_index} is not a string (type: {type(doc_text)}). Converting to string for LIME.")
         doc_text = str(doc_text)

    exp = explainer_lime.explain_instance(doc_text, pipe.predict_proba, num_features=10)
    print(f"\nLIME explanation for document index {doc_index} (true class: {class_names[test.target[doc_index]]}):")
    exp.show_in_notebook(text=doc_text)
    return exp

exp1 = explain_with_lime(0)
exp2 = explain_with_lime(1)

print("\nComparative Analysis Insights:")
print(""" """)

Dataset shape: (535014, 6)


In [None]:
!pip install lime

Collecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/275.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m266.2/275.7 kB[0m [31m12.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: lime
  Building wheel for lime (setup.py) ... [?25l[?25hdone
  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283834 sha256=3d17d749edd50400b7622d5aa0261b8c7bdeac34da2f5f9b7d457a0ae71de82e
  Stored in directory: /root/.cache/pip/wheels/e7/5d/0e/4b4fff9a47468fed5633211fb3b76d1db43fe806a17fb7486a
Successfully built lime
Installing collected packages: lime
Successfully installed lime-0.2.0.1
