In [1]:
import sys

# Add the root directory to path
sys.path.append("..")

In [2]:
import torch
from lib.model import get_model


use_cuda = torch.cuda.is_available()
device = torch.device('gpu:0') if use_cuda is True else torch.device('cpu')
model = get_model(half=use_cuda, device=device)
model.max_det = 1
print(f'usging device: {device}')

  from .autonotebook import tqdm as notebook_tqdm


usging device: cpu


In [3]:
from PIL import Image
import numpy as np


def make_prediction(img: Image.Image) -> np.ndarray:
    return model(img)

In [4]:
from typing import Tuple, Iterable


DATASET_PATH = '/Users/szymonsadkowski/Downloads/images'
NUM_IMAGES = 100

def iter_images(start: int = 0, end: int = NUM_IMAGES) -> Iterable[Tuple[np.ndarray, np.ndarray, np.ndarray, int]]:
    for i in range(start, end):
        org = np.load(f'{DATASET_PATH}/original-{i}.npy')
        d_patched = np.load(f'{DATASET_PATH}/patched-{i}.npy')
        random_patched = np.load(f'{DATASET_PATH}/random-{i}.npy')
        yield org / 255, d_patched / 255, random_patched / 255, i

In [5]:
from lib.segmentation import SuperPixler
from lib.explanations import calculate_shap
from lib.image import img_float_to_uint


In [6]:
N_SAMPES = 60
BATCH_SIZE = 250
N_SEGMENTS = 30
SIGMA = 0
COMPACTNESS = 60


def calc_avg_patch_contrib(contrib: np.ndarray) -> float:
    return np.mean(contrib)

def calc_var_patch_contrib(patch_contr: np.ndarray) -> float:
    return np.var(patch_contr)

def calc_mean_abs_patch_contrib_diff(contrib1: np.ndarray, contrib2: np.ndarray) -> float:
    return np.mean(np.abs(contrib1 - contrib2))

def calc_var_abs_patch_contrib_diff(contrib1: np.ndarray, contrib2: np.ndarray) -> float:
    return np.var(np.abs(contrib1 - contrib2))

def calc_most_positive_patch_invariant(contrib1: np.ndarray, contrib2: np.ndarray) -> bool:
    return np.argmax(contrib1) == np.argmax(contrib2)

def calc_most_negative_patch_invariant(contrib1: np.ndarray, contrib2: np.ndarray) -> bool:
    return np.argmin(contrib1) == np.argmin(contrib2)

def normalize_contrib(contrib: np.ndarray) -> np.ndarray:
    return contrib / np.sum(np.abs(contrib))

In [9]:
from itertools import combinations

contribs = []


for original, dpatched, random_patched, img_id in iter_images():  # floats arrs range [0, 1]
    print(f'processing image {img_id}')

    superpixler = SuperPixler(original, n_segments=N_SEGMENTS, sigma=SIGMA, compactness=COMPACTNESS)
    detections = model(img_float_to_uint(original)).xyxy[0].cpu().numpy()
    most_conf_det = np.argmax(detections[:, 4])
    target_bbox = detections[most_conf_det, :4].reshape(1, 4)

    # original
    org_shap_v = calculate_shap(model, superpixler, target_bbox, nsamples=N_SAMPES, batch_size=BATCH_SIZE, half=use_cuda)

    # dpatched
    superpixler.image = dpatched
    dpatch_shap_v = calculate_shap(model, superpixler, target_bbox, nsamples=N_SAMPES, batch_size=BATCH_SIZE, half=use_cuda)

    # random patch
    superpixler.image = random_patched
    rand_patch_shap_v = calculate_shap(model, superpixler, target_bbox, nsamples=N_SAMPES, batch_size=BATCH_SIZE, half=use_cuda)


    # calculate metrics
    patch_contribs = [(org_shap_v, 'org'), (dpatch_shap_v, 'dpatch'), (rand_patch_shap_v, 'rand')]

    contrib = {}
    contrib['img_id'] = img_id
    for patch_contr, name in patch_contribs:
        contrib[f'shap_contrib_{name}'] = patch_contr

    contribs.append(contrib)


processing image 0
processing image 1
processing image 2
processing image 3
processing image 4
processing image 5
processing image 6
processing image 7
processing image 8
processing image 9
processing image 10
processing image 11
processing image 12
processing image 13
processing image 14
processing image 15
processing image 16
processing image 17
processing image 18
processing image 19
processing image 20
processing image 21
processing image 22
processing image 23
processing image 24
processing image 25
processing image 26
processing image 27
processing image 28
processing image 29
processing image 30
processing image 31
processing image 32
processing image 33
processing image 34
processing image 35
processing image 36
processing image 37
processing image 38
processing image 39
processing image 40
processing image 41
processing image 42
processing image 43
processing image 44
processing image 45
processing image 46
processing image 47
processing image 48
processing image 49
processing

In [11]:
import pandas as pd

df = pd.DataFrame(contribs)