In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import sys
from PIL import Image
sys.path.append("..")
sys.path.append("../backbone")
from dataloader import create_dataloader
import torchvision.transforms as transforms
from transform import Transform
from data import RetinaDataset
import shap
torch.cuda.empty_cache()

# Set device to GPU if available, else use CPU
device = torch.device("cuda:0")
model = torch.load('models/dn201.pth')
model.to(device)
# device = torch.device("cpu")
print(device)
print(f"Current device: {torch.cuda.get_device_name(torch.cuda.current_device())}" if torch.cuda.is_available() else "Current device: CPU")

In [None]:
data_dir = '../../../data/GT-main'
batch_size = 16
image_size = 384
num_labels = 21
num_workers = 4
phase= 'test'
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
test_dataloader = create_dataloader(data_dir=data_dir, batch_size=batch_size, num_workers=num_workers, size=image_size, phase=phase)

In [None]:
transform = Transform(size=image_size, phase=phase)
dataset = RetinaDataset(data_dir=data_dir, split=phase, transform=None)
image = dataset[0][0]
plt.imshow(image)
plt.axis('off')
plt.show()

In [None]:
def predict(x):
    tmp = torch.tensor(x).to(device)
    return torch.sigmoid(model(tmp))

In [None]:
# Initialize Shapley Explainer
for images, labels in test_dataloader: break
print("Shape of images:", images[0].shape)  

masker = shap.maskers.Image("blur(64,64)", shape=images[0].shape)
explainer = shap.Explainer(predict, masker)

print("Type of explainer:", type(explainer))

shap_values = explainer(images[-1].unsqueeze(0), max_evals=100, batch_size=50, outputs=shap.Explanation.argsort.flip[:10])

In [None]:
print(shap_values.data.shape, shap_values.values.shape)
print(shap_values.values.transpose(0, 4, 2, 3, 1)[0])

In [None]:
# Define the inverse transformation function
inv_transform = transforms.Compose([
    # transforms.ToPILImage(),  # Convert tensor to PIL image
    transforms.Resize((image_size, image_size)),  # Resize back to original size
    # transforms.ToTensor(),  # Convert PIL image to tensor
    # transforms.Lambda(lambda x: x.permute(0, 2, 3, 1)),
    # transforms.Lambda(lambda x: x[:3, :, :]),  # Keep only the first 3 channels (RGB)
    # transforms.Normalize(mean=(-1 * np.array(mean) / np.array(std)).tolist(),std=(1 / np.array(std)).tolist(),),
])

# Assuming shap_values contains your SHAP values
shap_data = inv_transform(shap_values.data).cpu().numpy()[0][1]
shap_val = [val for val in np.moveaxis(shap_values.values.transpose(0, 2, 3, 1, 4)[0], -1, 0)]

In [None]:
print(shap_data, shap_val)

In [None]:
shap.image_plot(
    shap_values=shap_val,
    pixel_values=shap_data,
    labels=shap_values.output_names,
    true_labels=[labels[0][:]],
)

In [None]:
# # Convert Shapley values to numpy array
# shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]

# # Plot heat map for each class
# for class_idx in range(num_classes):
#     plt.figure()
#     plt.title('Shapley Heatmap for Class {}'.format(class_idx))
#     plt.imshow(shap_numpy[class_idx][0], cmap='hot', interpolation='nearest')
#     plt.axis('off')
#     plt.show()

In [None]:
# image_files = [f for f in os.listdir(data_dir) if f.endswith('.jpg') or f.endswith('.png')]

# # Define a simple dataset class to load images
# class SimpleImageDataset(torch.utils.data.Dataset):
#     def __init__(self, data_dir, image_files, transform=None):
#         self.data_dir = data_dir
#         self.image_files = image_files
#         self.transform = transform

#     def __len__(self):
#         return len(self.image_files)

#     def __getitem__(self, idx):
#         img_name = os.path.join(self.data_dir, self.image_files[idx])
#         image = Image.open(img_name).convert('RGB')
#         if self.transform:
#             image = self.transform(image)
#         return image

# # Create test dataset and dataloader
# test_dataset = SimpleImageDataset(data_dir, image_files)
# test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)