In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import sys
import os
from PIL import Image
sys.path.append("..")
sys.path.append("../backbone")
from backbones import DenseNet201
from dataloader import create_dataloader
import torchvision.transforms as transforms
from transform import Transform
from data import RetinaDataset
import shap
torch.cuda.empty_cache()

# Set device to GPU if available, else use CPU
device = torch.device("cuda:0")
model = torch.load('models/dn1.pth')
model.to(device)
# device = torch.device("cpu")
print(device)
print(f"Current device: {torch.cuda.get_device_name(torch.cuda.current_device())}" if torch.cuda.is_available() else "Current device: CPU")

In [None]:
data_dir = '../../../data/GT-main'
batch_size = 16
image_size = 384
num_labels = 21
num_workers = 4
idx = 10
phase= 'test'
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
threshold = [0.65, 0.7, 0.49, 0.43, 0.66, 0.78, 0.26, 0.21, 0.38, 0.82, 0.64, 0.9, 0.38, 0.59, 0.86, 0.31, 0.51, 0.69, 0.91, 0.23, 0.37]
class_names = ["DR","NORMAL","MH","ODC","TSLN","ARMD","DN","MYA","BRVO","ODP","CRVO","CNV","RS","ODE","LS","CSR","HTR","ASR","CRS","OTHER","RB"]
test_dataloader = create_dataloader(data_dir=data_dir, batch_size=batch_size, num_workers=num_workers, size=image_size, phase='test')

In [None]:
transform = Transform(size=image_size, phase=phase)
dataset = RetinaDataset(data_dir=data_dir, split=phase, transform=None)
image = dataset[idx][0]
label = dataset[idx][1]
print(label)
plt.imshow(image)
plt.axis('off')
plt.show()

In [None]:
def predict(x):
    tmp = torch.tensor(x).to(device)
    tmp = tmp.permute(0,3,1,2)
    return torch.sigmoid(model(tmp))

In [None]:
for images, labels in test_dataloader: break 
image = images[idx].permute(1,2,0)
print("Shape of images:", image.shape)
output = predict(image.unsqueeze(0))
print(output)

In [6]:
# Initialize Shapley Explainer
masker = shap.maskers.Image("blur(64,64)", shape=image.shape)
explainer = shap.Explainer(predict, masker, output_names=class_names)

print("Type of explainer:", type(explainer))

shap_values = explainer(image.unsqueeze(0), max_evals=10000, batch_size=50, outputs=shap.Explanation.argsort.flip[:4])
torch.cuda.empty_cache()

KeyboardInterrupt: 

In [None]:
print(shap_values.data.shape, shap_values.values.shape)

In [None]:
# Define the inverse transformation function
inv_transform = transforms.Compose([
    # transforms.ToPILImage(),  # Convert tensor to PIL image
    transforms.Lambda(lambda x: x.permute(0, 3, 1, 2)),
    # transforms.Resize((image_size, image_size)),  # Resize back to original size
    # transforms.ToTensor(),  # Convert PIL image to tensor
    # transforms.Lambda(lambda x: x[:3, :, :]),  # Keep only the first 3 channels (RGB)
    transforms.Normalize(mean=(-1 * np.array(mean) / np.array(std)).tolist(),std=(1 / np.array(std)).tolist(),),
    transforms.Lambda(lambda x: x.permute(0, 2, 3, 1)),
])

# Assuming shap_values contains your SHAP values
shap_data = inv_transform(shap_values.data).cpu().numpy()[0]
shap_val = [val for val in np.moveaxis(shap_values.values[0], -1, 0)]

In [None]:
print(shap_data, shap_val)

In [None]:
shap.image_plot(
    shap_values=shap_val,
    pixel_values=shap_data,
    labels=shap_values.output_names,
    true_labels=[labels[idx][:]],
)

sorted_output = torch.sort(output[0], descending=True)
top_preds = np.array(sorted_output[0][:4].detach().cpu().numpy()) 
top_indices = sorted_output[1][:4].cpu().numpy()
formatted_preds = ', '.join([f'{pred:.4f}' for pred in top_preds])

print(f'Top Predictions: {formatted_preds}\nIndices: {top_indices}\nClasses: {[class_names[idx] for idx in top_indices]}')