In [None]:
from pipelines.data_preparation_pipeline import DataPreparationPipeline
from fake_news_classifier import FakeNewsClassifier
from utils.utils import load_config, set_device
import shap
import lime
import numpy as np
import torch
from transformers import BertTokenizer
from data.testing_dataset import TestingDataset
from torch.utils.data import DataLoader

In [None]:
data_preparation_pipeline = DataPreparationPipeline(
    "configs/pipelines_config/data_preparation_config.json"
)
train_data, test_data, val_data = data_preparation_pipeline.run()

fake_news_classifier = FakeNewsClassifier("configs/classifier_config.json", 7)
model = fake_news_classifier.load_pretrained("models/best_model.pth")

In [None]:
config = load_config("configs/pipelines_config/explainability_config.json")
device = set_device()
model = model.classifier
val_data = val_data
num_background = config["num_background"]
num_explain_samples = config["num_explain_samples"]

contents = val_data["content"].to_list()
labels = val_data["label"].to_list()
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
def predict_proba(texts):
    test_dataset = TestingDataset(texts, tokenizer)
    test_loader = DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=False,
        pin_memory=True,
        num_workers=4,
    )

    probabilities = []

    model.eval()

    with torch.no_grad():
        for batch in test_loader:
            input_ids, attention_mask, token_type_ids = (
                batch["input_ids"].to(device),
                batch["attention_mask"].to(device),
                batch["token_type_ids"].to(device),
            )
            outputs = model(input_ids, attention_mask, token_type_ids)
            probabilities.extend(
                torch.nn.functional.softmax(outputs, dim=1).cpu().tolist()
            )

    return probabilities

In [None]:
def run_shap():
    print("Starting SHAP explainability analysis...")

    label_map = {
        0: "reliable",
        1: "bias",
        2: "conspiracy",
        3: "fake",
        4: "rumor",
        5: "unreliable",
        6: "other",
    }

    class_contents = {i: [] for i in range(len(label_map))}
    for content, label in zip(contents, labels):
        class_contents[label].append(content)

    sampled_contents = []
    for class_label, class_data in class_contents.items():
        sampled_contents.extend(
            np.random.choice(class_data, size=min(10, len(class_data)), replace=False)
        )

    num_explain_samples = len(sampled_contents)

    masker = shap.maskers.Text(tokenizer)

    output_names = [label_map[i] for i in range(len(label_map))]
    explainer = shap.Explainer(predict_proba, masker, output_names=output_names)
    shap_values = explainer(sampled_contents)

    print("SHAP values computed successfully.")

    print(shap.plots.text(shap_values))

    return shap_values


# def run_shap():
#     print("Starting SHAP explainability analysis...")

#     # Assuming 'contents' and 'labels' (the class labels) are defined
#     # 'contents' should be the text data, and 'labels' should be the corresponding class labels
#     label_map = {
#         0: "reliable",
#         1: "bias",
#         2: "conspiracy",
#         3: "fake",
#         4: "rumor",
#         5: "unreliable",
#         6: "other",
#     }

#     # Group the contents by class
#     class_contents = {i: [] for i in range(len(label_map))}
#     for content, label in zip(contents, labels):
#         class_contents[label].append(content)

#     # Sample 10 contents per class
#     sampled_contents = []
#     for class_label, class_data in class_contents.items():
#         # Sample 10 items per class, or all if fewer than 10
#         sampled_contents.extend(
#             np.random.choice(class_data, size=min(10, len(class_data)), replace=False)
#         )

#     # The number of explain samples will be the total number of sampled contents
#     num_explain_samples = len(sampled_contents)

#     # Create a masker for text
#     masker = shap.maskers.Text(tokenizer)

#     # Initialize SHAP Explainer
#     output_names = [label_map[i] for i in range(len(label_map))]
#     explainer = shap.Explainer(predict_proba, masker, output_names=output_names)
#     shap_values = explainer(sampled_contents)

#     print("SHAP values computed successfully.")

#     # Calculate the absolute SHAP values for each token
#     shap_values_array = np.array([np.abs(sv.values).max() for sv in shap_values])

#     # Calculate the threshold based on the 90th percentile (you can change this to 95, etc.)
#     threshold = np.percentile(shap_values_array, 90)
#     print(f"Threshold (90th percentile): {threshold}")

#     # Create filtered shap_values based on the threshold
#     filtered_shap_values = shap.Explanation(
#         values=[sv.values for sv in shap_values if np.abs(sv.values).max() > threshold],
#         base_values=[
#             sv.base_values for sv in shap_values if np.abs(sv.values).max() > threshold
#         ],
#         data=[sv.data for sv in shap_values if np.abs(sv.values).max() > threshold],
#         feature_names=shap_values[0].feature_names,
#         output_names=shap_values[0].output_names,
#     )

#     # Plot the filtered SHAP values using the text plot
#     print(shap.plots.text(filtered_shap_values))

#     return shap_values

In [None]:
shap_values = run_shap()