In [51]:
import numpy as np

def kernel(X, Y, sigma=0.25):
    d = np.zeros((X.shape[0], Y.shape[0]))
    for i in range(X.shape[0]):
        for j in range(Y.shape[0]):
            d[i, j] = np.linalg.norm(X[i] - Y[j])
    return np.exp(-(d ** 2) / sigma ** 2)

In [52]:
from skimage.segmentation import slic
import numpy as np
import matplotlib.pyplot as plt
import sys
#np.set_printoptions(threshold=sys.max size)
import copy
from tqdm.auto import tqdm

def create_perturbed_image_and_weights(img, kernel_size=4, max_dist=200, ratio = 0.25, num_samples=1000, kernel_width = 0.25):
    permuted_img = img.permute(1,2,0)
    segments = slic(permuted_img, n_segments=100, compactness=20, start_label=0)

    n_features = segments.max() + 1
    z_prime = np.random.randint(0, 2, num_samples*n_features).reshape((num_samples, n_features))
    z_prime[0, :] = 1
    grey_color = (torch.tensor([0.485, 0.456, 0.406]) * 255).numpy().astype(np.uint8)
    grey_image = np.full((224, 224, 3), grey_color, dtype=np.uint8)
    
    #create z with z_prime
    perturbed_images = []
    for space in tqdm(z_prime):
        temp = copy.deepcopy(permuted_img)
        zeros = np.where(space == 0)[0]
        mask = np.zeros(segments.shape).astype(bool)
        for z in zeros:
            mask[segments == z] = True
            temp[mask] = torch.from_numpy(grey_image).float()[mask]
        perturbed_images.append(temp.permute(2,0,1))
        
    perturbed_images = torch.stack(perturbed_images)
    weights = kernel(z_prime[0].reshape(1, -1), z_prime, kernel_width)
    return (segments, z_prime, weights, perturbed_images)


In [53]:
def labels_of_perturbed_images(perturbed_images, model, num_samples = 1000):

    if torch.cuda.is_available():
        perturbed_images = perturbed_images.to('cuda')
    with torch.no_grad():
        perturbed_outputs = model(perturbed_images)  # Pass the batch through the model
    return perturbed_outputs 

In [54]:
from sklearn.linear_model import Ridge
from skimage.color import label2rgb
import matplotlib.pyplot as plt
import numpy as np

def LIME_explanation(z_prime, perturbed_images_labels, target_label_idx, weights, segments, original_image, num_top_features=10):
   
    weights = np.asarray(weights).flatten()

    target_label_column = perturbed_images_labels[:, target_label_idx]
    
    reg = Ridge(alpha=2.0, fit_intercept=True)
    reg.fit(z_prime, target_label_column, sample_weight=weights)
    
    coefficients = reg.coef_
    
    top_features = np.argsort(-np.abs(coefficients))[:num_top_features]
    
    mask = np.zeros(segments.shape, dtype=bool)
    for feature in top_features:
        mask[segments == feature] = True
    
    original_image_np = original_image.permute(1, 2, 0).cpu().numpy()
    highlighted_image = original_image_np.copy()
    
    dimmed_image = np.mean(original_image_np, axis=-1, keepdims=True) * 0.5  
    highlighted_image[~ mask] = dimmed_image[~ mask] 
    
    # Plot the result
    plt.imshow(highlighted_image)
    plt.axis('off')
    plt.title("Highlighted Explanation with Original Colors")
    plt.show()



In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import json
import os 

# Load the pre-trained ResNet18 model
model = models.resnet18(pretrained=True)
model.eval()  # Set model to evaluation mode

# Define the image preprocessing transformations
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]   
    )
])

# Load the ImageNet class index mapping
with open("imagenet_class_index.json") as f:
    class_idx = json.load(f)
idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
idx2synset = [class_idx[str(k)][0] for k in range(len(class_idx))]
id2label = {v[0]: v[1] for v in class_idx.values()}

idx2label_explanation = [class_idx[str(k)][1] for k in range(len(class_idx))]
idx2synset_explanation = [class_idx[str(k)][0] for k in range(len(class_idx))]
id2label_explanation = {v[0]: v[1] for v in class_idx.values()}

imagenet_path = './imagenet_samples'

# List of image file paths\
image_paths = [f for f in os.listdir(imagenet_path) if not f.startswith('.')]
#image_paths = [image_paths[4]]
for img_path in image_paths:
    # Open and preprocess the image
    my_img = os.path.join(imagenet_path, img_path)
    input_image = Image.open(my_img).convert('RGB')
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)  # Create a mini-batch as expected by the model

    # Move the input and model to GPU if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    # Perform inference
    with torch.no_grad():
        output = model(input_batch)
    # Get the predicted class index
    _, predicted_idx = torch.max(output, 1)
    predicted_idx = predicted_idx.item()
    predicted_synset = idx2synset[predicted_idx]
    predicted_label = idx2label[predicted_idx]
    
    predicted_class = torch.argmax(output, dim=1).item()
    print(f"Predicted label: {predicted_synset} ({predicted_label})")

    segments, z_prime, weights, perturbed_images = create_perturbed_image_and_weights(input_tensor.squeeze(), num_samples = 500)
    
    perturbed_outputs = labels_of_perturbed_images(perturbed_images, model)
    perturbed_outputs_np = torch.softmax(perturbed_outputs, dim=1).cpu().numpy()
    
    target_label_idx = predicted_class

    LIME_explanation(
        z_prime=z_prime,
        perturbed_images_labels=perturbed_outputs_np,
        target_label_idx=target_label_idx,
        weights=weights,
        segments=segments,
        original_image=input_tensor.squeeze(),
        num_top_features=30
    )

    



Predicted label: n03250847 (drumstick)


  0%|          | 0/500 [00:00<?, ?it/s]