In [None]:

import os
import torch
import torch.nn as nn
import numpy as np
import shap
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt


device = torch.device("cpu")
print(f"Using device: {device}")


model_path = r"C:\Users\Ekaansh\OneDrive\Desktop\AB\research\SHAP\best_pneumonia_densenet121.pt"
dataset_dir = r"D:\datasets\chest_xray\val\PNEUMONIA"  # pneumonia images
save_dir = r"C:\Users\Ekaansh\OneDrive\Desktop\AB\research\SHAP\finding\shap_pred_1"

os.makedirs(save_dir, exist_ok=True)
class_names = ['NORMAL', 'PNEUMONIA']


model = models.densenet121(weights=None)
num_features = model.classifier.in_features
model.classifier = nn.Sequential(
    nn.Linear(num_features, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, len(class_names))
)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()


mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

def preprocess_input(images):
    images = images / 255.0
    images = (images - mean) / std
    images = images.transpose(0, 3, 1, 2)
    return torch.tensor(images).float()


def f(x):
    x_tensor = preprocess_input(x)
    with torch.no_grad():
        out = model(x_tensor)
        probs = torch.softmax(out, dim=1)
    return probs.cpu().numpy()


all_images = sorted([os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
selected_images = all_images[:50]

X = np.stack([np.array(Image.open(p).convert("RGB").resize((224, 224))) for p in selected_images])


masker = shap.maskers.Image("inpaint_telea", X[0].shape)
explainer = shap.Explainer(f, masker, output_names=class_names)

for idx, img_array in enumerate(X):
    shap_values = explainer(img_array[np.newaxis, ...], max_evals=5000, batch_size=20, outputs=shap.Explanation.argsort.flip[:2])
    
    plt.figure(figsize=(5, 5))
    shap.image_plot(shap_values, show=False)
    
    filename = os.path.basename(selected_images[idx])
    save_path = os.path.join(save_dir, f"shap_{filename}.png")
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    print(f"Saved: {save_path}")


Using device: cpu


  model.load_state_dict(torch.load(model_path, map_location=device))


  0%|          | 0/4998 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [07:03, 423.67s/it]              


Saved: C:\Users\Ekaansh\OneDrive\Desktop\AB\research\SHAP\finding\shap_pred_1\shap_person1946_bacteria_4874.jpeg.png


  0%|          | 0/4998 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [06:24, 384.69s/it]              


Saved: C:\Users\Ekaansh\OneDrive\Desktop\AB\research\SHAP\finding\shap_pred_1\shap_person1946_bacteria_4875.jpeg.png


  0%|          | 0/4998 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [06:08, 368.94s/it]              


Saved: C:\Users\Ekaansh\OneDrive\Desktop\AB\research\SHAP\finding\shap_pred_1\shap_person1947_bacteria_4876.jpeg.png


  0%|          | 0/4998 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [05:54, 354.41s/it]              


Saved: C:\Users\Ekaansh\OneDrive\Desktop\AB\research\SHAP\finding\shap_pred_1\shap_person1949_bacteria_4880.jpeg.png


  0%|          | 0/4998 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [06:04, 364.12s/it]              


Saved: C:\Users\Ekaansh\OneDrive\Desktop\AB\research\SHAP\finding\shap_pred_1\shap_person1950_bacteria_4881.jpeg.png


  0%|          | 0/4998 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [06:14, 374.05s/it]              


Saved: C:\Users\Ekaansh\OneDrive\Desktop\AB\research\SHAP\finding\shap_pred_1\shap_person1951_bacteria_4882.jpeg.png


  0%|          | 0/4998 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [07:12, 432.91s/it]              


Saved: C:\Users\Ekaansh\OneDrive\Desktop\AB\research\SHAP\finding\shap_pred_1\shap_person1952_bacteria_4883.jpeg.png


  0%|          | 0/4998 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [08:02, 482.68s/it]              


Saved: C:\Users\Ekaansh\OneDrive\Desktop\AB\research\SHAP\finding\shap_pred_1\shap_person1954_bacteria_4886.jpeg.png


<Figure size 500x500 with 0 Axes>

<Figure size 500x500 with 0 Axes>

<Figure size 500x500 with 0 Axes>

<Figure size 500x500 with 0 Axes>

<Figure size 500x500 with 0 Axes>

<Figure size 500x500 with 0 Axes>

<Figure size 500x500 with 0 Axes>

<Figure size 500x500 with 0 Axes>