In [None]:
import torch
import cv2
import os
import numpy as np
import pandas as pd
# print('np',np.__version__)
from functools import partial

import matplotlib
# matplotlib.use('Agg')  # Set a non-interactive backend
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
# import datetime
from datetime import datetime

import time

import chime
chime.theme('chime')


In [None]:
from matplotlib import rc

rc('text',usetex=True)
rc('text.latex', preamble='\\usepackage{color}')

In [None]:
use_cuda = torch.cuda.is_available()
import torchvision
use_mps = ('mps' in dir(torch.backends)) and torch.backends.mps.is_available()

if   use_cuda:    device = torch.device("cuda")
elif use_mps:     device = torch.device("mps")
else:             device = torch.device("cpu")

print(f'{"Torch CUDA":<25} {torch.cuda.is_available()}')
print(f'{"device:":<25} {device}')
print(f'{"Device Name:":<25} {torch.cuda.get_device_name()}')

# params.device = device


# Parameters Setting

In [None]:
from dataclasses import dataclass

@dataclass
class Parameters:
    dummy                   : bool  =   False
    exp_no                  : str   =   ''
    model_name              : str   =   'yolov8s'
    model_type              : str   =   'real'
    background_type         : str   =   'gray'
    top_classes_n           : int   =   3
    batch_size              : int   =   16
    seed                    : int   =   12345
    device                  : str   =   device
    submission_version      : bool  =   False
    data_subset             : int   =   300
params = Parameters()

@dataclass
class Paths:
    dummy           : bool  = False
    path_datasets   : str  =   ''
    DS_name_img : str = ''
    DS_name_annotation : str = ''
    codes_local : str  = ''
    path_ds_main : str = ''
    path_repos : str = ''
    path_codes : str = os.getcwd()#'E:/PHD/datacloud_data/repos/paper_suppl/ShapBPT_ICCV_25/notebooks/E8_MS_COCO'
    path_utils : str = ''
    results_path : str = ''
    results_path_single : str = ''
    results_path_selected : str = ''
    plotsIoU_path       : str =''
    plots_path : str =''
    image_dir : str = ''
    annotation_file : str = ''
    total_imgs : str =  ''
    csv_filename : str = ''
    

paths = Paths()

torch.manual_seed(params.seed)

# Black-Box Model
-   https://docs.ultralytics.com/models/yolov8/


In [None]:
# import ultralytics
from ultralytics import YOLO

# YOLOv8-cls -> Classification
# YOLOv8-seg -> Instance Segmentation

# params.model_name = 'yolov8s'
params.model_name = 'yolo11s'
# params.model_name = 'yolo11n-cls'

model = YOLO(f'{params.model_name}.pt')
class_names = model.names

print(f'{"Model Name":<15}{params.model_name}')
print(f'{"Num_Classes":<15}{len(class_names)}')

In [None]:
a = model.info()
a

### Paths Setting
This cell contains the setting of paths being used to load dataset, codes

In [None]:
# Paths to the dataset
from pycocotools.coco import COCO

paths.path_ds_main = 'D:\\DS\\MS_COCO'
# image_dir = os.path.join(main_path,"unlabeled2017")  # Replace with the path to your images
# annotation_file = f"{main_path}/annotations/image_info_unlabeled2017.json"  # Update for your annotation file
paths.image_dir = os.path.join(paths.path_ds_main,"val2017")  # Replace with the path to your images
paths.annotation_file = f"{paths.path_ds_main}//annotations//annotations_trainval2017//annotations//instances_val2017.json"  # Update for your annotation file

# Load COCO dataset
coco = COCO(paths.annotation_file)

# Get all image IDs
image_ids = coco.getImgIds()

paths.total_imgs = len(image_ids)

print('total images:\t', paths.total_imgs)

# Path Setting

In [None]:
def get_file_modify_time(filepath=None):
    if os.path.exists(filepath):
        datestamp  = datetime.fromtimestamp(os.path.getmtime(filepath))
        print('Modified Date/Time:', datestamp)

In [None]:
def get_subpaths(main_path=None,suffix=None, create_dirs =True, verbose_print=True):
    results_ = os.path.join(main_path,'results')
    results_path = os.path.join(results_,suffix)
    results_path_selected    = os.path.join(results_path    ,'selected')
    results_path_single      = os.path.join(results_path      ,'single')
    plots_path               = os.path.join(results_path,       'plots')
    plotsIoU_path            = os.path.join(results_path,   'plots_IoU')
    path_csv                 = os.path.join(main_path,         'csv')
    if create_dirs:
        os.makedirs(results_path,              exist_ok=True)
        os.makedirs(plots_path,                exist_ok=True)
        os.makedirs(plotsIoU_path,             exist_ok=True)
        os.makedirs(results_path_single,       exist_ok=True)
        os.makedirs(results_path_selected,     exist_ok=True)
        os.makedirs(path_csv,                  exist_ok=True)
    if verbose_print:
        print('-'*120)
        print('results_path\t\t',results_path)
        print('results_path_single\t',results_path_single)
        print('results_path_selected\t',results_path_selected)
        print('paths for CSV file\t',path_csv)
        print('-'*120)
    return results_path,results_path_selected,results_path_single,plots_path,plotsIoU_path,path_csv


In [None]:
suffix                     = f'{params.model_name}_{params.background_type}'

path_current = os.getcwd()
paths.results_path,paths.results_path_selected,paths.results_path_single,paths.plots_path,paths.plotsIoU_path,paths.path_csv \
    = get_subpaths(main_path=path_current, suffix=suffix)


# Plot Class Distribution for Valid Set in MS COCO


In [None]:
from collections import defaultdict

# Get all category IDs and names
categories = coco.loadCats(coco.getCatIds())
category_names = [cat['name'] for cat in categories]
category_ids = [cat['id'] for cat in categories]

category_counts = defaultdict(int)

# Get all annotations
for ann_id in coco.getAnnIds():
    ann = coco.loadAnns(ann_id)
    for annotation in ann:
        category_counts[annotation['category_id']] += 1

# Map category IDs to names for plotting
category_counts_named = {coco.loadCats(cat_id)[0]['name']: count for cat_id, count in category_counts.items()}

# Sort by counts for better visualization
sorted_counts = dict(sorted(category_counts_named.items(), key=lambda item: item[1], reverse=True))

# Plotting
plt.figure(figsize=(20, 8))
bars = plt.bar(sorted_counts.keys(), sorted_counts.values(), color='skyblue')

# Annotate bars with counts
for bar in bars:
    yval = bar.get_height()  # Height of the bar (count value)
    plt.text(
        bar.get_x() + bar.get_width() / 2,  # X-coordinate: center of the bar
        yval + 10,  # Y-coordinate: slightly above the bar
        str(yval),  # Text (count value)
        ha='center',  # Center align
        va='bottom',  # Align to the bottom of the text
        rotation=90,
        fontsize=9
    )

plt.xlabel('Categories', fontsize=12)
plt.ylabel('Number of Annotations', fontsize=12)
plt.title('Class-wise Distribution of Annotations in MS COCO Validation Set', fontsize=15)
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

# TEST IMAGE INFO

In [None]:
# fixed_choice = 139
fixed_choice = 113235

image_info = coco.loadImgs(fixed_choice)[0]
print(image_info)

image_path = os.path.join(paths.image_dir, image_info['file_name'])
print(image_path,image_info['file_name'])
print(image_path, os.path.exists(image_path))


In [None]:
img_size_stats = []

for im_id,image_no in enumerate(tqdm(coco.getImgIds())):
    image_info = coco.loadImgs(image_no)[0]
    # print(image_no,image_info['width'], image_info['height'])
    # print(image_info)
    img_size_stats.append({'image_no':image_no,
                           'width':image_info['width'],
                           'height':image_info['height']})
    # break
df_stats_size = pd.DataFrame(img_size_stats)
df_stats_size.head()

In [None]:
df_stats_size.describe()

## Check Segmentation Map Availability

In [None]:
# image_no = 139
image_no = 113235

image_info = coco.loadImgs(image_no)[0]

ann_ids = coco.getAnnIds(imgIds=image_info['id'])
annotations = coco.loadAnns(ann_ids)
has_segmentation = any('segmentation' in ann for ann in annotations)
print(f"Segmentation annotations present: {has_segmentation}")

## Load Annotation Boxes

In [None]:
categories = coco.loadCats(coco.getCatIds())
coco_categories = {cat['id']: cat['name'] for cat in categories}

# Get annotations for the selected image
ann_ids = coco.getAnnIds(imgIds=image_info['id'])
annotations = coco.loadAnns(ann_ids)

In [None]:
# Check if segmentation annotations are present
has_segmentation = any('segmentation' in ann for ann in annotations)
print(f"Segmentation annotations present: {has_segmentation}")

# Print one segmentation annotation (if available)
if has_segmentation:
    for ann in annotations:
        if 'segmentation' in ann:
            print(f"Segmentation Annotation: {ann['segmentation']}")
            break

In [None]:
# # def get_
# category_name = "person"  # Replace with the desired category name
# category_ids = coco.getCatIds(catNms=[category_name])
# # annotation_ids = coco.getAnnIds(imgIds=image_info['id'])
# annotation_ids = coco.getAnnIds(imgIds=image_info['id'], catIds=category_ids)

# annotations = coco.loadAnns(annotation_ids)
# print(f"Image:{image_info['id']} has {len(annotations)} annotations")

In [None]:
from torchvision import transforms
model_preprocess = transforms.Compose(
    [transforms.ToTensor()]
)

## FN: Load Image

In [None]:
from skimage.filters import gaussian
from scipy.ndimage import gaussian_filter

def load_image(fname,params,im_size=None,bg_type='black'):
    global path_img_val
    global image_to_explain, image_to_explain_preproc, image_to_explain_tensor
    global predicted_fS, predicted_f0, predicted_cls, sorted_classes, f_S, f_0
    global model_type
    global background_tensors, background_image_set, background_image_preproc_set
    

    # file_n = fname.split('//')[-1].split('.')[0]
    # im_size = [224,224]
    # Foreground image to be explained
    # image = cv2.imread(image_path)
    img_ = cv2.imread(f'{fname}')
    image_to_explain = cv2.cvtColor(img_, cv2.COLOR_BGR2RGB) #.astype(np.float32)
    if im_size is not None:
        image_to_explain         = cv2.resize(image_to_explain,im_size)# [:,:,::-1]
    
    
    if params.model_type=='ideal':
        image_to_explain_preproc  = torch.ones(tuple(reversed(image_to_explain.shape)))
    else:
        image_to_explain_preproc  = image_to_explain.copy()#torch.tensor(image_to_explain).to(device)# .astype(np.float32)/255.0
    # print(image_to_explain_preproc.shape, image_to_explain_preproc.dtype)
    
    # image_to_explain_tensor = model_preprocess(image_to_explain_preproc).to(device)
    np.random.seed(0)
    
    bkgnd0 = np.full_like(image_to_explain, 0)
    bkgnd1 = np.full_like(image_to_explain, 127)
    bkgnd2 = np.full_like(image_to_explain, 255)
    bkgnd3 = gaussian(image_to_explain, 8, channel_axis=-1)*255
    bkgnd4 = np.clip(np.random.normal(128, 128, size=image_to_explain.shape), 0, 255).astype(np.uint8)
    bkgnd4 = (gaussian(bkgnd4, 2.0, channel_axis=-1) * 255).astype(np.uint8)
    if bg_type=='black':
        background_image_set = np.array([bkgnd0])
    elif bg_type=='gray':
        background_image_set = np.array([bkgnd1])
    elif bg_type=='white':
        background_image_set = np.array([bkgnd2])
    elif bg_type=='blurred':
        background_image_set = np.array([bkgnd3])
    elif bg_type=='noise':
        background_image_set = np.array([bkgnd4])
    elif bg_type=='full':
        background_image_set = np.array([bkgnd0, bkgnd1, bkgnd2, bkgnd3, bkgnd4])
    if params.model_type=='ideal':
        background_image_preproc_set = [torch.zeros(tuple(reversed(bkgnd.shape)))
                                        for bkgnd in background_image_set]
    else:
        background_image_preproc_set = [model_preprocess(bkgnd.astype(np.float32)/255.0)
                                        for bkgnd in background_image_set]

    background_tensors = torch.cat([torch.unsqueeze(bk_p, dim=0) 
                                    for bk_p in background_image_preproc_set]).to(device)
    


# Load GroundTruth

In [None]:
def get_annotation(coco,image_no,category_name=None):
    if isinstance(image_no, str):
        image_no = int(image_no.split('\\')[-1].split('.')[0])
    
    image_info = coco.loadImgs(image_no)[0]
    if category_name is None:
        annotation_ids = coco.getAnnIds(imgIds=image_info['id'])
    else:
        category_ids = coco.getCatIds(catNms=[category_name])
        annotation_ids = coco.getAnnIds(imgIds=image_info['id'], catIds=category_ids)
    annotations = coco.loadAnns(annotation_ids)
    return annotations

def create_gt(coco,image_no,category_name=None, verbose=False):
    annotations = get_annotation(coco,image_no,category_name=category_name)
    if verbose:
        if len(annotations)>0:
            print(f"Image:{image_info['id']} has {len(annotations)} annotations")
    
    # ann_ids = coco.getAnnIds(imgIds=image_info['id'])
    # annotations = coco.loadAnns(ann_ids)
    
    mask = np.zeros((image_info['height'], image_info['width']), dtype=np.uint8)
    category_mask = np.zeros((image_info['height'], image_info['width']), dtype=np.uint8)

    # Combine all masks for this image
    for ann in annotations:
        if 'segmentation' in ann:
            category_id = ann['category_id']  # Unique ID for object category
            # Decode the segmentation mask
            if isinstance(ann['segmentation'], list):  # Polygon format
                for seg in ann['segmentation']:
                    pts = np.array(seg).reshape(-1, 2).astype(np.int32)
                    cv2.fillPoly(mask, [pts], color=1)  # Fill the mask polygon
                    cv2.fillPoly(category_mask, [pts], color=category_id)
            elif isinstance(ann['segmentation'], dict):  # RLE format
                rle = ann['segmentation']
                decoded_mask = coco.annToMask(ann)
                mask += decoded_mask  # Add binary mask
                category_mask[decoded_mask > 0] = category_id  # Assign category ID
    
    # Resize masks to match actual image dimensions
    if mask.shape[:2] != image_to_explain.shape[:2]:
        # print(f"Resizing masks: Annotated={mask.shape}, Actual={image_to_explain.shape[:2]}")
        mask = cv2.resize(mask, (image_to_explain.shape[1], image_to_explain.shape[0]), interpolation=cv2.INTER_NEAREST)
        category_mask = cv2.resize(category_mask, (image_to_explain.shape[1], image_to_explain.shape[0]), interpolation=cv2.INTER_NEAREST)
    return mask,category_mask,annotations

In [None]:
def load_groundtruth(coco,image_no,fixed_category=None):
    global ground_truth,weighted_ground_truth,annotations
    
    mask,ground_truth,annotations = create_gt(coco,image_no,category_name = fixed_category)
    weighted_ground_truth = gaussian_filter(ground_truth.astype(float), 16) * ground_truth
    ground_truth.dtype = 'bool'
    # ground_truth = ground_truth.dtype('bool')
    # ground_truth = cv2.imread(f'{fname}', cv2.IMREAD_COLOR)[:,:,::-1]
    # return ground_truth,weighted_ground_truth,annotations
    # ground_truth = cv2.resize(ground_truth, [224,224], interpolation=cv2.INTER_NEAREST)
    # ground_truth = ground_truth[:,:,0].astype(int) + 256*ground_truth[:,:,1].astype(int)
    # ground_truth[ ground_truth==1000 ] = 0
    
    

#     # mask, category_mask,annotations = create_gt(coco,fixed_choice)
# # plot_gt(image_to_explain,annotations)

# # plot_mask(category_mask,category_name = fixed_category)


In [None]:
# # mask, category_mask,annotations = create_gt(coco,fixed_choice,category_name = None)
# mask, category_mask,annotations = create_gt(coco,fixed_choice,category_name = 'tv')

# plot_mask(category_mask,category_name = fixed_category)

# IMPORTANT NOTEBOOK
- https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoDemo.ipynb

In [None]:
# https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoDemo.ipynb
# load and display instance annotations
# plt.imshow(I); plt.axis('off')
# annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
# anns = coco.loadAnns(annIds)
# coco.showAnns(anns)

In [None]:
# # get all images containing given categories, select one at random
# catIds = coco.getCatIds(catNms=['person','dog','skateboard']);
# imgIds = coco.getImgIds(catIds=catIds );
# imgIds = coco.getImgIds(imgIds = [324158])
# img = coco.loadImgs(imgIds[np.random.randint(0,len(imgIds))])[0]

In [None]:
# # load and display caption annotations
# annIds = coco_caps.getAnnIds(imgIds=img['id']);
# anns = coco_caps.loadAnns(annIds)
# coco_caps.showAnns(anns)
# plt.imshow(I); plt.axis('off'); plt.show()

### Masking Fucntion
Two Masking Function being used to generate Shapely values
- predict_yolo
- predict_yolo_masked

In [None]:
def predict_yolo(x,verbose=False):
    # predict_yolo_masked
    # res = model(x)[0]
    res = model.predict(source=x, verbose=verbose)[0]
    # probs = []
    # for res in results_:
    p = np.zeros(80)
    for cls,prob in zip(res.boxes.cls.cpu().numpy(), res.boxes.conf.cpu().numpy()):
        p[int(cls)] = prob

            # print(cls,prob)
        # probs.append(p)
    return np.array(p)

In [None]:
def predict_yolo_masked(masks,verbose=False):
    imglst_preds = []
    for mask in masks:
        preds = []
        for repl in background_image_set:
            # print(mask.shape, repl.shape)
            if len(mask.shape)!=3:
                mask3 = np.stack([mask,mask,mask], axis=2)
            else:
                print(mask.shape)
                mask3 = mask.copy()
            masked_image = np.where(mask3, image_to_explain, repl)
            preds.append(predict_yolo(masked_image,verbose=verbose))

        preds = np.mean(preds, axis=0)
        imglst_preds.append(preds)       
    
    return np.array(imglst_preds)

In [None]:
# def get_bg_values(predict_yolo, ground_truth, predicted_cls, class_names, verbose=False):
#     if verbose:
#         print(f'Input shape: {ground_truth.shape}, dtype: {ground_truth.dtype}, '
#               f'type: {type(ground_truth)}, max value: {np.max(ground_truth)}')
    
#     # Convert 2D mask to 3D if needed
#     if len(ground_truth.shape) == 2:
#         ground_truth = np.stack([ground_truth.astype(np.uint8) * 255]*3, axis=-1)
#         if verbose:
#             print(f'Converted to 3D: {ground_truth.shape}')
    
#     # Process ground truth
#     predicted_fG = predict_yolo(ground_truth)
#     if isinstance(predicted_fG, list):  # Handle batch prediction
#         predicted_fG = predicted_fG[0]
    
#     # Extract score for ground truth
#     if hasattr(predicted_fG, 'probs'):  # Classification model
#         f_G = float(predicted_fG.probs[predicted_cls])
#     else:  # Detection model
#         boxes = getattr(predicted_fG, 'boxes', None)
#         if boxes is not None:
#             cls_mask = boxes.cls == predicted_cls
#             f_G = float(boxes.conf[cls_mask].mean()) if any(cls_mask) else 0.0
#         else:
#             f_G = 0.0
    
#     # Process background mask (inverse of ground truth)
#     background_mask = np.logical_not(ground_truth if ground_truth.dtype == bool else ground_truth > 0)
#     background_mask = background_mask.astype(np.uint8) * 255
#     if len(background_mask.shape) == 2:
#         background_mask = np.stack([background_mask]*3, axis=-1)
    
#     predicted_fB = predict_yolo(background_mask)
#     if isinstance(predicted_fB, list):
#         predicted_fB = predicted_fB[0]
    
#     # Extract score for background
#     if hasattr(predicted_fB, 'probs'):  # Classification model
#         f_B = float(predicted_fB.probs[predicted_cls])
#     else:  # Detection model
#         boxes = getattr(predicted_fB, 'boxes', None)
#         if boxes is not None:
#             cls_mask = boxes.cls == predicted_cls
#             f_B = float(boxes.conf[cls_mask].mean()) if any(cls_mask) else 0.0
#         else:
#             f_B = 0.0
    
#     if verbose:
#         print(f'Class: {class_names[predicted_cls]} ({predicted_cls})')
#         print(f'Ground Truth Score (f_G): {f_G:.4f}')
#         print(f'Background Score (f_B): {f_B:.4f}')
#         print('-' * 50)
    
#     return f_G, f_B

In [None]:
def get_bg_values_yolo(f_masked, ground_truth, predicted_cls, class_names, verbose=False):
    """
    YOLOv8 version of get_bg_values
    """
    # Evaluate ground truth mask
    predicted_fG = f_masked(np.expand_dims(ground_truth, axis=0))[0]
    f_G = float(predicted_fG[predicted_cls])
    
    # Evaluate background mask
    background_mask = np.logical_not(ground_truth)
    predicted_fB = f_masked(np.expand_dims(background_mask, axis=0))[0]
    f_B = float(predicted_fB[predicted_cls])
    
    if verbose:
        print(f'Class: {class_names[predicted_cls]} ({predicted_cls})')
        print(f'Ground Truth Score (f_G): {f_G:.5f}')
        print(f'Background Score (f_B): {f_B:.5f}')
        print('-' * 50)
    
    return f_G, f_B

In [None]:
def load_image_to_explain(fname,params,bg_type='gray', load_gt=True):
    global predicted_fS, predicted_f0, predicted_cls, sorted_classes, f_S, f_0,sorted_probs
    global model_type,pretrained_model_type

    load_image(fname,params=params,bg_type=bg_type)
    h,w,_ = image_to_explain.shape
    # Foreground image to be explained  
    predicted_fS = predict_yolo(image_to_explain) 
    # predicted_fS = f(torch.unsqueeze(resnet50_preprocess(image_to_explain.astype(np.float32)/255.0).to(device), dim=0))[0]
    sorted_classes = np.flip(np.argsort(predicted_fS))
    sorted_probs   = predicted_fS[sorted_classes]
    predicted_cls = sorted_classes[0]
    f_S = float(predicted_fS[predicted_cls])
    #####################
    
    predicted_f0 = [predict_yolo(bkgnd.astype(np.float32)/255.0) for bkgnd in background_image_set]
    predicted_f0 = np.mean(predicted_f0,axis=0)
    f_0          = float(predicted_f0[predicted_cls])
    # if load_gt:
        # image_no = int(image_path.split('\\')[-1].split('.')[0])

        # load_groundtruth(coco,image_path,fixed_category=fixed_category)
    

## TEST ANNOTATION AVAILABILITY

In [None]:

load_image_to_explain(image_path,params, bg_type='gray')
# fixed_category = 'tv'
fixed_category = class_names[predicted_cls]
print('fixed_category: ',fixed_category)

# class_names[predicted_cls]

load_groundtruth(coco,image_path,fixed_category=fixed_category)
# load_groundtruth(coco,image_path)


In [None]:
def plot_img_gt_bg(save_fig=False, selected_ext='png',title=None,destroy_fig=False):
    fig,ax = plt.subplots(1,2+len(background_image_set))
    ax[0].imshow(image_to_explain)
    ax[0].set_title('Input')
    # ax[0].set_axis_off()
    ax[1].imshow(ground_truth, cmap='Reds')
    ax[1].set_title('Ground\ntruth')
    # ax[1].set_axis_off()
    for i,img in enumerate(background_image_set):
        ax[i+2].imshow(img.astype(np.uint8))
        ax[i+2].set_title(f'Repl {i}')
    
    for axx in ax:
        axx.set_axis_off()
    if save_fig:
        if title is None:
            suffix = f'{image_no}_{params.background_type}_bg_gt.{selected_ext}'
        else:
            suffix = f'{image_no}_{params.background_type}_heatmap_{title}.{selected_ext}'
                                                   
        plt.savefig(f'{paths.results_path_single}//{suffix}',dpi=200,transparent=True,bbox_inches='tight', pad_inches=0.02)
    if destroy_fig:
        plt.close(fig)
    plt.show()

In [None]:
plot_img_gt_bg(save_fig=True)

In [None]:
print('predicted_fS shape:\t', predicted_fS.shape)
print('predicted_cls:\t\t',    predicted_cls, class_names[predicted_cls])
print('f_S:\t\t\t', f_S)

In [None]:
results = model(image_to_explain)
# results[0].show()

# Funct: Plotting

In [None]:
def plot_gt(image_,annotation,fixed_category,save_fig=False,fig_size = (3,3),title=None,selected_ext='png',destroy_fig=False):
    fig = plt.figure(figsize=fig_size)
    plt.imshow(image_)
    plt.axis('off')

    # Add annotations as overlays
    for annotation in annotations:
        bbox = annotation['bbox']  # [x, y, width, height]
        category = coco.loadCats(annotation['category_id'])[0]['name']
        # Draw bounding boxes
        x, y, width, height = bbox
        plt.gca().add_patch(plt.Rectangle((x, y), width, height, edgecolor='green', facecolor='none', linewidth=2))
        # Add category labels
        plt.text(x, y - 5, category, color='white', fontsize=12, bbox=dict(facecolor='green', alpha=0.5))
    if save_fig:
        if title is None:
            suffix = f'{image_no}_bb_{fixed_category}.{selected_ext}'
        else:
            suffix = f'{image_no}_{params.background_type}_heatmap_{title}.{selected_ext}'
                                                        
        plt.savefig(f'{paths.results_path_single}//{suffix}',dpi=200,transparent=True,bbox_inches='tight', pad_inches=0.02)
    if destroy_fig:
        plt.close(fig)
    plt.show()
    
def plot_mask(category_mask,category_name=None,save_fig=False,fig_size = (3,3),title=None,selected_ext='png',destroy_fig=False):
    fig, ax = plt.subplots(figsize=fig_size)
    # plt.imshow(category_mask, cmap='gray')
    
    if category_name is None:
        cax = plt.imshow(category_mask, cmap='tab20', interpolation='nearest')
        plt.colorbar(cax,ax = ax, orientation='vertical')  # Add color bar for clarity
        
    else:
        plt.imshow(category_mask, cmap='gray', interpolation='nearest')
    ax.axis('off')
    if save_fig:
        if title is None:
            suffix = f'{image_no}_mask_{fixed_category}.{selected_ext}'
        else:
            suffix = f'{image_no}_{params.background_type}_heatmap_{title}.{selected_ext}'
                                                        
        plt.savefig(f'{paths.results_path_single}//{suffix}',dpi=200,transparent=True,bbox_inches='tight', pad_inches=0.02)
    if destroy_fig:
        plt.close(fig)
    plt.show()


In [None]:
# Perform inference on an image
def plot_predictions(image, results, category_name=None, filter_preds=True, line_thickness=2, exp_type='demo', save_fig=False, fig_size=(3, 3), title=None, selected_ext='png', destroy_fig=False):
    image_ = image.copy()
    save_path = paths.results_path_single if exp_type == 'demo' else paths.plots_path

    fig = plt.figure(figsize=fig_size)
    plt.imshow(image_)
    plt.axis('off')

    for result in results:
        # Access detected classes, confidences, and boxes
        class_ids = result.boxes.cls.cpu().numpy()  # Class IDs
        scores = result.boxes.conf.cpu().numpy()   # Confidence scores
        boxes = result.boxes.xyxy.cpu().numpy()    # Bounding boxes in xyxy format
        labels = model.names                       # Class labels (MS COCO classes)

        # Draw bounding boxes and labels on the image
        for box, class_id, score in zip(boxes, class_ids, scores):
            label = labels[int(class_id)]
            confidence = f"{score:.2f}"
            x1, y1, x2, y2 = map(int, box)  # Bounding box coordinates

            if filter_preds and category_name and label != category_name:
                continue

            # Draw bounding boxes
            width = x2 - x1
            height = y2 - y1
            plt.gca().add_patch(plt.Rectangle((x1, y1), width, height, edgecolor='darkred', facecolor='none', linewidth=line_thickness))

            # Add labels with transparency
            text = f"{label} {confidence}"
            plt.text(x1, y1 - 5, text, color='white', fontsize=12, bbox=dict(facecolor='darkred', alpha=0.5))

    if save_fig:
        suffix = title or "yolo_predictions"
        save_filename = f"{save_path}//{image_no}_{suffix}.{selected_ext}"
        plt.savefig(save_filename, dpi=200, transparent=True, bbox_inches='tight', pad_inches=0.02)

    if destroy_fig:
        plt.close(fig)

    plt.show()


In [None]:
def fun_plot_heatmaps(methods_,heatmaps,exp_type=None,draw_gt_contour=True,title = None,plot_title=False,ttl=None,fontsize=14,plot_colorbar=False,save_fig=True,selected_ext='png',dpi=100,transparent=True,destroy_fig=False):
    fig,axes = plt.subplots(1, len(methods_)+2, figsize=(2*(len(methods_)), 2))
    
    # save_path = paths.results_path_single if exp_type == 'demo' else paths.plots_path
    # save_path    = paths.results_path_selected if exp_type == 'selected' else paths.plotsIoU_path
    if exp_type == 'demo':
        save_path = paths.results_path_single
    elif exp_type == 'selected':
        save_path    = paths.results_path_selected
    else:
        save_path    = paths.plots_path

    selected_ext = 'svg' if exp_type == 'selected' else 'png'

    axes[0].imshow(image_to_explain)
    axes[0].set_xticks([]) ; axes[0].set_yticks([])
    axes[1].imshow(ground_truth, cmap='binary')
    axes[1].set_xticks([]) ; axes[1].set_yticks([])
    
    for ii, (n,c,_) in enumerate(methods_):
        ax = axes[ii+2]
        vmax = np.quantile(np.abs(heatmaps[n][0]), 0.99)
        ax.imshow(heatmaps[n][0], cmap=shap_bpt.shapley_values_colormap, vmin=-vmax, vmax=vmax)
        if draw_gt_contour:
            marked_h = mark_boundaries(np.tile((255,255,255,0), (heatmaps[n][0].shape[0],heatmaps[n][0].shape[1],1)), ground_truth, 
                             mode='thick', color=(0,0,0,1))
            ax.imshow(marked_h)
        ax.set_xticks([]) ; ax.set_yticks([])
        
    if ttl is not None:
        axes[0].set_yticklabels(str(ttl), fontsize=fontsize)
    plt.subplots_adjust(wspace=0.05,hspace=0.05)
    # plt.tight_layout(pad = 0.1)
    
    if save_fig:
        if title is None:
            suffix = f'{image_no}_{params.background_type}_heatmap.{selected_ext}'
        else:
            suffix = f'{image_no}_{params.background_type}_heatmap_{title}.{selected_ext}'
        # print('save_path',save_path)                                           
        plt.savefig(f'{save_path}//{suffix}',dpi=dpi,transparent=True,bbox_inches='tight', pad_inches=0.02)
    ##########################################
    for ii, (n,c,_) in enumerate(methods_):
        ax = axes[ii+2]
        if plot_title:
            ax.set_title(n, fontsize=fontsize)
    ##########################################
    if destroy_fig:
        plt.close(fig)
    plt.show()
    del vmax,marked_h
###################################################################################################################
props         = dict(boxstyle='round', facecolor='white', alpha=0.6, pad=0.2)
###################################################################################################################
def fun_plot_IoU(methods_,heatmaps,exp_type='demo',title = None,draw_gt_contour=False,plot_title=False,
                 fontsize=14,save_fig=True,selected_ext='png',dpi=100,transparent=True,destroy_fig=False,
                 text_x=11,text_y=47):
    # save_path    = paths.results_path_single if exp_type == 'demo' else paths.plotsIoU_path
    # save_path    = paths.results_path_selected if exp_type == 'selected' else paths.plotsIoU_path
    if exp_type == 'demo':
        save_path = paths.results_path_single
    elif exp_type == 'selected':
        save_path    = paths.results_path_selected
    else:
        save_path    = paths.plotsIoU_path

    selected_ext = 'svg' if exp_type == 'selected' else 'png'

    # {'X': array([          0,  3.2552e-06,  6.5104e-06, ...,     0.99999,     0.99999,           1]),
    #  'Y': array([ 0.00019069,  0.00038139,  0.00057208, ...,    0.017074,    0.017074,    0.017074]),
    #  'max_IoU': 2.4450657292618416e-05,
    #  'x_best': 0.018225911458333334,
    #  'auc_IoU': 0.06164667252339603}
    fig,axs = plt.subplots(1,  len(methods_)+2, figsize=(2*(len(methods_)), 2))
    axs[0].imshow(image_to_explain)
    axs[0].set_xticks([]) ; axs[0].set_yticks([])
    axs[1].imshow(ground_truth, cmap='binary')
    axs[1].set_xticks([]) ; axs[1].set_yticks([])
    
    for ii, (n,_,_) in tqdm(enumerate(methods_), desc='IoU',leave=False):
        ax = axs[ii+2]
        auc_IoU =  IoU[n]
        img, max_IoU = vis_IoU(heatmaps[n][0], auc_IoU['max_IoU_heatmap_threshold'], ground_truth), np.max(auc_IoU['Y'])
        
        ax.imshow(img)
        if draw_gt_contour:
            marked_h = mark_boundaries(np.tile((255,255,255,0), (heatmaps[n][0].shape[0],heatmaps[n][0].shape[1],1)), ground_truth, 
                             mode='thick', color=(0,0,0,1))
            ax.imshow(marked_h)
        ax.text(text_x,text_y, f'IoU:{max_IoU:.3}', fontsize=fontsize, bbox=props,weight='bold')
        ax.set_xticks([]); ax.set_yticks([])
        if plot_title:
            ax.set_title(n, fontsize=fontsize)
    plt.subplots_adjust(wspace=0.05,hspace=0.05)
    # plt.tight_layout(pad = 0.1)
    if save_fig:
        if title is None:
            suffix = f'{image_no}_{params.background_type}_IoU.{selected_ext}'
        else:
            suffix = f'{image_no}_{params.background_type}_IoU_{title}.{selected_ext}'

        plt.savefig(f'{save_path}//{suffix}',dpi=dpi,transparent=transparent,bbox_inches='tight', pad_inches=0.02)
    if destroy_fig:
        plt.close(fig)
    plt.show()
    del img,max_IoU
###################################################################################################################
# def fun_plot_performance(aa_= 'AA-100',bpt_='BPT-100',save_fig=True,exp_type = 'demo',fontsize=14,selected_ext='png', dpi =150,transparent=True):
#     save_path = paths.results_path_single if exp_type == 'demo' else paths.plotsIoU_path
    
#     fig,axes = plt.subplots(1,5, figsize=(10,2), sharex=True, sharey=True) #3,3  figsize=(8,2.2)
#     if len(background_tensors)==1:
#         # if model_type=='ideal':    
#         aucI_aa =  aucI[aa_]
#         aucD_aa =  aucD[aa_]
#         auc_IoU_aa =  IoU[aa_]
#         ttl = 'PE'
#         # else:
#         #     aucI_aa =  aucI['Partition-100']
#         #     aucD_aa =  aucD['Partition-100']
#         #     auc_IoU_aa =  IoU['Partition-100']
#     else:
#         aucI_aa =  aucI[aa_]
#         aucD_aa =  aucD[aa_]
#         auc_IoU_aa =  IoU[aa_]
#         ttl = 'AA'
#     aucI_bpt =  aucI[bpt_]
#     aucD_bpt =  aucD[bpt_]
#     auc_IoU_bpt =  IoU[bpt_]
    
#     for i in range(5):
#         ax = axes.flat[i]
#         if i==0: # insertion/regression
#             # Xa,Ya,Ma,La      = aucI_aa['xs'],aucI_aa['ys'],aucI_aa['ms'],aucI_aa['auc_reg']
#             title='$\\mathit{AUC}^{+}$'
#             Xa,Ya,Ma,La      = aucI_aa ['xs'],aucI_aa ['ys'],aucI_aa ['ms'],aucI_aa ['auc_reg']
#             Xb,Yb,Mb,Lb      = aucI_bpt['xs'],aucI_bpt['ys'],aucI_bpt['ms'],aucI_bpt['auc_reg']
#             # Xb,Yb,Mb,Lb = aucI_bpt[0], aucI_bpt[1], aucI_bpt[2], aucI_bpt[3]
#         elif i==1: # deletion/regression
#             title='$\\mathit{AUC}^{-}$'
#             Xa,Ya,Ma,La      = aucD_aa ['xs'],aucD_aa ['ys'],aucD_aa ['ms'],aucD_aa ['auc_reg']
#             Xb,Yb,Mb,Lb      = aucD_bpt['xs'],aucD_bpt['ys'],aucD_bpt['ms'],aucD_bpt['auc_reg']
            
#             # Xa,Ya,Ma,La = aucD_aa[0], aucD_aa[1], aucD_aa[2], aucD_aa[3]
#             # Xb,Yb,Mb,Lb = aucD_bpt[0], aucD_bpt[1], aucD_bpt[2], aucD_bpt[3]
#     #     elif i==2: # insertion/efficiency
#     #         title='$\\mathit{AUC}^{+}_\\mathrm{eff}$'
#     #         Xa,Ya,Ma,La = aucI_aa[0]*100, aucI_aa[1]-aucI_aa[2], aucI_aa[2], aucI_aa[4]
#     #         Xb,Yb,Mb,Lb = aucI_bpt[0]*100, aucI_bpt[1]-aucI_bpt[2], aucI_bpt[2], aucI_bpt[4]
#     #     elif i==3: # deletion/efficiency
#     #         title='$\\mathit{AUC}^{-}_\\mathrm{eff}$'
#     #         Xa,Ya,Ma,La = aucD_aa[0]*100, aucD_aa[1]-aucD_aa[2], aucD_aa[2], aucD_aa[4]
#     #         Xb,Yb,Mb,Lb = aucD_bpt[0]*100, aucD_bpt[1]-aucD_bpt[2], aucD_bpt[2], aucD_bpt[4]
#         elif i==2: # insertion/error
#             title='$\\mathit{MSE}^{+}$'
#             Xa,Ya,Ma,La      = aucI_aa ['xs'],aucI_aa ['ys'],aucI_aa ['ms'],aucI_aa ['auc_reg']
#             Xb,Yb,Mb,Lb      = aucI_bpt['xs'],aucI_bpt['ys'],aucI_bpt['ms'],aucI_bpt['auc_reg']

#             Xa,Ya,Ma,La = aucI_aa ['xs'], (aucI_aa ['ys']- aucI_aa  ['ms'])**2, aucI_aa  ['ms'], aucI_aa  ['auc_mse']
#             Xb,Yb,Mb,Lb = aucI_bpt ['xs'],(aucI_bpt ['ys']-aucI_bpt ['ms'])**2, aucI_bpt ['ms'], aucI_bpt ['auc_mse']
            
#             # Xa,Ya,Ma,La = aucI_aa[0], (aucI_aa[1]-aucI_aa[2])**2, aucI_aa[2], aucI_aa[5]
#             # Xb,Yb,Mb,Lb = aucI_bpt[0], (aucI_bpt[1]-aucI_bpt[2])**2, aucI_bpt[2], aucI_bpt[5]
#         elif i==3: # deletion/error
#             title='$\\mathit{MSE}^{-}$'
#             Xa,Ya,Ma,La = aucI_aa ['xs'], (aucI_aa  ['ys']-aucI_aa  ['ms'])**2, aucI_aa  ['ms'], aucI_aa  ['auc_mse']
#             Xb,Yb,Mb,Lb = aucI_bpt ['xs'],(aucI_bpt ['ys']-aucI_bpt ['ms'])**2, aucI_bpt ['ms'], aucI_bpt ['auc_mse']
            
#             # Xa,Ya,Ma,La = aucD_aa[0], (aucD_aa[1]-aucD_aa[2])**2, aucD_aa[2], aucD_aa[5]
#             # Xb,Yb,Mb,Lb = aucD_bpt[0], (aucD_bpt[1]-aucD_bpt[2])**2, aucD_bpt[2], aucD_bpt[5]
#         elif i==4: # IoU
#             title='$\\mathit{AU}{\!-\!}\\mathit{IoU}$'
#             Xa,Ya,Ma,La = auc_IoU_aa['X'], auc_IoU_aa['Y'], None, auc_IoU_aa['auc_IoU']     #[4]
#             Xb,Yb,Mb,Lb = auc_IoU_bpt['X'], auc_IoU_bpt['Y'], None, auc_IoU_bpt['auc_IoU']  #[4]
#             xma, yma = auc_IoU_aa['x_best'], np.max(auc_IoU_aa['Y'])
#             xmb, ymb = auc_IoU_bpt['x_best'], np.max(auc_IoU_bpt['Y'])
            
#     #     ymax = max(np.max(Ya), np.max(Yb))
        
#         Sa, Sb = ('\\textbf', '') if La<Lb else ('', '\\textbf')
#         if i in [0,4]:   Sa,Sb=Sb,Sa
#         ax.plot(Xa, Ya, c='#2465a6', label=f'{Sa}{{AA}} {La:.4}')
#         ax.fill_between(Xa, Ya, color='#2465a6', alpha=0.15, hatch='///')
#         # ax.scatter(Xa, Ya, c='black', s=5)
#         ax.plot(Xb, Yb, c='#9d2f4d', label=f'{Sb}{{BPT}} {Lb:.4}', alpha=0.80)
#         ax.fill_between(Xb, Yb, color='#9d2f4d', alpha=0.15, hatch='\\\\\\')
        
#         if i==4:
#             ax.scatter(xma, yma, s=40, color='blue')
#             ax.scatter(xmb, ymb, s=40, color='red')
#     #     ax.set_ylim(-0.05, round(ymax+0.1, 1))
#         # ax.scatter(Xb, Yb, c='black', s=5)
    
#     #     if i < 2:
#     #         ax.plot(Xa, Ma, c='blue', lw=1, ls=':')
#     #         ax.plot(Xb, Mb, c='red', lw=1, ls=':')
    
#         # if i in [0,1]:
#         if i < 2:
#             ax.axhline(f_S, ls='--', c='grey', zorder=0)
            
#         ax.axhline(0, c='lightgrey', zorder=0)
#         ax.legend(borderpad=0.2, labelspacing=0.1, loc='upper right' if i>=1 else 'lower right')#, bbox_to_anchor=(1,0))
#         ax.set_title(title, fontsize=fontsize)
    
#     # axes[0].set_xlabel('\% pixels inserted', fontsize=14)
#     # axes[1].set_xlabel('\% pixels deleted', fontsize=14)
#     # axes[2].set_xlabel('\% pixels inserted', fontsize=14)
#     # axes[3].set_xlabel('\% pixels deleted', fontsize=14)
    
#     plt.tight_layout()
#     suffix = f'{image_no}_{params.background_type}_PM.{selected_ext}'
#     if save_fig:
#         plt.savefig(f'{save_path}//{suffix}',dpi=dpi,transparent=transparent,bbox_inches='tight', pad_inches=0.02)
#         # plt.savefig(f'{paths.results_path_single}/five_metrics_{params.background_type}_2.png', dpi=150, bbox_inches='tight', pad_inches=0.02)
#     # plt.savefig(f'{results_path_single}/five_metrics_{background_type}_2.svg', dpi=150, bbox_inches='tight', pad_inches=0.02)
    
#     plt.show()

In [None]:
# def fun_plot_performance(aucI,aucD,IoU,budget_set='100',path=None,save_fig=True,fontsize=14):

#     fig,axes = plt.subplots(1,9, figsize=(20,2.5), sharex=True)#, sharey=True) #3,3  figsize=(8,2.2)
#     if len(background_tensors)==1:
#         if params.model_type=='ideal':    
#             aucI_aa =  aucI['AA-100']
#             aucD_aa =  aucD['AA-100']
#             auc_IoU_aa =  IoU['AA-100']
#             ttl = 'PE'
#         else:
#             aucI_aa =  aucI['AA-100']
#             aucD_aa =  aucD['AA-100']
#             auc_IoU_aa =  IoU['AA-100']
#     else:
#         aucI_aa =  aucI['AA-100']
#         aucD_aa =  aucD['AA-100']
#         auc_IoU_aa =  IoU['AA-100']
#         ttl = 'AA'
#     aucI_bpt =  aucI['BPT-100']
#     aucD_bpt =  aucD['BPT-100']
#     auc_IoU_bpt =  IoU['BPT-100']
    
#     for i in range(9):
#         ax = axes.flat[i]
#         ############  AUC
#         if i==0: # insertion/regression
#             # Xa,Ya,Ma,La      = aucI_aa['xs'],aucI_aa['ys'],aucI_aa['ms'],aucI_aa['auc_reg']
#             title='$\\mathit{AUC}^{+}$'
#             Xa,Ya,Ma,La      = aucI_aa ['xs'],aucI_aa ['ys'],aucI_aa ['ms'],aucI_aa ['auc']
#             Xb,Yb,Mb,Lb      = aucI_bpt['xs'],aucI_bpt['ys'],aucI_bpt['ms'],aucI_bpt['auc']
#             # Xb,Yb,Mb,Lb = aucI_bpt[0], aucI_bpt[1], aucI_bpt[2], aucI_bpt[3]
#         elif i==1: # deletion/regression
#             title='$\\mathit{AUC}^{-}$'
#             Xa,Ya,Ma,La      = aucD_aa ['xs'],aucD_aa ['ys'],aucD_aa ['ms'],aucD_aa ['auc']
#             Xb,Yb,Mb,Lb      = aucD_bpt['xs'],aucD_bpt['ys'],aucD_bpt['ms'],aucD_bpt['auc']


#         ############  AUC_Adj
#         elif i==2: # insertion/error
#             title='$\\mathit{AUC-Adj}^{+}$'
#             Xa,Ya,Ma,La      = aucI_aa ['xs'],aucI_aa ['y_adj'],aucI_aa ['ms'],aucI_aa ['auc_adj']
#             Xb,Yb,Mb,Lb      = aucI_bpt['xs'],aucI_bpt['y_adj'],aucI_bpt['ms'],aucI_bpt['auc_adj']
#         elif i==3: # deletion/error
#             title='$\\mathit{AUC-Adj}^{-}$'
#             Xa,Ya,Ma,La      = aucD_aa ['xs'],aucD_aa ['y_adj'],aucD_aa ['ms'],aucD_aa ['auc_adj']
#             Xb,Yb,Mb,Lb      = aucD_bpt['xs'],aucD_bpt['y_adj'],aucD_bpt['ms'],aucD_bpt['auc_adj']



#         ############  AUC_r
#         elif i==4: # insertion/error
#             title='$\\mathit{AUC-r}^{+}$'
#             Xa,Ya,Ma,La      = aucI_aa ['xs'],aucI_aa ['ysr'],aucI_aa ['ms'],aucI_aa ['auc_r']
#             Xb,Yb,Mb,Lb      = aucI_bpt['xs'],aucI_bpt['ysr'],aucI_bpt['ms'],aucI_bpt['auc_r']            
        
#         elif i==5: # deletion/error
#             title='$\\mathit{AUC-r}^{-}$'
#             Xa,Ya,Ma,La      = aucD_aa ['xs'],aucD_aa ['ysr'],aucD_aa ['ms'],aucD_aa ['auc_r']
#             Xb,Yb,Mb,Lb      = aucD_bpt['xs'],aucD_bpt['ysr'],aucD_bpt['ms'],aucD_bpt['auc_r']

        

#         ############  AUC_Adj_r


#         elif i==6: # insertion/error
#             title='$\\mathit{AUC-Adj-r}^{+}$'
#             Xa,Ya,Ma,La      = aucI_aa ['xs'],aucI_aa ['y_adjr'],aucI_aa ['ms'],aucI_aa ['auc_adjr']
#             Xb,Yb,Mb,Lb      = aucI_bpt['xs'],aucI_bpt['y_adjr'],aucI_bpt['ms'],aucI_bpt['auc_adjr']
#         elif i==7: # deletion/error
#             title='$\\mathit{AUC-Adj-r}^{-}$'
#             Xa,Ya,Ma,La      = aucD_aa ['xs'],aucD_aa ['y_adjr'],aucD_aa ['ms'],aucD_aa ['auc_adjr']
#             Xb,Yb,Mb,Lb      = aucD_bpt['xs'],aucD_bpt['y_adjr'],aucD_bpt['ms'],aucD_bpt['auc_adjr']
            
#         elif i==8: # IoU
#             # return {'X':X2, 'Y':IoU2,
#             #  'max_IoU_heatmap_threshold':Th[best_pt], 
#             # 'max_IoU_score':IoU2[best_pt], 'x_best':X2[best_pt], 'auc_IoU':auc_IoU}
#             title='$\\mathit{AU}{\!-\!}\\mathit{IoU}$'
#             Xa,Ya,Ma,La = auc_IoU_aa['X'], auc_IoU_aa['Y'], None, auc_IoU_aa['x_best']
#             Xb,Yb,Mb,Lb = auc_IoU_bpt['X'], auc_IoU_bpt['Y'], None, auc_IoU_bpt['x_best']

#             xma, yma = auc_IoU_aa['max_IoU_score'], np.max(auc_IoU_aa['Y'])
#             xmb, ymb = auc_IoU_bpt['max_IoU_score'], np.max(auc_IoU_bpt['Y'])
            
#         # print(i,'  len(Xb): ',len(Xb),'  len(Yb): ', len(Yb), len(Xb)==len(Yb), '  len(Xa): ',len(Xa),'  len(Ya): ', len(Ya), len(Xa)==len(Ya))
#     #     ymax = max(np.max(Ya), np.max(Yb))
        
#         Sa, Sb = ('\\textbf', '') if La<Lb else ('', '\\textbf')
#         if i in [0,2,4,6, 8]:   Sa,Sb=Sb,Sa
#         ax.plot(Xa, Ya, c='#2465a6', label=f'{Sa}{{AA}} {La:.4}')
#         ax.fill_between(Xa, Ya, color='#2465a6', alpha=0.15, hatch='///')
#         # ax.scatter(Xa, Ya, c='black', s=5)
#         ax.plot(Xb, Yb, c='#9d2f4d', label=f'{Sb}{{BPT}} {Lb:.4}', alpha=0.80)
#         ax.fill_between(Xb, Yb, color='#9d2f4d', alpha=0.15, hatch='\\\\\\')
        
#         if i==8:
#             ax.scatter(xma, yma, s=40, color='blue')
#             ax.scatter(xmb, ymb, s=40, color='red')

#         if i < 4:
#             ax.axhline(f_S, ls='--', c='grey', zorder=0)
            
#         ax.axhline(0, c='lightgrey', zorder=0)
#         ax.legend(borderpad=0.2, labelspacing=0.1, loc='upper right' if i>=1 else 'lower right')#, bbox_to_anchor=(1,0))
#         ax.set_title(title, fontsize=fontsize)
    
#     # axes[0].set_xlabel('\% pixels inserted', fontsize=14)
#     # axes[1].set_xlabel('\% pixels deleted', fontsize=14)
#     # axes[2].set_xlabel('\% pixels inserted', fontsize=14)
#     # axes[3].set_xlabel('\% pixels deleted', fontsize=14)
    
#     plt.tight_layout()
#     if save_fig:
#         if path is not None:
#             plt.savefig(f'{path}/{image_n}_five_metrics_{params.background_type}_2.png', dpi=150, bbox_inches='tight', pad_inches=0.02)
#         else:
#             print('path is None')
#     # plt.savefig(f'{path}/five_metrics_{background_type}_2.svg', dpi=150, bbox_inches='tight', pad_inches=0.02)
    
#     plt.show()

In [None]:
from matplotlib import rc
rc('text',usetex=True)
rc('text.latex', preamble='\\usepackage{color}')

def fun_plot_performance(aucI, aucD, IoU,
                         list_variants=['BPT-100', 'AA-100', 'LIME-100'],
                         budget_set='100',
                         path=None,
                         file_name='',
                         save_fig=True,
                         fontsize=14,
                         set_title=False,
                         plot_lime=False,
                         ttl=None,
                         layout='row'  # 'row' or '2rows'
                         ):
    import numpy as np
    import matplotlib.pyplot as plt

    # Order of plots (grouped by column for 2-row layout)
    plot_defs = [
        {'title': '$\\mathit{AUC}^{+}$',         'typee': 'auc',        'from': 'aucI'}, # 0
        {'title': '$\\mathit{AUC}^{-}$',         'typee': 'auc',        'from': 'aucD'}, # 1
        {'title': '$\\mathit{AUC}^{+}-clip$',     'typee': 'auc_clip',   'from': 'aucI'}, # 2
        {'title': '$\\mathit{AUC}^{-}-clip$',     'typee': 'auc_clip',   'from': 'aucD'}, # 3
        {'title': '$\\mathit{AUC^{+}-Adj}$',     'typee': 'auc_adj',    'from': 'aucI'}, # 4
        {'title': '$\\mathit{AUC^{-}-Adj}$',     'typee': 'auc_adj',    'from': 'aucD'}, # 5
        
        {'title': '$\\mathit{AU}{\!-\!}\\mathit{IoU}$', 'typee': 'auc_iou', 'from': 'iou'}, # 6 
        #####################################################################

        {'title': '$\\mathit{AUC}^{+}-r$',       'typee': 'auc_r',      'from': 'aucI'}, # 7
        {'title': '$\\mathit{AUC}^{-}-r$',       'typee': 'auc_r',      'from': 'aucD'}, # 8
        {'title': '$\\mathit{AUC}^{+}-clip-r$', 'typee': 'auc_clip_r', 'from': 'aucI'}, # 9
        {'title': '$\\mathit{AUC}^{-}-clip-r$', 'typee': 'auc_clip_r', 'from': 'aucD'}, # 10
        {'title': '$\\mathit{AUC^{+}-Adj-r}$',   'typee': 'auc_adj_r',  'from': 'aucI'}, # 11
        {'title': '$\\mathit{AUC^{-}-Adj-r}$',   'typee': 'auc_adj_r',  'from': 'aucD'}, # 12 
        
    ]
    save_version = list_variants[0].split('-')[-1]
    # Reorder plot_defs for 2-row layout: interleaved columns
    if layout == '2rows':
        top = plot_defs[0::2]
        bottom = plot_defs[1::2]
        plot_defs = [None]*(len(top)+len(bottom))
        plot_defs[::2] = top
        plot_defs[1::2] = bottom

    total_variant = len(plot_defs)

    # Layout config
    if layout == 'row':
        nrows, ncols = 1, total_variant
        fig_width, fig_height = 25, 2.5
    elif layout == '2rows':
        nrows, ncols = 2, int(np.ceil(total_variant / 2))
        fig_width, fig_height = 16, 5
    else:
        raise ValueError("layout must be either 'row' or '2rows'")

    fig, axes = plt.subplots(nrows, ncols, figsize=(fig_width, fig_height))
    axes = axes.flatten()

    # Alias
    aucI_aa = aucI[list_variants[0]]
    aucD_aa = aucD[list_variants[0]]
    auc_IoU_aa = IoU[list_variants[0]]

    aucI_bpt = aucI[list_variants[1]]
    aucD_bpt = aucD[list_variants[1]]
    auc_IoU_bpt = IoU[list_variants[1]]

    aucI_lime = aucI[list_variants[2]]
    aucD_lime = aucD[list_variants[2]]
    auc_IoU_lime = IoU[list_variants[2]]

    def get_params_auc(auc_, typee):
        return {
            'auc':        (auc_['xs'], auc_['ys'],     auc_['ms'], auc_['auc']),
            'auc_r':      (auc_['xs'], auc_['ysr'],    auc_['ms'], auc_['auc_r']),
            'auc_adj':    (auc_['xs'], auc_['y_adj'],  auc_['ms'], auc_['auc_adj']),
            'auc_adj_r':  (auc_['xs'], auc_['y_adjr'], auc_['ms'], auc_['auc_adjr']),
            'auc_clip':   (auc_['xs'], auc_['y_clip'], auc_['ms'], auc_['auc_clip']),
            'auc_clip_r': (auc_['xs'], auc_['y_clipr'],auc_['ms'], auc_['auc_clipr']),
        }.get(typee, (None, None, None, None))

    def get_params_iou(iou_, typee):
        if typee == 'auc_iou':

            # xmb, ymb = iou_[3], np.max(iou_[1])

            Xa,Ya,Ma,La = iou_['X'], iou_['Y'], None, iou_['x_best']
            # Xb,Yb,Mb,Lb = auc_IoU_bpt['X'], auc_IoU_bpt['Y'], None, auc_IoU_bpt['x_best']

            xma, yma = iou_['max_IoU_score'], np.max(iou_['Y'])
            # xmb, ymb = auc_IoU_bpt['max_IoU_score'], np.max(auc_IoU_bpt['Y'])

            # Xa, Ya, Ma, La = iou_[0], iou_[1], None, iou_[4]
            # xma, yma = iou_[3], np.max(iou_[1])
            return Xa, Ya, Ma, La, xma, yma
        return None, None, None, None, None, None

    for i, config in enumerate(plot_defs):
        ax = axes[i]
        title = config['title']
        typee = config['typee']
        source = config['from']

        # Select correct dict
        if source == 'aucI':
            Xa, Ya, Ma, La = get_params_auc(aucI_aa, typee)
            Xb, Yb, Mb, Lb = get_params_auc(aucI_bpt, typee)
            Xl, Yl, Ml, Ll = get_params_auc(aucI_lime, typee)
        elif source == 'aucD':
            Xa, Ya, Ma, La = get_params_auc(aucD_aa, typee)
            Xb, Yb, Mb, Lb = get_params_auc(aucD_bpt, typee)
            Xl, Yl, Ml, Ll = get_params_auc(aucD_lime, typee)
        elif source == 'iou':
            Xa, Ya, Ma, La, xma, yma = get_params_iou(auc_IoU_aa,  typee)
            Xb, Yb, Mb, Lb, xmb, ymb = get_params_iou(auc_IoU_bpt, typee)
            Xl, Yl, Ml, Ll, xml, yml = get_params_iou(auc_IoU_lime,   typee)

        # Boldness logic
        Sa, Sb = ('\\textbf', '') if La < Lb else ('', '\\textbf')
        if i in [0,2,4,6,8]: Sa, Sb = Sb, Sa

        # Plot
        ax.plot(Xa, Ya, c='#2465a6', label=f'{Sa}{{AA}} {La:.4f}')
        ax.fill_between(Xa, Ya, color='#2465a6', alpha=0.15, hatch='///')
        ##########
        ax.plot(Xb, Yb, c='#9d2f4d', label=f'{Sb}{{BPT}} {Lb:.4f}', alpha=0.80)
        ax.fill_between(Xb, Yb, color='#9d2f4d', alpha=0.15, hatch='\\\\\\')
        ##########
        if plot_lime:
            ax.plot(Xl, Yl, c="#01B0A7", label=f'{Sb}{{LM}} {Ll:.4f}', alpha=0.80)
            ax.fill_between(Xl, Yl, color='#01B0A7', alpha=0.15, hatch='\\\\\\')


        if typee == 'auc_iou':
            ax.scatter(xma, yma, s=40, color='blue')
            ax.scatter(xmb, ymb, s=40, color='red')

        if i < 4:
            ax.axhline(0.0, ls='--', c='grey', zorder=0)  # placeholder for f_S - f_0

        ax.axhline(0, c='lightgrey', zorder=0)
        ax.legend(borderpad=0.2, labelspacing=0.1, loc='upper right' if i>=1 else 'lower right')
        ax.set_title(title, fontsize=fontsize)
    axes[-1].imshow(image_to_explain); axes[-1].set_xticks([]); axes[-1].set_yticks([])
    axes[-1].set_title(f'{ttl.split(",")[0]} - {class_names[predicted_cls]} {f_S}')
    if set_title:
        plt.suptitle(f'{ttl}', fontsize=16)
    plt.tight_layout()

    if save_fig and path is not None:
        plt.savefig(f'{path}/{file_name}_{save_version}_paired_auc_metrics.png', dpi=150, bbox_inches='tight', pad_inches=0.02)
    plt.show()


In [None]:
def plot_perform_iou(methods_, save_fig=True, exp_type='demo', fontsize=12, selected_ext='png', dpi=150, transparent=True):
    save_path = paths.results_path_single if exp_type == 'demo' else paths.plotsIoU_path
    # fig,axs = plt.subplots(1,  len(methods_)+2, figsize=(2*(len(methods_)), 2))
    
    n_cols = len(methods_) + 2  # Number of subplots
    fig, axs = plt.subplots(1, n_cols, figsize=(2 * n_cols-2, 2), constrained_layout=True)  # Constrained layout for uniform spacing

    # Plot first subplot (image to explain)
    axs[0].imshow(image_to_explain)
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    # Plot second subplot (ground truth)
    axs[1].imshow(ground_truth, cmap='binary')
    axs[1].set_xticks([])
    axs[1].set_yticks([])

    # Loop through methods and plot IoU
    for ii, (n, _, _) in tqdm(enumerate(methods_), desc='IoU', leave=False):
        ax = axs[ii + 2]
        auc_IoU = IoU[n]
        Xa, Ya, Ma, La = auc_IoU['X'], auc_IoU['Y'], None, auc_IoU['auc_IoU']
        xma, yma = auc_IoU['x_best'], np.max(auc_IoU['Y'])

        ax.plot(Xa, Ya, c='#2465a6', label=f'{La:.4}')
        ax.fill_between(Xa, Ya, color='#2465a6', alpha=0.15, hatch='///')
        ax.scatter(xma, yma, s=40, color='blue')
        ax.axhline(f_S, ls='--', c='grey', zorder=0)
        ax.axhline(0, c='lightgrey', zorder=0)
        ax.legend(borderpad=0.2, labelspacing=0.1, loc='upper right', fontsize=fontsize)
        ax.set_title('$\\mathit{AU}{\!-\!}\\mathit{IoU}$', fontsize=fontsize)
        # ax.set_xticks([])
        # ax.set_yticks([])

    # Adjust spacing between subplots
    plt.subplots_adjust(wspace=0.1)

    # Save figure
    if save_fig:
        suffix = f'{image_no}_{params.background_type}_PM_iou.{selected_ext}'
        plt.savefig(f'{save_path}//{suffix}', dpi=dpi, transparent=transparent, bbox_inches='tight', pad_inches=0.02)
        print(save_path)


## FILTER METHODS for Fewer XAI for PAPER

In [None]:
def fun_plot_filtered_heatmaps(methods_, heatmaps, exp_type=None, draw_gt_contour=True, title=None, 
                               plot_title=False, ttl=None, fontsize=14, plot_colorbar=False, save_fig=True, 
                               selected_ext='svg', dpi=200, transparent=True, destroy_fig=False, filter_methods=None):
    
    # Filter methods based on user selection
    selected_methods = [(n, c, v) for n, c, v in methods_ if filter_methods is None or n in filter_methods]
    num_methods = len(selected_methods)

    # If no methods are selected, exit
    if num_methods == 0:
        print("No methods selected for filtering. Exiting.")
        return

    # Define save path
    if exp_type == 'demo':
        save_path = paths.results_path_single
    elif exp_type == 'selected':
        save_path = paths.results_path_selected
    else:
        save_path = paths.plots_path

    # Change file extension for selected experiments
    selected_ext = 'svg' if exp_type == 'selected' else 'png'

    # Set up subplots: input + ground truth + selected methods
    num_subplots = num_methods + 2
    fig, axes = plt.subplots(1, num_subplots, figsize=(2 * num_subplots, 2))

    # Plot Input Image
    axes[0].imshow(image_to_explain)
    axes[0].set_xticks([]) 
    axes[0].set_yticks([])

    # Plot Ground Truth
    axes[1].imshow(ground_truth, cmap='binary')
    axes[1].set_xticks([]) 
    axes[1].set_yticks([])

    # Plot Heatmaps for Selected Methods
    for i, (n, c, _) in enumerate(selected_methods):
        ax = axes[i + 2]
        vmax = np.quantile(np.abs(heatmaps[n][0]), 0.99)
        ax.imshow(heatmaps[n][0], cmap=shap_bpt.shapley_values_colormap, vmin=-vmax, vmax=vmax)

        # Draw Ground Truth Contours if enabled
        if draw_gt_contour:
            marked_h = mark_boundaries(
                np.tile((255, 255, 255, 0), (heatmaps[n][0].shape[0], heatmaps[n][0].shape[1], 1)), 
                ground_truth, mode='thick', color=(0, 0, 0, 1)
            )
            ax.imshow(marked_h)

        ax.set_xticks([]) 
        ax.set_yticks([])

        if plot_title:
            ax.set_title(n, fontsize=fontsize)

    # Add optional Y-axis label
    if ttl is not None:
        axes[0].set_yticklabels(str(ttl), fontsize=fontsize)

    plt.subplots_adjust(wspace=0.05, hspace=0.05)

    # Save figure if required
    if save_fig:
        if title is None:
            suffix = f'{image_no}_{params.background_type}_filtered_heatmap.{selected_ext}'
        else:
            suffix = f'{image_no}_{params.background_type}_filtered_heatmap_{title}.{selected_ext}'
        
        print('Saving to:', save_path)
        plt.savefig(f'{save_path}//{suffix}', dpi=dpi, transparent=transparent, bbox_inches='tight', pad_inches=0.02)

    if destroy_fig:
        plt.close(fig)

    plt.show()
    del vmax, marked_h
################################################################################################################



#### fun_plot_filtered_IoU

In [None]:
def fun_plot_filtered_IoU(methods_, heatmaps, exp_type='demo', title=None, draw_gt_contour=False, 
                           plot_title=False, fontsize=14, save_fig=True, selected_ext='svg', dpi=200, 
                           transparent=True, destroy_fig=False, filter_methods=None):
    
    # Filter methods based on user selection
    selected_methods = [(n, c, v) for n, c, v in methods_ if filter_methods is None or n in filter_methods]
    num_methods = len(selected_methods)

    # If no methods are selected, exit
    if num_methods == 0:
        print("No methods selected for filtering. Exiting.")
        return

    # Define save path
    if exp_type == 'demo':
        save_path = paths.results_path_single
    elif exp_type == 'selected':
        save_path = paths.results_path_selected
    else:
        save_path = paths.plotsIoU_path

    # Change file extension for selected experiments
    selected_ext = 'svg' if exp_type == 'selected' else 'png'
    
    # Set up subplots: input + ground truth + selected methods
    num_subplots = num_methods + 2
    fig, axs = plt.subplots(1, num_subplots, figsize=(2 * num_subplots, 2))

    # Plot Input Image
    axs[0].imshow(image_to_explain)
    axs[0].set_xticks([]) 
    axs[0].set_yticks([])

    # Plot Ground Truth
    axs[1].imshow(ground_truth, cmap='binary')
    axs[1].set_xticks([]) 
    axs[1].set_yticks([])

    # Plot IoU Heatmaps for Selected Methods
    for i, (n, _, _) in tqdm(enumerate(selected_methods), desc='IoU', leave=False):
        ax = axs[i + 2]
        auc_IoU = IoU[n]
        img, max_IoU = vis_IoU(heatmaps[n][0], auc_IoU['max_IoU_heatmap_threshold'], ground_truth), np.max(auc_IoU['Y'])

        ax.imshow(img)

        # Draw Ground Truth Contours if enabled
        if draw_gt_contour:
            marked_h = mark_boundaries(
                np.tile((255, 255, 255, 0), (heatmaps[n][0].shape[0], heatmaps[n][0].shape[1], 1)), 
                ground_truth, mode='thick', color=(0, 0, 0, 1)
            )
            ax.imshow(marked_h)

        ax.text(20, 60, f'IoU:{max_IoU:.3}', fontsize=fontsize, bbox=props, weight='bold')
        ax.set_xticks([]) 
        ax.set_yticks([])

        if plot_title:
            ax.set_title(n, fontsize=fontsize)

    plt.subplots_adjust(wspace=0.05, hspace=0.05)

    # Save figure if required
    if save_fig:
        if title is None:
            suffix = f'{image_no}_{params.background_type}_filtered_IoU.{selected_ext}'
        else:
            suffix = f'{image_no}_{params.background_type}_filtered_IoU_{title}.{selected_ext}'
        
        print('Saving to:', save_path)
        plt.savefig(f'{save_path}//{suffix}', dpi=dpi, transparent=transparent, bbox_inches='tight', pad_inches=0.02)

    if destroy_fig:
        plt.close(fig)

    plt.show()
    del img, max_IoU

#### plot_perform_filtered_iou

In [None]:
def plot_perform_filtered_iou(methods_, save_fig=True, exp_type='demo', fontsize=12, selected_ext='png', 
                              dpi=150, transparent=True, filter_methods=None):
    
    # Filter methods based on user selection
    selected_methods = [(n, c, v) for n, c, v in methods_ if filter_methods is None or n in filter_methods]
    num_methods = len(selected_methods)

    # If no methods are selected, exit
    if num_methods == 0:
        print("No methods selected for filtering. Exiting.")
        return

    # Define save path
    save_path = paths.results_path_single if exp_type == 'demo' else paths.plotsIoU_path

    # Number of subplots (input image + ground truth + filtered methods)
    num_subplots = num_methods + 2
    fig, axs = plt.subplots(1, num_subplots, figsize=(2 * num_subplots - 2, 2), constrained_layout=True)

    # Plot Input Image
    axs[0].imshow(image_to_explain)
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    # Plot Ground Truth
    axs[1].imshow(ground_truth, cmap='binary')
    axs[1].set_xticks([])
    axs[1].set_yticks([])

    # Plot IoU Curves for Selected Methods
    for i, (n, _, _) in tqdm(enumerate(selected_methods), desc='IoU', leave=False):
        ax = axs[i + 2]
        auc_IoU = IoU[n]
        Xa, Ya, La = auc_IoU['X'], auc_IoU['Y'], auc_IoU['auc_IoU']
        xma, yma = auc_IoU['x_best'], np.max(auc_IoU['Y'])

        ax.plot(Xa, Ya, c='#2465a6', label=f'{La:.4}')
        ax.fill_between(Xa, Ya, color='#2465a6', alpha=0.15, hatch='///')
        ax.scatter(xma, yma, s=40, color='blue')
        ax.axhline(f_S, ls='--', c='grey', zorder=0)
        ax.axhline(0, c='lightgrey', zorder=0)
        ax.legend(borderpad=0.2, labelspacing=0.1, loc='upper right', fontsize=fontsize)
        ax.set_title('$\\mathit{AU}{\!-\!}\\mathit{IoU}$', fontsize=fontsize)

    # Adjust subplot spacing
    plt.subplots_adjust(wspace=0.1)

    # Save figure if required
    if save_fig:
        suffix = f'{image_no}_{params.background_type}_filtered_PM_iou.{selected_ext}'
        plt.savefig(f'{save_path}//{suffix}', dpi=dpi, transparent=transparent, bbox_inches='tight', pad_inches=0.02)
        print('Saved at:', save_path)

    plt.show()


#### fun_plot_performance

In [None]:
def fun_plot_performance_filtered(selected_methods, aa_='AA-500', bpt_='BPT-500', save_fig=True, exp_type='demo', fontsize=14, selected_ext='png', dpi=150, transparent=True):
    save_path = paths.results_path_single if exp_type == 'demo' else paths.plotsIoU_path

    fig, axes = plt.subplots(1, 5, figsize=(10, 2), sharex=True, sharey=True)

    auc_data = {}  # Store selected methods' data

    for method in selected_methods:
        if method in aucI and method in aucD and method in IoU:
            auc_data[method] = {
                'aucI': aucI[method],
                'aucD': aucD[method],
                'IoU': IoU[method]
            }

    # Assign colors based on the method (blue for AA and red for BPT)
    method_colors = {aa_: '#2465a6', bpt_: '#9d2f4d'}

    for i in range(5):
        ax = axes.flat[i]
        title = ""
        
        for method, data in auc_data.items():
            # Assign the color based on the method (AA or BPT)
            color = method_colors.get(method, '#000000')  # Default to black if method is not AA or BPT
            label = f"{method} {data['aucI']['auc_reg']:.4}"  # Example metric label

            if i == 0:  # AUC+
                title = '$\\mathit{AUC}^{+}$'
                Xa, Ya, Ma, La = data['aucI']['xs'], data['aucI']['ys'], data['aucI']['ms'], data['aucI']['auc_reg']
            elif i == 1:  # AUC-
                title = '$\\mathit{AUC}^{-}$'
                Xa, Ya, Ma, La = data['aucD']['xs'], data['aucD']['ys'], data['aucD']['ms'], data['aucD']['auc_reg']
            elif i == 2:  # MSE+
                title = '$\\mathit{MSE}^{+}$'
                Xa, Ya, Ma, La = data['aucI']['xs'], (data['aucI']['ys'] - data['aucI']['ms']) ** 2, data['aucI']['ms'], data['aucI']['auc_mse']
            elif i == 3:  # MSE-
                title = '$\\mathit{MSE}^{-}$'
                Xa, Ya, Ma, La = data['aucD']['xs'], (data['aucD']['ys'] - data['aucD']['ms']) ** 2, data['aucD']['ms'], data['aucD']['auc_mse']
            elif i == 4:  # IoU
                title = '$\\mathit{AU}{\!-\!}\\mathit{IoU}$'
                Xa, Ya, Ma, La = data['IoU']['X'], data['IoU']['Y'], None, data['IoU']['auc_IoU']

            ax.plot(Xa, Ya, c=color, label=label, alpha=0.8)
            ax.fill_between(Xa, Ya, color=color, alpha=0.15)

        ax.set_title(title, fontsize=fontsize)
        ax.legend(loc='upper right' if i >= 1 else 'lower right')

    plt.tight_layout()

    if save_fig:
        suffix = f"{selected_ext}"
        plt.savefig(f"{save_path}/performance_plot_{suffix}", dpi=dpi, transparent=transparent, bbox_inches="tight", pad_inches=0.02)

    plt.show()


## Predict Image

In [None]:
# print(image_to_explain.shape)
results = model.predict(image_to_explain,verbose=True)
# results

In [None]:
# Extract top-k classes
def get_top_k_classes(results, k=3):
    # Predictions contain bounding boxes with associated scores and classes
    class_names = model.names
    top_k = []
    for result_ in results:
        boxes = result_.boxes  # List of bounding boxes
        # Iterate through the bounding boxes
        for box in boxes:
            score = box.conf  # Confidence score
            class_id = box.cls  # Class ID
            top_k.append((int(class_id.item()),class_names[int(class_id.item())], score.item()))

        # Sort by confidence and take top-k
        top_k = sorted(top_k, key=lambda x: x[2], reverse=True)[:k]
    return top_k

In [None]:
top_k_classes = get_top_k_classes(results, k=params.top_classes_n)
print("Top-k Classes (Class ID, Confidence):", top_k_classes)

print([class_names[int(cls)] for cls, _,_ in top_k_classes])
print('-'*120)

print('top_k_classes \t', top_k_classes)

# PLOT PREDICTION

In [None]:
# plot_predictions(image_to_explain,results,category_name = 'person')
# plot_predictions(image_to_explain,results,category_name = 'tv')
plot_predictions(image_to_explain,results,category_name = fixed_category,fig_size=(5,5), save_fig=True)

# plot_predictions(image_to_explain,results, filter_preds=False)
# plot_predictions(image_to_explain,results, filter_preds=True)

# FUNC:   MASKING

In [None]:
# image_to_explain_tensor = image

def predict_yolo_masked(masks):
    imglst_preds = []
    for mask in masks:
        preds = []
        for repl in background_image_set:
            # print(mask.shape, repl.shape)
            if len(mask.shape)==2:
                mask3 = np.stack([mask,mask,mask], axis=2)
            else:
                mask3 = mask.copy()
            masked_image = np.where(mask3, image_to_explain, repl)
            preds.append(predict_yolo(masked_image))

        preds = np.mean(preds, axis=0)
        imglst_preds.append(preds)       
    
    return np.array(imglst_preds)

def predict_yolo(x,verbose=False):
    # start_time = time.time()
    # predict_yolo_masked
    # res = model(x)[0]
    x = torch.from_numpy(x).to(device)
    x = x.cpu().numpy()
    # x = x.to(device)
    
    res = model.predict(x,verbose=verbose)[0]
    # probs = []
    # for res in results_:
    p = np.zeros(80)
    for cls,prob in zip(res.boxes.cls.cpu().numpy(), res.boxes.conf.cpu().numpy()):
        p[int(cls)] = prob

            # print(cls,prob)
        # probs.append(p)
    # computation_time = time.time() - start_time  # End timer
    torch.cuda.empty_cache()
    return np.array(p)

    # computation_time = time.time() - start_time  # End timer
# predict_yolo_masked(np.zeros((2,426, 640), dtype=bool))


## TEST BLACKBOX FUCNTION

In [None]:
print(image_to_explain.shape)
bg_ls = np.zeros((100,image_to_explain.shape[0],image_to_explain.shape[1],image_to_explain.shape[2]))
# bg_ls = np.random.randint(0, 255, (50,426, 640, 3), dtype=np.uint8)

print(bg_ls.shape)
pred = predict_yolo_masked(bg_ls)
print(pred.shape)

torch.cuda.empty_cache()

# XAI

In [None]:
# num_explained_classes = 2
# batch_size            = 32
# del batch_size

## 1. BPT

In [None]:
import shap_bpt

print('shap_bpt version:',shap_bpt.__version__)
def get_bpt_heatmaps(num_explained_classes=1,num_samples=100,verbose=False,batch_size=64):
    explainer = shap_bpt.Explainer(predict_yolo_masked, image_to_explain, num_explained_classes=num_explained_classes, verbose=verbose)
    shap_values_bpt = explainer.explain_instance(num_samples, method='BPT',batch_size=batch_size)
    del explainer
    return shap_values_bpt

## 2. Axis-Aligned

In [None]:
def get_aa_heatmaps(num_explained_classes=1,num_samples=100,verbose=False,batch_size=64,verbose_plot=False):
    explainer = shap_bpt.Explainer(predict_yolo_masked, image_to_explain, num_explained_classes=num_explained_classes, verbose=verbose)
    shap_values_aa = explainer.explain_instance(num_samples, method='AA',  batch_size=batch_size, verbose_plot=verbose_plot)
    del explainer
    return shap_values_aa

### Partition SHAP

In [None]:
import shap as shap

def shap_predict(img):
    # assert model_type=='real'
    pred = predict_yolo_masked(img)
    # print(img.shape, pred.shape)
    return pred
    # return f(torch.Tensor(img).permute(0,3,1,2).to(device))

def get_pe_heatmaps(num_samples=1000,batch_size=32):
    assert len(background_tensors)==1
    shapMasker     = shap.maskers.Image(background_tensors[0].detach().cpu().numpy(), image_to_explain.shape) # .permute(1,2,0)
    shapPartExpl   = shap.Explainer(shap_predict, shapMasker, 
                                    output_names=[class_names[i] for i in range(len(class_names))], 
                                    algorithm="partition")
    shap_values_pe = shapPartExpl(np.expand_dims(image_to_explain, axis=0),  # permute(1,2,0) np.expand_dims(image_to_explain_preproc.detach().cpu().numpy(), 0)
                                  max_evals=num_samples, batch_size=batch_size, 
                                  outputs=shap.Explanation.argsort.flip[:4])
    shap_values_pe = np.moveaxis(np.sum(shap_values_pe.values[0], axis=2), 2, 0)
    del shapMasker,shapPartExpl
    return shap_values_pe

### GradientExplainer SHAP
 - https://shap-lrjball.readthedocs.io/en/latest/generated/shap.GradientExplainer.html

In [None]:
def get_gradexpl_heatmap(use_abs=True,num_explained_classes=2):
    torch.cuda.empty_cache()
    # if pretrained_model_type  == 'vit_LRP':
    #     return np.ones((1,224,224))
    
    e = shap.GradientExplainer(model.predict(), background_tensors)
    return 
    expl = e.shap_values(torch.unsqueeze(image_to_explain_tensor, dim=0),
                         nsamples=20, ranked_outputs=num_explained_classes)
    heatmaps = np.sum(expl[0], axis=1)[:,:,:,0]
    for i, clsid in enumerate(sorted_classes[:num_explained_classes]):
        if use_abs:
            heatmaps[i] = np.abs(heatmaps[i])
        heatmaps[i] = heatmaps[i] * (predicted_fS[clsid] - predicted_f0[clsid]) / np.sum(heatmaps[i])
        torch.cuda.empty_cache()
    del e,expl
    return heatmaps

## 5. LIME
 - https://github.com/marcotcr/lime

In [None]:
from lime import lime_image
from lime.wrappers.scikit_image import SegmentationAlgorithm
from skimage.segmentation import mark_boundaries

In [None]:
def get_segment_number(image, md):
    segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, max_dist=md, ratio=0.2, random_seed=1234) 
    segments = segmentation_fn(image)
    return len(np.unique(segments))

def search_segment_number(image, target_seg_no, init_max_dist=100, max_iter=3):
    lmd, rmd = 0, init_max_dist
    lsn = get_segment_number(image, lmd)
    rsn = get_segment_number(image, rmd)
    niter = 0
    while niter<max_iter and rsn!=target_seg_no:
        niter += 1
        mmd = (lmd + rmd) / 2.0
        msn = get_segment_number(image, mmd)
        if msn <= target_seg_no <= lsn:
            rsn, rmd = msn, mmd
        else:
            lsn, lmd = msn, mmd
    return rmd

In [None]:
def format_lime_heatmaps(segments, expl):
    global predicted_fS, predicted_f0
    class_heatmaps = []
    eps = 1e-8
    # print(expl.top_labels)
    for clsid in expl.top_labels:
        heatmap = np.zeros_like(segments, dtype=np.float32)
        for segm, importance in expl.local_exp[clsid]:
            heatmap[ segments==segm ] += importance 
        
        heatmap = heatmap * (predicted_fS[clsid] - predicted_f0[clsid]) / (np.sum(heatmap) + eps)
        class_heatmaps.append(heatmap)
    return np.array(class_heatmaps)

In [None]:
def lime_predict(img):
    # global model_type
    # if model_type=='ideal':        
        # return f_masked_ideal(torch.Tensor(img).permute(0,3,1,2).cpu().numpy() [:,0,:,:])
    # else:
    # return predict_yolo_masked(torch.Tensor(img).permute(0,3,1,2).to(device))
    return predict_yolo_masked(img)
    
    # return f(torch.Tensor(img).permute(0,3,1,2).to(device))

def get_lime_heatmaps(num_segments=100, batch_size=32,num_samples=1000,num_explained_classes=2, use_stratification=False,verbose=False):
    if verbose:
        start_time = datetime.now()
    segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, 
                                            max_dist=search_segment_number(image_to_explain, num_segments), 
                                            # max_dist=30,
                                            ratio=0.2, random_seed=1234) 
    segments = segmentation_fn(image_to_explain)
    def segments_getter(img):
        return segments
    if verbose:
        end_time = datetime.now()
        print(f'{num_segments} {"segments Took":<20} : {"{}".format(end_time - start_time)} and computed {num_segments} segments')
    num_segments = len(np.unique(segments))
    heatmap_list = []
    for bg_c in background_image_preproc_set:
        torch.cuda.empty_cache()
        lime_explainer = lime_image.LimeImageExplainer(random_state=1234)
        lime_expl = lime_explainer.explain_instance(image_to_explain_preproc,#.permute(1,2,0).detach().numpy(), 
                                                    lime_predict,
                                                    batch_size = batch_size,
                                                    top_labels=num_explained_classes,
                                                    # use_stratification=use_stratification,
                                                    segmentation_fn=segments_getter,
                                                    progress_bar = verbose,
                                                    hide_color=bg_c.permute(1,2,0).detach().numpy(), 
                                                    num_samples=num_samples)
        if isinstance(lime_expl, tuple):
            lime_expl = lime_expl[2]
        heatmap_list.append(format_lime_heatmaps(segments, lime_expl))
        torch.cuda.empty_cache()
        del lime_expl,lime_explainer
    del segments
    return np.mean(heatmap_list, axis=0)

In [None]:
# # https://arxiv.org/pdf/2305.20052.pdf
# # https://github.com/chasewalker26/Integrated-Decision-Gradients/tree/main
# import sys
# paths_repo = 'E:/PHD/datacloud_data/repos' 
# sys.path.append(f'{paths_repo}/Integrated-Decision-Gradients/util/attribution_methods')

# from saliencyMethods import IDG
# def get_idg_heatmaps(use_abs=True):
#     # global predicted_fS, predicted_f0
#     steps = 50
#     batch_si = 25 # default 25
#     baseline = 0
#     heatmaps = []
#     torch.cuda.empty_cache()
#     for clsid in sorted_classes[:num_explained_classes]:
#         heatmap = idg = IDG(torch.unsqueeze(image_to_explain_tensor, dim=0), model, 
#                             steps, batch_si, baseline, device, predicted_cls)
#         heatmap = idg.detach().cpu().numpy()
#         heatmap = np.mean(heatmap, axis=0) # reduce to one attribution per pixel
#         # normalize
#         if use_abs:
#             heatmap = np.abs(heatmap)
#         # heatmap -= np.min(heatmap)
#         heatmap = heatmap * (predicted_fS[clsid] - predicted_f0[clsid]) / np.sum(heatmap)
#         heatmaps.append(heatmap)
#         del heatmap
#     torch.cuda.empty_cache()
#     del steps,baseline,batch_si
#     return np.array(heatmaps)

### GradientShap

In [None]:
# # https://github.com/jacobgil/pytorch-grad-cam
# # https://captum.ai/api/gradient_shap.html
# from captum.attr import GradientShap
# def get_gradshap_captum_heatmaps(n_samples=50):
#     if params.pretrained_model_type == 'vit_LRP':
#         gradient_shap = rand_img_dist = clsid = None
#         return np.ones((1,224,224))
#     torch.cuda.empty_cache()
#     rand_img_dist = torch.cat([image_to_explain_tensor.unsqueeze(0) * 0, image_to_explain_tensor.unsqueeze(0) * 1])
#     gradient_shap = GradientShap(model)
    
#     heatmaps = []
#     for clsid in sorted_classes[:num_explained_classes]:
#         clsid = torch.tensor(clsid)
        
#         heatmap = gradient_shap.attribute(image_to_explain_tensor.unsqueeze(0),
#                                           n_samples=n_samples,
#                                           stdevs=0.0001,
#                                           baselines=rand_img_dist,
#                                           target=clsid)
#         heatmap = np.sum(heatmap.squeeze().cpu().detach().numpy(), axis=0)
#         heatmap = np.abs(heatmap)
#         heatmap = heatmap * (predicted_fS[clsid] - predicted_f0[clsid]) / np.sum(heatmap)
#         heatmaps.append(heatmap)
#         del heatmap
#     torch.cuda.empty_cache()
#     del gradient_shap,rand_img_dist,clsid
#     return np.array(heatmaps)

# COMBINE XAI

In [None]:
#     name,                 color,                  functor
verbose = False

methods = [
    ('BPT-100',         'xkcd:light pink',     partial(get_bpt_heatmaps, num_samples=100,batch_size=64,verbose=verbose)),
    ('BPT-500',         'xkcd:bright pink',     partial(get_bpt_heatmaps, num_samples=500,batch_size=64,verbose=verbose)),
    ('BPT-1000',         'xkcd:deep pink',     partial(get_bpt_heatmaps, num_samples=1000,batch_size=64,verbose=verbose)),
    ]
# methods_pe = [ # if single background
#     ('Partition-1000',   'xkcd:bluish',     partial(get_pe_heatmaps, num_samples=1000)),
#     ('Partition-2000',   'xkcd:cerulean',   partial(get_pe_heatmaps, num_samples=2000)),
#     ('Partition-5000',   'xkcd:soft blue', partial(get_pe_heatmaps, num_samples=5000))
#     ]
methods_aa = [ # if multiple backgrounds
    ('AA-100', 'xkcd:bright blue',     partial(get_aa_heatmaps, num_samples=100,batch_size=64, verbose=verbose)),
    ('AA-500', 'xkcd:bright blue',     partial(get_aa_heatmaps, num_samples=500,batch_size=64, verbose=verbose)),
    ('AA-1000', 'xkcd:bright blue',    partial(get_aa_heatmaps, num_samples=1000,batch_size=64, verbose=verbose)),
]
# methods_aa_huge = [
#     ('AA-5000', 'xkcd:bright blue',      partial(get_aa_heatmaps, num_samples=5000, verbose=verbose)),
#     # ('AA-10000', 'xkcd:bright blue',     partial(get_aa_heatmaps, num_samples=10000,batch_size=64, verbose=verbose)),
# ]
# methods_lime = [
#     ('LIME-50',        'xkcd:bright lime',     partial(get_lime_heatmaps, num_segments=50, num_samples=50*5,verbose=verbose)),
#     ('LIME-100',        'xkcd:kermit green',   partial(get_lime_heatmaps, num_segments=100, num_samples=100*5,verbose=verbose)),
#     ('LIME-200',        'xkcd:dark lime green',partial(get_lime_heatmaps, num_segments=200, num_samples=200*5,verbose=verbose)),
#     ]


methods_lime = [
    # ('LIME-50',        'xkcd:bright lime',     partial(get_lime_heatmaps, num_segments=50, num_samples=50*5,verbose=verbose)),
    # ('LIME-100',        'xkcd:bright lime',     partial(get_lime_heatmaps, num_segments=100, num_samples=100*5,verbose=verbose)),
    # ('LIME-200',        'xkcd:dark lime green',partial(get_lime_heatmaps, num_segments=200, num_samples=200*5,verbose=verbose)),
###########################################################
## num_segments = num_samples / 10
    # partial(get_aa_heatmaps, num_samples=100, verbose=verbose))
    ('LIME-100',        'xkcd:bright lime',     partial(get_lime_heatmaps, num_segments=100/10, num_samples=100,verbose=verbose)),
    ('LIME-500',        'xkcd:kermit green',   partial(get_lime_heatmaps, num_segments=500/10, num_samples=500,verbose=verbose)),
    ('LIME-1000',        'xkcd:dark lime green',partial(get_lime_heatmaps, num_segments=1000/10, num_samples=1000,verbose=verbose))
    ]

# methods_cam = [
#     # ('aIDG',         'xkcd:indigo',          partial(get_idg_heatmaps, use_abs=True)),
#     ('aGradExpl',    'red',                  partial(get_gradexpl_heatmap, use_abs=True))
#     ]

# methods_ShapGradE = [
#     ('ShapGradE',     'xkcd:camel',            partial(get_gradshap_captum_heatmaps, n_samples=20)),    
# ]

# # methods_limeSAM = [
# #     ('LIMESAM',        'xkcd:bright lime',     partial(get_limeSAM_heatmaps, num_samples=500, verbose=verbose)),
# #     ]



# methods_LRP_ViT = [
#         ('LRP',        'xkcd:bright lime',     get_heatmaps_LRP_ViT),
    
#     ]
# methods_LRP = [
#     ('LRP',        'xkcd:bright lime',     get_LRP_captum_heatmaps),
    
#     ]

# method_gradcam_vit_heatmaps = [
#     ('GradCAM',     'xkcd:camel',            get_gradcam_vit_heatmaps),
#     ]

# if pretrained_model_type  == 'swin_trans_vit' or pretrained_model_type == 'vit_LRP':
#     methods_gradcam = [
#     ('GradCAM',     'xkcd:camel',            get_gradcam_vit_heatmaps),
#     ]
# else:
#     methods_gradcam = [
#     ('GradCAM',     'xkcd:camel',            get_gradcam_heatmaps),
#     ]
    

# methods += methods_pe if len(background_tensors) == 1 else methods_aa
methods +=methods_aa
# methods += methods_aa_huge
methods += methods_lime


# if pretrained_model_type  == 'swin_trans_vit' or pretrained_model_type == 'vit_LRP':
#     methods += methods_LRP_ViT
# else:
#     methods += methods_LRP
    
# if pretrained_model_type  == 'swin_trans_vit' or pretrained_model_type == 'vit_LRP':
#     methods += method_gradcam_vit_heatmaps
# else:
#     methods += methods_gradcam

# methods += methods_cam

# methods += methods_ShapGradE
for n,_,_ in methods:
    print(n)

# FUNC:   XAI 

In [None]:
num_explained_classes  =  1         # No of Top Predicted Classes

In [None]:
# #     name,                 color,                  functor
# verbose = False
# # if pretrained_model_type  == 'swin_trans_vit':
# #     verbose = True

# methods = [
#     ('BPT-100',         'xkcd:light pink',     partial(get_bpt_heatmaps, num_samples=100,verbose=verbose)),
#     ('BPT-500',         'xkcd:bright pink',     partial(get_bpt_heatmaps, num_samples=500,verbose=verbose)),
#     ('BPT-1000',         'xkcd:deep pink',     partial(get_bpt_heatmaps, num_samples=1000,verbose=verbose)),
#     ]
# methods_pe = [ # if single background
#     ('Partition-100',   'xkcd:bluish',     partial(get_pe_heatmaps, num_samples=100)),
#     ('Partition-500',   'xkcd:cerulean',   partial(get_pe_heatmaps, num_samples=500)),
#     ('Partition-1000',   'xkcd:soft blue', partial(get_pe_heatmaps, num_samples=1000))
#     ]
# methods_aa = [ # if multiple backgrounds
#     ('AA-100', 'xkcd:bright blue',     partial(get_aa_heatmaps, num_samples=100, verbose=verbose)),
#     ('AA-500', 'xkcd:bright blue',     partial(get_aa_heatmaps, num_samples=500, verbose=verbose)),
#     ('AA-1000', 'xkcd:bright blue',    partial(get_aa_heatmaps, num_samples=1000, verbose=verbose)),
# ]
# methods_aa_huge = [
#     ('AA-5000', 'xkcd:bright blue',      partial(get_aa_heatmaps, num_samples=5000, verbose=verbose)),
#     ('AA-10000', 'xkcd:bright blue',     partial(get_aa_heatmaps, num_samples=10000, verbose=verbose)),
# ]
# methods_lime = [
#     ('LIME-100',        'xkcd:bright lime',     partial(get_lime_heatmaps, num_segments=100/5, num_samples=100,verbose=verbose)),
#     ('LIME-500',        'xkcd:kermit green',   partial(get_lime_heatmaps, num_segments=500/5, num_samples=500,verbose=verbose)),
#     ('LIME-1000',        'xkcd:dark lime green',partial(get_lime_heatmaps, num_segments=1000/5, num_samples=1000,verbose=verbose))
#     ]

# methods_cam = [
#     ('aIDG',         'xkcd:indigo',          partial(get_idg_heatmaps, use_abs=True)),
#     ('aGradExpl',    'red',                  partial(get_gradexpl_heatmap, use_abs=True))
#     ]

# methods_ShapGradE = [
#     ('ShapGradE',     'xkcd:camel',            partial(get_gradshap_captum_heatmaps, n_samples=20)),    
# ]

# # methods_limeSAM = [
# #     ('LIMESAM',        'xkcd:bright lime',     partial(get_limeSAM_heatmaps, num_samples=500, verbose=verbose)),
# #     ]



# # methods_LRP_ViT = [
# #         ('LRP',        'xkcd:bright lime',     get_heatmaps_LRP_ViT),
    
# #     ]
# # methods_LRP = [
# #     ('LRP',        'xkcd:bright lime',     get_LRP_captum_heatmaps),
    
# #     ]

# # method_gradcam_vit_heatmaps = [
# #     ('GradCAM',     'xkcd:camel',            get_gradcam_heatmaps_for_swin_vit), #get_gradcam_vit_heatmaps),
# #     ]

# # if params.pretrained_model_type == 'vit_LRP':
# #     methods_gradcam = [
# #     ('GradCAM',     'xkcd:camel',            get_gradcam_vit_heatmaps),
# #     ]
# # else:
# #     methods_gradcam = [
# #     ('GradCAM',     'xkcd:camel',            get_gradcam_heatmaps),
# #     ]
    

# methods += methods_pe if len(background_tensors) == 1 and params.model_type!='ideal' else methods_aa
# # if params.model_type == 'ideal':
# #     methods += methods_aa_huge

# methods += methods_lime
# # methods += methods_limeSAM
# ######## CAM   #############    
# # if pretrained_model_type  != 'swin_trans_vit' :

# # if params.pretrained_model_type  == 'swin_trans_vit' or params.pretrained_model_type == 'vit_LRP':
# #     methods += methods_LRP_ViT
# # else:
# #     methods += methods_LRP

# # if params.pretrained_model_type  == 'swin_trans_vit' or params.pretrained_model_type == 'vit_LRP':
# #     methods += method_gradcam_vit_heatmaps
# # else:
# #     methods += methods_gradcam

# if params.model_type != 'ideal':
#     methods += methods_cam

# methods += methods_ShapGradE
# for n,_,functt in methods:
#     print(n, functt)

In [None]:
# import numpy as np
# # import random

# rng = np.random.default_rng(12345)
# rng.normal()

# Updated saliency_to_auc

In [None]:
# def saliency_to_auc(heatmap, batch_size=4, method='del', num_samples=101, add_noise=True):
#     xs, ys, ms, masks,qs = [], [], [], [],[]

    
#     if add_noise:
#         rng = np.random.default_rng(12345)
#         heatmap = heatmap + rng.normal(0.0, 0.000000001, size=heatmap.shape)

#     #heatmap = gaussian(heatmap, 2.0)
#     for i in np.linspace(start=1.0, stop=0.0, num=num_samples):
#         if method=='del':
#             epsilon = (1 if i==0.0 else 0)
#             q = (np.quantile(heatmap, q=i) - epsilon)
#             m = heatmap <= q
#             nx = (1.0 - np.sum(m) / m.size)
#         elif method=='ins':
#             epsilon = (1 if i==1.0 else 0)
#             q = (np.quantile(heatmap, q=i) + epsilon)
#             m = heatmap >= q
#             nx = (np.sum(m) / m.size)
#         else:
#             raise Exception()
            
#         if len(xs)==0 or nx != xs[-1]:
#             xs.append(nx)
#             masks.append(m)
#             ms.append(np.sum(heatmap[m]))
#             qs.append(q)
#             if len(masks) >= batch_size:
#                 y = predict_yolo_masked(np.array(masks))[:, predicted_cls]
#                 ys.extend(y)
#                 masks = []

#     if len(masks) > 0:
#         y = predict_yolo_masked(np.array(masks))[:, predicted_cls]
#         ys.extend(y)
    
#     xs, ys = np.array(xs), np.array(ys)
#     auc_reg, auc_eff, auc_mse = 0.0, 0.0, 0.0
#     assert(len(xs) == len(ys))
#     # compute the area under the curve - use the rectangle method
#     for i in range(1, len(xs)):
#         delta_x = abs(xs[i] - xs[i-1])
#         if delta_x > 0:
#             auc_reg += abs(delta_x * 0.5*(ys[i-1] + ys[i])) # base * height
#             auc_eff += abs(delta_x * 0.5*(ys[i-1] + ys[i] - ms[i-1] - ms[i])) # base * height
#             auc_mse += abs(delta_x * 0.5*(ys[i-1] + ys[i] - ms[i-1] - ms[i])**2) # base * height^2

#     # return xs, ys, ms, auc_reg, auc_eff, auc_mse
#     return {'xs':xs, 'ys':ys, 'ms':ms, 'qs':qs, 'auc_reg':auc_reg, 'auc_eff':auc_eff, 'auc_mse':auc_mse,'method':method}

In [None]:
# def saliency_to_auc(heatmap, f_S, f_0, predicted_cls, batch_size=4, method='del', num_samples=101, 
#                     rule='trapezoid'):

#     # print(f'Computing AUC with method={method}, batch_size={batch_size}, num_samples={num_samples}, rule={rule}')
#     # print('from saliency_to_auc',heatmap.shape, heatmap.dtype, heatmap.min(), heatmap.max())
#     assert isinstance(heatmap, np.ndarray)
#     assert len(heatmap.shape)==2 and np.issubdtype(heatmap.dtype, np.floating)

#     nu_max = max(f_S, f_0)
#     nu_min = min(f_S, f_0)

#     xs, ys, ms, masks, qs = [], [], [], [], []
#     for i, value in enumerate(np.linspace(start=1.0, stop=0.0, num=num_samples)):
#         if method=='del':
#             epsilon = (1 if value==0.0 else 0)
#             q = (np.quantile(heatmap, q=value) - epsilon)
#             m = heatmap <= q
#             nx = (1.0 - np.sum(m) / m.size)
#         elif method=='ins':
#             epsilon = (1 if value==1.0 else 0)
#             q = (np.quantile(heatmap, q=value) + epsilon)
#             m = heatmap >= q
#             nx = (np.sum(m) / m.size)
#         else:
#             raise Exception()
            
#         # add a new datapoint on the curve
#         if len(xs)==0 or nx != xs[-1]: 
#             assert m.dtype==bool and len(m.shape)==2
#             xs.append(nx)
#             masks.append(m)
#             ms.append(np.sum(heatmap[m]))
#             qs.append(q)

#         # evaluate the characteristic function
#         if len(masks) >= batch_size or (len(masks)>0 and i==(num_samples-1)):
#             y = predict_yolo_masked(np.array(masks))[:, predicted_cls]
#             ys.extend(y)
#             masks = []

#     assert len(masks)==0    
#     xs, ys = np.array(xs), np.array(ys)
#     assert(len(xs) == len(ys))

#     # compute considering under/over shoots
#     overshoot_max = np.maximum(0, ys - nu_max) # overshoot for values exceeding the maximum
#     overshoot_min = np.maximum(0, nu_min - ys) # overshoot for values below the minimum
#     # adjust ys with the overshoot. Clip it inside the admitted range
#     y_adjusted = np.clip(ys - 2*overshoot_max + 2*overshoot_min, nu_min, nu_max)

#     # rescaling
#     ys_rescaled = (ys - nu_min) / (nu_max - nu_min)
#     y_adjusted_rescaled = (y_adjusted - nu_min) / (nu_max - nu_min)

#     auc, auc_r, auc_mae, auc_mse, auc_adj, auc_adjr = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0

#     curve_range = range(1, len(xs)) if rule=='trapezoid' else range(len(xs))

#     # compute the area under the curve with the midpoint Riemann sum (i.e. the trapezoidal rule)
#     for i in curve_range:
#         if rule=='trapezoid':
#             delta_x = abs(xs[i] - xs[i-1])
#             assert delta_x > 0
#             y_mid   =         0.5*(ys[i-1] + ys[i])
#             y_r_mid =         0.5*(ys_rescaled[i-1] + ys_rescaled[i])
#             err_mid = y_mid - 0.5*(ms[i-1] - ms[i])
#             y_adj_mid =       0.5*(y_adjusted[i-1] + y_adjusted[i])
#             y_adjr_mid =      0.5*(y_adjusted_rescaled[i-1] + y_adjusted_rescaled[i])
#         else: # rectangles
#             delta_x = 1.0/num_samples if i==len(xs)-1 else abs(xs[i+1] - xs[i])
#             assert delta_x > 0
#             y_mid   =         ys[i]
#             y_r_mid =         ys_rescaled[i]
#             err_mid = y_mid - ms[i]
#             y_adj_mid =       y_adjusted[i]
#             y_adjr_mid =      y_adjusted_rescaled[i]


#         auc += abs(delta_x * y_mid) # base * height
#         auc_r += abs(delta_x * y_r_mid) # base * height
#         # auc_eff += abs(delta_x * err_mid) # base * height
#         auc_mae += abs(delta_x * err_mid) # base * height
#         auc_mse += abs(delta_x * (err_mid**2)) # base * height^2
#         auc_adj += abs(delta_x * y_adj_mid)
#         auc_adjr += abs(delta_x * y_adjr_mid)

#     return {'xs':xs, 'ys':ys, 'ms':ms, 'qs':qs, 'ysr':ys_rescaled,
#             'y_adj':y_adjusted, 'y_adjr':y_adjusted_rescaled, 
#             'method':method, #'class_id':class_id,
#             'auc':auc, 'auc_r':auc_r, #'auc_eff':auc_eff, 
#             'auc_mae':auc_mae, 'auc_mse':auc_mse, 'auc_rmse':np.sqrt(auc_mse), 
#             'auc_adj':auc_adj, 'auc_adjr':auc_adjr}

In [None]:
def saliency_to_auc(nu, heatmap, f_S, f_0, predicted_cls, batch_size=4, method='del', num_samples=101, rule='trapezoid'):
                    # predict_yolo_masked,heatmaps[n][0],f_S,f_0,predicted_cls, method='del', batch_size=params.batch_size
    assert isinstance(heatmap, np.ndarray)
    assert len(heatmap.shape)==2 and np.issubdtype(heatmap.dtype, np.floating)

    # nu_max = max(f_S, f_0)
    # nu_min = min(f_S, f_0)

    xs, ys, ms, masks, qs = [], [], [], [], []
    for i, value in enumerate(np.linspace(start=1.0, stop=0.0, num=num_samples)):
        if method=='del':
            epsilon = (1 if value==0.0 else 0)
            q = (np.quantile(heatmap, q=value) - epsilon)
            m = heatmap <= q
            nx = (1.0 - np.sum(m) / m.size)
        elif method=='ins':
            epsilon = (1 if value==1.0 else 0)
            q = (np.quantile(heatmap, q=value) + epsilon)
            m = heatmap >= q
            nx = (np.sum(m) / m.size)
        else:
            raise Exception()
            
        # add a new datapoint on the curve
        if len(xs)==0 or nx != xs[-1]: 
            assert m.dtype==bool and len(m.shape)==2
            xs.append(nx)
            masks.append(m)
            ms.append(np.sum(heatmap[m]))
            qs.append(q)

        # evaluate the characteristic function
        if len(masks) >= batch_size or (len(masks)>0 and i==(num_samples-1)):
            y = nu(np.array(masks))[:, predicted_cls]
            ys.extend(y)
            masks = []

    assert len(masks)==0    
    xs, ys = np.array(xs), np.array(ys)
    assert(len(xs) == len(ys))

    # compute considering under/over shoots
    if f_S > f_0:
        overshoot_max = np.maximum(0, ys - f_S) # overshoot for values exceeding the maximum f(S)
        overshoot_min = np.maximum(0, f_0 - ys) # overshoot for values below the minimum f(0)
    else: # f(S) < f(0)
        overshoot_max = np.maximum(0, ys - f_0) # overshoot for values exceeding the maximum f(0)
        overshoot_min = np.maximum(0, f_S - ys) # overshoot for values below the minimum f(S)

    # clip ys, no oveshoots
    y_clipped = np.clip(ys, min(f_S, f_0), max(f_S, f_0))
    # adjust ys with the overshoot. Clip it inside the admitted range
    y_adjusted = np.clip(ys - 2*overshoot_max + 2*overshoot_min, min(f_S, f_0), max(f_S, f_0))

    # rebase to f(0)
    if f_S > f_0:
        flipped = False
        ys = ys - f_0 
        y_clipped = y_clipped - f_0 
        y_adjusted = y_adjusted - f_0
    else: # f(S) < f(0)
        flipped = True
        ys = f_0 - ys 
        y_clipped = f_0 - y_clipped 
        y_adjusted = f_0 - y_adjusted

    # rescaling
    ys_rescaled = ys / abs(f_S - f_0)
    y_clipped_rescaled = y_clipped / abs(f_S - f_0)
    y_adjusted_rescaled = y_adjusted / abs(f_S - f_0)

    auc, auc_r, auc_mae, auc_mse, auc_adj, auc_adjr, auc_clip, auc_clipr = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0

    curve_range = range(1, len(xs)) if rule=='trapezoid' else range(len(xs))

    # compute the area under the curve with the midpoint Riemann sum (i.e. the trapezoidal rule)
    for i in curve_range:
        if rule=='trapezoid':
            delta_x = abs(xs[i] - xs[i-1])
            assert delta_x > 0
            y_mid   =         0.5*(ys[i-1] + ys[i])
            y_r_mid =         0.5*(ys_rescaled[i-1] + ys_rescaled[i])
            err_mid = y_mid - 0.5*(ms[i-1] - ms[i])
            y_clip_mid =       0.5*(y_clipped[i-1] + y_clipped[i])
            y_clipr_mid =      0.5*(y_clipped_rescaled[i-1] + y_clipped_rescaled[i])
            y_adj_mid =       0.5*(y_adjusted[i-1] + y_adjusted[i])
            y_adjr_mid =      0.5*(y_adjusted_rescaled[i-1] + y_adjusted_rescaled[i])
        else: # rectangles
            delta_x = 1.0/num_samples if i==len(xs)-1 else abs(xs[i+1] - xs[i])
            assert delta_x > 0
            y_mid   =         ys[i]
            y_r_mid =         ys_rescaled[i]
            err_mid = y_mid - ms[i]
            y_clip_mid =       y_clipped[i]
            y_clipr_mid =      y_clipped_rescaled[i]
            y_adj_mid =       y_adjusted[i]
            y_adjr_mid =      y_adjusted_rescaled[i]


        auc += abs(delta_x * y_mid) # base * height
        auc_r += abs(delta_x * y_r_mid) # base * height
        # auc_eff += abs(delta_x * err_mid) # base * height
        auc_mae += abs(delta_x * err_mid) # base * height
        auc_mse += abs(delta_x * (err_mid**2)) # base * height^2
        auc_clip += abs(delta_x * y_clip_mid)
        auc_clipr += abs(delta_x * y_clipr_mid)
        auc_adj += abs(delta_x * y_adj_mid)
        auc_adjr += abs(delta_x * y_adjr_mid)

    return {'xs':xs, 'ms':ms, 'qs':qs, 
            'f_0':f_0, 'f_S':f_S, 'flipped':flipped, 
            'ys':ys, 'ysr':ys_rescaled,
            'y_clip':y_clipped, 'y_clipr':y_clipped_rescaled, 
            'y_adj':y_adjusted, 'y_adjr':y_adjusted_rescaled, 
            'method':method, 'predicted_cls':predicted_cls,
            'auc':auc, 'auc_r':auc_r,
            'auc_mae':auc_mae, 'auc_mse':auc_mse, 'auc_rmse':np.sqrt(auc_mse), 
            'auc_clip':auc_clip, 'auc_clipr':auc_clipr,
            'auc_adj':auc_adj, 'auc_adjr':auc_adjr}

In [None]:
def combine_groundtruth_explanation(gtruth, heatmap, threshold):
    if gtruth.ndim == 3:
        gt = gtruth[:,:,0]>0
    else:
        gt = gtruth[:,:]>0
    ht = (heatmap >= threshold).astype(np.uint8)
    img = np.zeros(shape=list(heatmap.shape)+[3], dtype=np.uint8)
    img[:,:,0] = 255*(1-gt)
    img[:,:,1] = 255*(1-ht)
    img[:,:,2] = 255*(1-ht)
    return img
# def calc_IoU_curve(y_true, y_pred):
    
#     assert len(y_true.shape)==1 and len(y_pred.shape)==1 # assumes y_true and y_pred to be flattened arrays
#     yd = np.array(sorted(zip(y_pred, y_true), reverse=True))
#     X2   = np.zeros(len(y_pred))
#     IoU2 = np.zeros(len(y_pred))
#     Th   = np.zeros(len(y_pred))
    
#     nT = np.sum(y_true)
#     nInt = 0
#     for i in range(len(y_pred)):
#         if yd[i,1]: 
#             nInt += 1
        
#         IoU2[i] = nInt / (i + nT - nInt)
#         X2[i] = i
#         Th[i] = yd[i,0]
        
#     X2 = X2 / len(y_pred)
#     auc_IoU = 0
#     for i in range(1, len(y_pred)):
#         auc_IoU += (X2[i] - X2[i-1]) * (IoU2[i] + IoU2[i-1]) / 2.0
    
#     best_pt = np.argmax(IoU2)
    
#     if np.sum(y_pred) == 0:
#         return X2, np.zeros_like(X2), Th[best_pt], X2[best_pt], 0
#     else:
#         return X2, IoU2, Th[best_pt], X2[best_pt], auc_IoU
#         # return {'X':X2, 'Y':IoU2, 'max_IoU':Th[best_pt], 'x_best':X2[best_pt], 'auc_IoU':auc_IoU}

# UPDATED calc_IoU_curve_imp

In [None]:
def calc_IoU_curve_imp(y_true, y_pred, add_noise=True):
    assert isinstance(y_true, np.ndarray)
    assert isinstance(y_pred, np.ndarray)
    assert len(y_true.shape)==1 and len(y_pred.shape)==1 # assumes y_true and y_pred to be flattened arrays
    assert len(y_true)==len(y_pred)
    assert y_true.dtype==np.dtype('bool') and np.issubdtype(y_pred.dtype, np.floating)
    if add_noise:
        rng = np.random.default_rng(12345)
        y_pred = y_pred + rng.normal(0.0, 0.000000001, size=y_pred.shape)
    
    yd   = np.array(sorted(zip(y_pred, y_true), reverse=True))
    X2   = np.zeros(len(y_pred))
    IoU2 = np.zeros(len(y_pred))
    Th   = np.zeros(len(y_pred))
    
    nT = np.sum(y_true)
    nInt = 0
    for i in range(len(y_pred)):
        if yd[i,1]: 
            nInt += 1
        
        IoU2[i] = nInt / (i + nT - nInt)
        X2[i] = i
        Th[i] = yd[i,0]
        
    X2 = X2 / len(y_pred)
    auc_IoU = 0
    for i in range(1, len(y_pred)):
        auc_IoU += (X2[i] - X2[i-1]) * (IoU2[i] + IoU2[i-1]) / 2.0
    
    best_pt = np.argmax(IoU2)
    # return {'X':X2, 'Y':IoU2, 'max_IoU':Th[best_pt], 'x_best':X2[best_pt], 'auc_IoU':auc_IoU}
    return {'X':X2, 'Y':IoU2, 'max_IoU_heatmap_threshold':Th[best_pt], 
            'max_IoU_score':IoU2[best_pt], 'x_best':X2[best_pt], 'auc_IoU':auc_IoU}




def calc_IoU_curve(y_true, y_pred, add_noise=True):
    
    assert len(y_true.shape)==1 and len(y_pred.shape)==1 # assumes y_true and y_pred to be flattened arrays
    if add_noise:
        rng = np.random.default_rng(12345)
        # factor = np.mean(y_pred)/1000
        y_pred = y_pred + rng.normal(0.0, 0.000000001, size=y_pred.shape) # 0.0000001
    
    yd = np.array(sorted(zip(y_pred, y_true), reverse=True))
    X2   = np.zeros(len(y_pred))
    IoU2 = np.zeros(len(y_pred))
    Th   = np.zeros(len(y_pred))
    
    nT = np.sum(y_true)
    nInt = 0
    for i in range(len(y_pred)):
        if yd[i,1]: 
            nInt += 1
        
        IoU2[i] = nInt / (i + nT - nInt)
        X2[i] = i
        Th[i] = yd[i,0]
        
    X2 = X2 / len(y_pred)
    auc_IoU = 0
    for i in range(1, len(y_pred)):
        auc_IoU += (X2[i] - X2[i-1]) * (IoU2[i] + IoU2[i-1]) / 2.0
    
    best_pt = np.argmax(IoU2)
    
    if np.sum(y_pred) == 0:
        return X2, np.zeros_like(X2), Th[best_pt], X2[best_pt], 0
    else:
        return X2, IoU2, Th[best_pt], X2[best_pt], auc_IoU

In [None]:
def vis_IoU(shapley_values, threshold, ground_truth,verbose=False):
    pred = shapley_values.flatten() >= threshold
    real = ground_truth.flatten()
    # real = real.astype(np.float32)
    
    image = np.full((len(pred), 3), 1.0, dtype=np.float32)
    if verbose:
        print(np.sum(pred), np.sum(real))
    image[ pred & real, : ]    = (0.0, 0.0, 0.75) # True Positives
    image[ pred & (~real), : ] = (1.0, 0.6, 0.2)  # False Positives
    image[ (~pred) & real, : ] = (1.0, 0.4, 1.0)  # False Negatives

    return image.reshape(list(ground_truth.shape) + [3])

# RUN SINGLE

In [None]:
def test_single_run(image_no = 139,filter_methods= [], fixed_category = None,plot_img_gt=False,plot_prediction = False,verbose=False):
    global heatmaps,IoU,aucD,aucI
    # print('explaining :', fixed_category)
    image_info = coco.loadImgs(image_no)[0]
    image_path = os.path.join(paths.image_dir, image_info['file_name'])
    load_image_to_explain(image_path,params, bg_type='gray')
    
    fixed_category = class_names[predicted_cls] # 'person'
    
    load_groundtruth(coco,image_path,fixed_category=fixed_category)
    has_segmentation = any('segmentation' in ann for ann in annotations)
    if has_segmentation:
        print(f'{params.model_name} {image_no:<10} {fixed_category:<15} {class_names[predicted_cls]:<10} {f_S:<10.5}')
        
        if verbose:
            print('='*90)
            
            print('='*90)
        if plot_img_gt:
            plot_img_gt_bg()
            plot_gt(image_to_explain,annotations,fixed_category, save_fig=True,fig_size = (5,5))
            plot_mask(ground_truth,category_name = fixed_category,save_fig=True,fig_size = (5,5))
        ############################################################
        results = model.predict(image_to_explain)
        if plot_prediction:
            plot_predictions(image_to_explain,results,category_name = fixed_category, save_fig=True,fig_size = (5,5))
        # print(result)
        print('-'*100)
        #############   XAI
        explainer = shap_bpt.Explainer(predict_yolo_masked, image_to_explain, num_explained_classes=2, verbose=True)
        ################# EXPLANATION
        heatmaps, aucD, aucI,IoU = {}, {}, {},{}
        for n,_,funct in tqdm(methods, leave=False):
            if n not in filter_methods:
                continue;
            # print_gpu_memory()
            start_time = datetime.now()
            heatmaps[n] = funct()
            end_time = datetime.now()
            print(f'| {n:<20} | {np.sum(heatmaps[n][0]):<20.5} | {"{}".format(end_time - start_time)} |')
        ################# EVALUATION
        if verbose:
            print(f'| {"method":<10} | {"aucI_pred":<10} | {"aucD_pred":<10} | {"aucI_mse":<10} | {"aucD_mse":<10} | {"max_IoU":<10} | {"au_IoU":<10} |')
        for n,_,funct in tqdm(methods, leave=False):
            if n not in filter_methods:
                continue;
            aucD[n] = saliency_to_auc(heatmaps[n][0], method='del', batch_size=params.batch_size)
            aucI[n] = saliency_to_auc(heatmaps[n][0], method='ins', batch_size=params.batch_size)
            IoU[n]  = calc_IoU_curve(ground_truth.flatten(), heatmaps[n][0].flatten())
            if verbose:
                print(f"| {n:<10} | {aucI[n]['auc_reg']:<10.5} | {aucD[n]['auc_reg']:<10.5} | {aucI[n]['auc_mse']:<10.5} | {aucD[n]['auc_mse']:<10.5} | {np.max(IoU[n][1]):<10.5} | {IoU[n][4]:<10.5} |")

        ################# PLOTTING
        fun_plot_heatmaps(methods,heatmaps,destroy_fig=False, selected_ext='png', save_fig=True,plot_colorbar=False, plot_title=True)
        fun_plot_IoU(methods,heatmaps,destroy_fig=False, selected_ext='png',plot_title=False, save_fig=True,fontsize=12)
        plot_perform_iou(methods,IoU)
        ################# 

# TEST_SINGLE

In [None]:
# run_single_test = False
run_single_test = True

# filter_methods = ['AA-500','Partition-500','BPT-500','LIME-500']
filter_methods = []

# USE_METHOD_FILTER = True
USE_METHOD_FILTER = False


In [None]:
for n,_,_ in methods:
    print(n)

### LOAD IMAGE

In [None]:
if run_single_test:
    
    image_no        =   113235 
    # image_no        =   1490 
    fixed_category  =   None
    plot_img_gt     =   True
    plot_prediction =   True
    verbose         =   False

In [None]:
# predict_yolo_masked(np.expand_dims(ground_truth, axis=0))

# predict_yolo_masked(np.expand_dims(image_to_explain, axis=0))

In [None]:
# Get background values
# f_G, f_B = get_bg_values_yolo(f_masked_yolo, ground_truth, predicted_cls, class_names, verbose=True)

In [None]:
if run_single_test:
    image_info = coco.loadImgs(image_no)[0]
    image_path = os.path.join(paths.image_dir, image_info['file_name'])
    load_image_to_explain(image_path,params, bg_type='black')

    

    fixed_category = class_names[predicted_cls] # 'person'
    
    load_groundtruth(coco,image_path,fixed_category=fixed_category)
    has_segmentation = any('segmentation' in ann for ann in annotations)
    if has_segmentation:
        print('='*90)
        
        if verbose:
            print('='*90)
            
            print('='*90)
        #--------------------------------------------------------------------------------------
        # f_G, f_B = get_bg_values(predict_yolo,ground_truth,predicted_cls,class_names, verbose= True)
        #--------------------------------------------------------------------------------------
        print(f'{params.model_name} {image_no:<10} {fixed_category:<15} {class_names[predicted_cls]:<10} {f_S:<10.5} {f_0:<10.5}')
        if plot_img_gt:
            plot_img_gt_bg()
            plot_gt(image_to_explain,annotations,fixed_category, save_fig=True,fig_size = (5,5))
            plot_mask(ground_truth,category_name = fixed_category,save_fig=True,fig_size = (5,5))
        results = model.predict(image_to_explain)
        if plot_prediction:
            plot_predictions(image_to_explain,results,category_name = fixed_category, save_fig=True,fig_size = (5,5))
        # print(result)
        print('-'*100)
            

### SINGLE : EXPLANATION

In [None]:
if run_single_test:
    heatmaps = {}
    evaluate_explanation = True
    print(f'{"| Method":<20} | {"f_S":<10} | {"f_T":<10} | {"f_0":<10} | {"TIME":<10} |')
    for ii, (n,_,funct) in enumerate(tqdm(methods, leave=False, desc = 'Explanation: ')):
        if n not in filter_methods and USE_METHOD_FILTER:
                # print('skipping - ',n)
                continue
        start_time = datetime.now()
        heatmaps[n] = funct()
        end_time = datetime.now()
        print(f'| {n:<20} | {f_S:<10.7} |{np.sum(heatmaps[n][0]):<10.7} |{f_0:<10.7} | {"{}".format(end_time - start_time)} |')


### SINGLE : EVALUATION

In [None]:
if run_single_test:
    aucD, aucI,IoU = {}, {}, {}
    print('='*130)
    print(f'| {"method":<10} | {"f_S":<8} | {"f_T":<8} | {"f_S-f_T":<12} | {"aucI_pred":<10} | {"aucD_pred":<10} | {"aucI_mse":<10} | {"aucD_mse":<10} | {"max_IoU":<10} | {"au_IoU":<10} |')
    print('='*130)
    for n,_,funct in tqdm(methods, leave=False, desc ='evaluation'):
        ###############################################
        if n not in filter_methods and USE_METHOD_FILTER:
                # print('skipping - ',n)
                continue
        ###############################################
        if evaluate_explanation:
            st = time.time()
            aucD[n] = saliency_to_auc(predict_yolo_masked,heatmaps[n][0],f_S,f_0,predicted_cls, method='del', batch_size=params.batch_size)
            aucI[n] = saliency_to_auc(predict_yolo_masked,heatmaps[n][0],f_S,f_0,predicted_cls, method='ins', batch_size=params.batch_size)
            IoU[n]  = calc_IoU_curve_imp(ground_truth.flatten(), heatmaps[n][0].flatten())
            print(f"| {n:<10} | {f_S:<8.5} |{np.sum(heatmaps[n][0]):<8.5} |{f_S - np.sum(heatmaps[n][0]):<12.7} |\
                   {aucI[n]['auc']:<10.5} | {aucD[n]['auc']:<10.5} | {aucI[n]['auc_r']:<10.5} | {aucD[n]['auc_r']:<10.5} | {np.max(IoU[n]['Y']):<10.5} | {IoU[n]['auc_IoU']:<10.5} |")


### SINGLE : PLOTTING

In [None]:
if run_single_test:
    if len(filter_methods)==0:
        fun_plot_heatmaps(methods,heatmaps,destroy_fig=False, selected_ext='png',
                      exp_type='demo',
                       save_fig=True,plot_colorbar=False, plot_title=True)
    else:
        fun_plot_filtered_heatmaps(methods,heatmaps,destroy_fig=False, selected_ext='png',
                      exp_type='demo',save_fig=True,plot_colorbar=False, plot_title=True,
                      filter_methods=filter_methods)

In [None]:
if run_single_test:
    if len(filter_methods)==0:
        fun_plot_IoU(methods,heatmaps,destroy_fig=False, selected_ext='png',
                 exp_type = 'demo',
                 plot_title=True, save_fig=True,fontsize=12)
        plot_perform_iou(methods,exp_type = 'demo')
    else:
        fun_plot_filtered_IoU(methods,heatmaps,destroy_fig=False, selected_ext='png',
                 exp_type = 'demo',
                 plot_title=True, save_fig=True,fontsize=12, filter_methods=filter_methods)
        plot_perform_filtered_iou(methods,exp_type = 'demo', filter_methods=filter_methods)

In [None]:
plt_ttl = f'{image_no}, explained_class: {predicted_cls} '\
                            f'f_S:\ {f_S:0.6}, f_0: {f_0:0.6}, delta: {abs(f_S-f_0):0.6}'


if run_single_test:

    if len(filter_methods)==0:
        fun_plot_performance(aucI,aucD,IoU, save_fig = True, ttl=plt_ttl, layout='2rows')
    else:
        fun_plot_performance_filtered(selected_methods=filter_methods)
        

# SN: FULL TEST DEMO

In [None]:
# run_demo = False
run_demo = True

In [None]:
if run_demo:
    num_explained_classes = 2
    # batch_size            = 32
    print(f'{"image":<10} {"total_obj":<10} {"selected_obj":<10} {"prob":<10} {"aucD_mse":<10} {"max_IoU":<10} {"au_IoU"}:<10')
    for im_id,image_no in enumerate(tqdm(coco.getImgIds())):
        if image_no==139 or True:
            #----------------------------------------------------------------------------------------
            image_info = coco.loadImgs(image_no)[0]
            image_path = os.path.join(paths.image_dir, image_info['file_name'])
            load_image_to_explain(image_path,params, bg_type='gray')
            #----------------------------------------------------------------------------------------
            
            if class_names[predicted_cls]=='tv':
                continue
            # print(f'top pred class === \t{} : {}')
            fixed_category = class_names[predicted_cls] # 'person'
            # fixed_category = 'person' #class_names[predicted_cls] # 'person'
            annotations = get_annotation(coco,image_path)[0]
            
            load_groundtruth(coco,image_path,fixed_category=fixed_category)
            #----------------------------------------------------------------------------------------
            has_segmentation = any('segmentation' in ann for ann in annotations)
            if has_segmentation:
                print(f'{image_no:<10} {fixed_category:<15} {class_names[predicted_cls]:<10} {f_S:<10.5}')
                # plot_img_gt_bg()
                plot_gt(image_to_explain,annotations,fixed_category, save_fig=True,fig_size = (5,5))
                # plot_mask(ground_truth,category_name = fixed_category,save_fig=True,fig_size = (5,5))
                #----------------------------------------------------------------------------------------
                results = model.predict(image_to_explain)
                
                plot_predictions(image_to_explain,results,category_name = fixed_category, save_fig=True,fig_size = (5,5))
                # print(result)
                print('-'*100)
                #############   XAI
                explainer = shap_bpt.Explainer(predict_yolo_masked, image_to_explain, num_explained_classes=2, verbose=True)
                ############
                heatmaps, aucD, aucI,IoU = {}, {}, {},{}
                for n,_,funct in tqdm(methods, leave=False):
                    if n not in filter_methods and USE_METHOD_FILTER:
                        # print('skipping - ',n)
                        continue
                    print('method : ',n)
                    # print_gpu_memory()

                    heatmaps[n] = funct()
                #----------------------------------------------------------------------------------------
                for n,_,funct in tqdm(methods, leave=False):
                    if n not in filter_methods and USE_METHOD_FILTER:
                        # print('skipping - ',n)
                        continue
                    aucD[n] = saliency_to_auc(predict_yolo_masked,heatmaps[n][0],f_S,f_0,predicted_cls, method='del', batch_size=params.batch_size)
                    aucI[n] = saliency_to_auc(predict_yolo_masked,heatmaps[n][0],f_S,f_0,predicted_cls, method='ins', batch_size=params.batch_size)
                    IoU[n]  = calc_IoU_curve_imp(ground_truth.flatten(), heatmaps[n][0].flatten())
                if len(filter_methods)==0:
                    fun_plot_heatmaps(methods,heatmaps,destroy_fig=False, selected_ext='png', save_fig=True,plot_colorbar=False)
                    fun_plot_IoU(methods,heatmaps,destroy_fig=False, selected_ext='png',plot_title=True, save_fig=True,fontsize=12)
                else:
                    fun_plot_filtered_heatmaps(methods,heatmaps, filter_methods=filter_methods,
                                               destroy_fig=False, selected_ext='png', save_fig=True,plot_colorbar=False)
                    fun_plot_filtered_IoU(methods,heatmaps, filter_methods=filter_methods,
                                          destroy_fig=False, selected_ext='png',plot_title=True, save_fig=True,fontsize=12)

        break

In [None]:
print(f'{params.model_name} {model.info()}')

# RUN ON SELECTED IMAGES


In [None]:
# run_on_selected_images = False
run_on_selected_images = True

In [None]:
#     name,                 color,                  functor
verbose = False

methods_selected = [
    # ('BPT-100',         'xkcd:light pink',     partial(get_bpt_heatmaps, num_samples=100,batch_size=64,verbose=verbose)),
    ('BPT-500',         'xkcd:bright pink',     partial(get_bpt_heatmaps, num_samples=500,batch_size=64,verbose=verbose)),
    # ('BPT-1000',         'xkcd:deep pink',     partial(get_bpt_heatmaps, num_samples=1000,batch_size=64,verbose=verbose)),
    ]
methods_selected_aa = [ # if multiple backgrounds
    # ('AA-100', 'xkcd:bright blue',     partial(get_aa_heatmaps, num_samples=100,batch_size=64, verbose=verbose)),
    ('AA-500', 'xkcd:bright blue',     partial(get_aa_heatmaps, num_samples=500,batch_size=64, verbose=verbose)),
    # ('AA-1000', 'xkcd:bright blue',    partial(get_aa_heatmaps, num_samples=1000,batch_size=64, verbose=verbose)),
]
# methods_aa_huge = [
#     ('AA-5000', 'xkcd:bright blue',      partial(get_aa_heatmaps, num_samples=5000, verbose=verbose)),
#     # ('AA-10000', 'xkcd:bright blue',     partial(get_aa_heatmaps, num_samples=10000,batch_size=64, verbose=verbose)),
# ]
# methods_lime = [
#     ('LIME-50',        'xkcd:bright lime',     partial(get_lime_heatmaps, num_segments=50, num_samples=50*5,verbose=verbose)),
#     ('LIME-100',        'xkcd:kermit green',   partial(get_lime_heatmaps, num_segments=100, num_samples=100*5,verbose=verbose)),
#     ('LIME-200',        'xkcd:dark lime green',partial(get_lime_heatmaps, num_segments=200, num_samples=200*5,verbose=verbose)),
#     ]


methods_selected_lime = [
    # ('LIME-50',        'xkcd:bright lime',     partial(get_lime_heatmaps, num_segments=50, num_samples=50*5,verbose=verbose)),
    # ('LIME-100',        'xkcd:bright lime',     partial(get_lime_heatmaps, num_segments=100, num_samples=100*5,verbose=verbose)),
    # ('LIME-200',        'xkcd:dark lime green',partial(get_lime_heatmaps, num_segments=200, num_samples=200*5,verbose=verbose)),
###########################################################
## num_segments = num_samples / 10
    # partial(get_aa_heatmaps, num_samples=100, verbose=verbose))
    # ('LIME-100',        'xkcd:bright lime',     partial(get_lime_heatmaps, num_segments=100/10, num_samples=100,verbose=verbose)),
    ('LIME-500',        'xkcd:kermit green',   partial(get_lime_heatmaps, num_segments=500/10, num_samples=500,verbose=verbose)),
    # ('LIME-1000',        'xkcd:dark lime green',partial(get_lime_heatmaps, num_segments=1000/10, num_samples=1000,verbose=verbose))
    ]

# methods_cam = [
#     # ('aIDG',         'xkcd:indigo',          partial(get_idg_heatmaps, use_abs=True)),
#     ('aGradExpl',    'red',                  partial(get_gradexpl_heatmap, use_abs=True))
#     ]

# methods_ShapGradE = [
#     ('ShapGradE',     'xkcd:camel',            partial(get_gradshap_captum_heatmaps, n_samples=20)),    
# ]

# # methods_limeSAM = [
# #     ('LIMESAM',        'xkcd:bright lime',     partial(get_limeSAM_heatmaps, num_samples=500, verbose=verbose)),
# #     ]



# methods_LRP_ViT = [
#         ('LRP',        'xkcd:bright lime',     get_heatmaps_LRP_ViT),
    
#     ]
# methods_LRP = [
#     ('LRP',        'xkcd:bright lime',     get_LRP_captum_heatmaps),
    
#     ]

# method_gradcam_vit_heatmaps = [
#     ('GradCAM',     'xkcd:camel',            get_gradcam_vit_heatmaps),
#     ]

# if pretrained_model_type  == 'swin_trans_vit' or pretrained_model_type == 'vit_LRP':
#     methods_gradcam = [
#     ('GradCAM',     'xkcd:camel',            get_gradcam_vit_heatmaps),
#     ]
# else:
#     methods_gradcam = [
#     ('GradCAM',     'xkcd:camel',            get_gradcam_heatmaps),
#     ]
    

# methods += methods_pe if len(background_tensors) == 1 else methods_aa
methods_selected +=methods_selected_aa
# methods += methods_aa_huge
methods_selected += methods_selected_lime


# if pretrained_model_type  == 'swin_trans_vit' or pretrained_model_type == 'vit_LRP':
#     methods += methods_LRP_ViT
# else:
#     methods += methods_LRP
    
# if pretrained_model_type  == 'swin_trans_vit' or pretrained_model_type == 'vit_LRP':
#     methods += method_gradcam_vit_heatmaps
# else:
#     methods += methods_gradcam

# methods += methods_cam

# methods += methods_ShapGradE
for n,_,_ in methods_selected:
    print(n)

In [None]:
if run_on_selected_images:
    exp_type = 'selected'
    # selected_set = [
    #                 8021,
    #                 1000, #
    #                 1490, # person
    #                 1584, # bus
    #                 1761, #airplane
    #                 2153, #person INTRESTING CASE
    #                 2473, # ski, INTRESTING CASE
    #                 3156, 
    #                 3845, #  cup - 0.88561 # OCCLUDED
    #                 253002, ####### INSPECT
    #                 139,
    #                 7888,
    #                 8277,
    #                 8762,
    #                 9483, # keyboard
    #                 11051, # person -> 2
                     
    # #                 397133

                    # ]
    
    selected_set = [785,
                    360661,     # MULTI HORSES
                    181666,     # --> two person very small
                    233771,     # --> RGB BJECT ON GRAY IMAGE
                    238866,     # --> GRAY MULTIPLE OBJECTS 
                    25560,      # --> TINY OBJECT in Left corner
                    41888,      # --> Multi Objects 3 birds
                    80671,      # --> PERSON IN DIFFERENT POSE
                    85329,      # --> VERIFYYYYYYYYYY TINY OBJECT
                    87038,      # --> TINY OBJECT
                    125211,     # --> ZEBRA STRIPESSSSSSSSSSSSSSSSSSSS 
                    252219,     # --> MULTI OBJECTS 3 person
                    270244,     #   --> ZEBRA IN FOREST 
                    286994,     #   --> MULTIPLE MANY MANY OBJECTS ???  
                    516316,     #   --> MULTIPLE TINY OBJECTS

    ]
    
    # selected_set = [785,
    #                 # 5037,
    #                 # 5060,
    #                 # 7816,
    #                 # 10764,
    #                 # 15440,
    #                 # 18737,
    #                 # 22396
    #                 ]
    print('exp_type',exp_type)
    print('selected_set',selected_set)
print(filter_methods)

In [None]:
if run_on_selected_images:
    data_to_csv =  []

    plot_heatmaps            = True
    plot_IoU                 = False #True
    compute_auc_curves       = False
    save_fig                 = True
    destroy_fig              = False 
    save_vectors_heatmaps    = False if run_on_selected_images else True
    save_vectors_evaluation  = False 
    plot_title               = False
    st_full                  = time.time()

    # disallowed_vectors = ['LRP','GradCAM','aIDG','aGradExpl','ShapGradE']
    # if model_type=='real':
    # disallowed_vectors = ['Partition-100','Partition-500','Partition-1000','aGradExpl','GradCAM','LRP']

    methods_ls,methods_ls_id = [],[]
    for i_id,(i,_,_) in enumerate(methods):
        # i = i.replace('Partition','PE')
        # i = i.replace('GradCAM','GC')
        i = i.replace('-','_')
        # methods_ls.append(i)
        # methods_ls_id.append(i_id+1)
    files_success =[]
    # print('saving data :\t',paths.plotsIoU_path)
    start_time = datetime.now()
    print(f'code started:\t{start_time}')

    paths.csv_filename = f'{paths.path_csv}/csv_exp_{suffix}_{len(methods)}.csv'
    print('='*95)
    # print(f'{params.model_name} {model.info()}')
    # print('='*95)
    print(f'| {"No":<10}| {"Image no":<10} | {"predicted class":<25} | {"annotations":<10} | {"predicted prob":<12} | {"predicted class-id":<15} |')
    print('='*95)

    verify_annotation_preds = []
    for im_id,image_no in enumerate(tqdm(selected_set)):
        # print(image_no)
        # break
        
        if image_no not in selected_set:
            # print(image_no, image_no in selected_set)
        # else:
            continue
        # break
        # if im_id>1000:
        #     break
        # if image_no!=226111:
        #     continue
        # image_no = int(im_no.split('\\')[-1].split('.')[0])
        st_load = time.time()

        image_info = coco.loadImgs(image_no)[0]
        image_path = os.path.join(paths.image_dir, image_info['file_name'])
        load_image_to_explain(image_path,params, bg_type='gray')
        ########################################################################
        
        # if class_names[predicted_cls]=='tv':
        #     continue
        # print(f'top pred class === \t{class_names[predicted_cls]} : {f_S:10.5}')
        fixed_category = class_names[predicted_cls] # 'person'
        # fixed_category = 'person' #class_names[predicted_cls] # 'person'
        annotations_all = get_annotation(coco,image_path,category_name=None)
        

        # print(ima,len(annotations))
        if len(annotations_all)==0:
            continue
        preds_ = []
        for annotation in annotations_all:
            bbox = annotation['bbox']  # [x, y, width, height]
            category = coco.loadCats(annotation['category_id'])[0]['name']
            preds_.append(category)

        ########################################################################
        fixed_category = class_names[predicted_cls] # 'person'
        # fixed_category = 'person' #class_names[predicted_cls] # 'person'

        annotations_selected = get_annotation(coco,image_path,category_name=fixed_category)
        
        if len(annotations_selected)==0:
            verify_annotation_preds.append({
            'image_no':image_no,
            'annotation_count':len(annotations_selected),
            'top_predicted': class_names[predicted_cls],
            'preds': preds_,
            'True?': class_names[predicted_cls] in preds_
            })
            # print(f'| {im_id:<5} | {image_no:<10} | {fixed_category:<25} | {len(annotations_selected):<10} | {f_S:<12.5} | {predicted_cls:<15} |')
            continue
        
        load_groundtruth(coco,image_path,fixed_category=fixed_category)
        ########################################################################            
        annotations_selected = annotations_selected[0]

        has_segmentation = any('segmentation' in ann for ann in annotations_selected)
        if not has_segmentation:
            continue

        results = model.predict(image_to_explain)
        time_load = time.time()-st_load
        # plot_predictions(image_to_explain,results,category_name = fixed_category)

        print(f'| {im_id:<10} | {image_no:<10} | {fixed_category:<25} | {len(annotations_selected):<10} | {f_S:<12.5} | {predicted_cls:<15} |')

In [None]:
if run_on_selected_images:
    data_to_csv =  []

    plot_heatmaps            = True
    plot_IoU                 = False #True
    compute_auc_curves       = False
    save_fig                 = True
    destroy_fig              = False 
    save_vectors_heatmaps    = False if run_on_selected_images else True
    save_vectors_evaluation  = False 
    plot_title               = False
    st_full                  = time.time()

    # disallowed_vectors = ['LRP','GradCAM','aIDG','aGradExpl','ShapGradE']
    # if model_type=='real':
    # disallowed_vectors = ['Partition-100','Partition-500','Partition-1000','aGradExpl','GradCAM','LRP']

    methods_ls,methods_ls_id = [],[]
    for i_id,(i,_,_) in enumerate(methods):
        # i = i.replace('Partition','PE')
        # i = i.replace('GradCAM','GC')
        i = i.replace('-','_')
        # methods_ls.append(i)
        # methods_ls_id.append(i_id+1)
    files_success =[]
    # print('saving data :\t',paths.plotsIoU_path)
    start_time = datetime.now()
    print(f'code started:\t{start_time}')

    paths.csv_filename = f'{paths.path_csv}/csv_exp_{suffix}_{len(methods)}.csv'
    print('='*95)
    # print(f'{params.model_name} {model.info()}')
    # print('='*95)
    print(f'| {"No":<10}| {"Image no":<10} | {"predicted class":<25} | {"annotations":<10} | {"predicted prob":<12} | {"predicted class-id":<15} |')
    print('='*95)

    verify_annotation_preds = []
    for im_id,image_no in enumerate(tqdm(selected_set)):
        # print(image_no)
        # break
        
        if image_no not in selected_set:
            # print(image_no, image_no in selected_set)
        # else:
            continue
        # break
        # if im_id>1000:
        #     break
        # if image_no!=226111:
        #     continue
        # image_no = int(im_no.split('\\')[-1].split('.')[0])
        st_load = time.time()

        image_info = coco.loadImgs(image_no)[0]
        image_path = os.path.join(paths.image_dir, image_info['file_name'])
        load_image_to_explain(image_path,params, bg_type='gray')
        ########################################################################
        
        # if class_names[predicted_cls]=='tv':
        #     continue
        # print(f'top pred class === \t{class_names[predicted_cls]} : {f_S:10.5}')
        fixed_category = class_names[predicted_cls] # 'person'
        # fixed_category = 'person' #class_names[predicted_cls] # 'person'
        annotations_all = get_annotation(coco,image_path,category_name=None)
        

        # print(ima,len(annotations))
        if len(annotations_all)==0:
            
            # annot__ = get_annotation(coco,image_path,category_name=None)
            
            # preds_ = []
            # for ar in range(len(annot__)):
            #     print(ar, class_names[annot__[ar]['category_id']])
            #     preds_.append(class_names[annot__[ar]['category_id']])
            # fail_cases.append({
            #     'image_no':image_no,
            #     'annotation_count':len(annot__),
            #     'top_predicted': class_names[predicted_cls],
            #     'preds': preds_
            # })
            continue
        preds_ = []
        for annotation in annotations_all:
            bbox = annotation['bbox']  # [x, y, width, height]
            category = coco.loadCats(annotation['category_id'])[0]['name']
            preds_.append(category)
        # else:
            # pa
            # print(image_no, len(annotations))
        ########################################################################
        fixed_category = class_names[predicted_cls] # 'person'
        # fixed_category = 'person' #class_names[predicted_cls] # 'person'

        annotations_selected = get_annotation(coco,image_path,category_name=fixed_category)
        
        if len(annotations_selected)==0:
            verify_annotation_preds.append({
            'image_no':image_no,
            'annotation_count':len(annotations_selected),
            'top_predicted': class_names[predicted_cls],
            'preds': preds_,
            'True?': class_names[predicted_cls] in preds_
            })
            # print(f'| {im_id:<5} | {image_no:<10} | {fixed_category:<25} | {len(annotations_selected):<10} | {f_S:<12.5} | {predicted_cls:<15} |')
            continue
        
        load_groundtruth(coco,image_path,fixed_category=fixed_category)
        ########################################################################            
        annotations_selected = annotations_selected[0]

        has_segmentation = any('segmentation' in ann for ann in annotations_selected)
        if not has_segmentation:
            continue
        # if not class_names[predicted_cls] in preds_:

        #     print(f'| FAIL: {im_id:<5} | {image_no:<10} | {fixed_category:<25} | {len(annotations_selected):<10} | {f_S:<12.5} | {predicted_cls:<15} |')
        # continue
        
        
        # print(f'{image_no:<10} {fixed_category:<15} ')
        # print(f'{image_no:<10} {fixed_category:<15} {class_names[predicted_cls]:<10} {f_S:<10.5}')
        # plot_img_gt_bg()
        # break
        results = model.predict(image_to_explain)
        time_load = time.time()-st_load
        # plot_predictions(image_to_explain,results,category_name = fixed_category)

        # print(result)
        # print('-'*95)
        print(f'| {im_id:<10} | {image_no:<10} | {fixed_category:<25} | {len(annotations_selected):<10} | {f_S:<12.5} | {predicted_cls:<15} |')
        #############   XAI
        explainer = shap_bpt.Explainer(predict_yolo_masked, image_to_explain, num_explained_classes=2, verbose=True)
        ############     HeatMaps    ############
        heatmaps = {}
        
        time_exp = {}
        for n,_,funct in tqdm(methods, desc='Explanation',leave=False):
            if len(filter_methods)>0:
                if n not in filter_methods  and USE_METHOD_FILTER:
                    # print('skipping - ',n)
                    continue;
            # heatmap_filename = f'{vector_path}//heatmaps_{image_n}_{n}.pkl'
            
            # if save_vectors_heatmaps:
            #     if os.path.exists(heatmap_filename):
            #         print(f'loaded : {image_n}_{n}.pkl')
            st = time.time()
            heatmaps[n] = funct()
            time_exp[n] = time.time()-st
        
        if plot_heatmaps:
            if len(filter_methods)==0:
                fun_plot_heatmaps(methods,heatmaps,exp_type=exp_type,destroy_fig=destroy_fig,
                            #   exp_type='selected',
                              save_fig=save_fig,plot_title=plot_title, dpi=100)
            else:
                fun_plot_filtered_heatmaps(methods,heatmaps,destroy_fig=destroy_fig,
                                           exp_type=exp_type,save_fig=save_fig,plot_colorbar=False, plot_title=plot_title,
                                           dpi=200,
                                           filter_methods=filter_methods)
                
        if compute_auc_curves:
            ############     AUCs    ############
            aucD, aucI,IoU = {}, {},{}
            overlaps={}
            for n,_,_ in tqdm(methods, desc='Evaluation',leave=False):
                if len(filter_methods)>0:
                    if n not in filter_methods  and USE_METHOD_FILTER:
                        # print('skipping - ',n)
                        continue;
                #############     aucD    ############################
                st = time.time()
                aucD[n] = saliency_to_auc(predict_yolo_masked,heatmaps[n][0],f_S,f_0,predicted_cls, method='del', batch_size=params.batch_size)
                   
                time_aucD = time.time()-st
                ############    aucI     ############################
                st = time.time()
                aucI[n] = saliency_to_auc(predict_yolo_masked,heatmaps[n][0],f_S,f_0,predicted_cls, method='ins', batch_size=params.batch_size)
                time_aucI = time.time()-st
                st = time.time()
                ###########   IOU  ############################
                IoU[n]  = calc_IoU_curve_imp(ground_truth.flatten(), heatmaps[n][0].flatten())
                
                time_auc_IoU = time.time()-st      
            
                data_selected = {'image'         : image_no,
                                'image_size'    : image_to_explain.shape,
                                'bg_type'       : params.background_type,
                                'inference_time': results[0].speed['inference'],
                                'pred_cls'      : predicted_cls,
                                'pred_lbl'      : class_names[predicted_cls],
                                'object_count'  : len(annotations),
                                #----------------------------------------------------------------------------------------
                                'f_S'           : f_S,
                                'f_0'           : f_0,
                                'f_T'           : np.sum(heatmaps[n][0]),               
                                'f_N'           : len(np.unique(heatmaps[n][0])),           ## UNIQUE PATCHES IN EXPLANATION
                                'method'        : n,
                                #----------------------------------------------------------------------------------------                        
                                'aucI_pred'     : aucI[n]['auc'],  
                                'aucD_pred'     : aucD[n]['auc'],  
                                ##------------------------------------
                                'aucI_r'        : aucI[n]['auc_r'],  
                                'aucD_r'        : aucD[n]['auc_r'],  
                                ##------------------------------------
                                'aucI_adj'      : aucI[n]['auc_adj'],  
                                'aucD_adj'      : aucD[n]['auc_adj'],  
                                ##------------------------------------
                                'aucI_adj_r'    : aucI[n]['auc_adjr'],  
                                'aucD_adj_r'    : aucD[n]['auc_adjr'],  
                                ##------------------------------------
                                'aucI_clip'     : aucI[n]['auc_clip'], 
                                'aucI_clipr'    : aucI[n]['auc_clipr'],
                                ##------------------------------------
                                'aucD_clip'     : aucD[n]['auc_clip'], 
                                'aucD_clipr'    : aucD[n]['auc_clipr'],
                                #----------------------------------------------------------------------------------------
                                'threshold'     : IoU[n]['max_IoU_heatmap_threshold'],      # [2],
                                'best_point'    : IoU[n]['x_best'],     # [3],
                                'max_IoU'       : np.max(IoU[n]['Y']),  # [1],
                                'auc_IoU'        : IoU[n]['auc_IoU'],     # [4],
                                #----------------------------------------------------------------------------------------
                                'time_load'     : time_load,
                                'time_exp'      : time_exp[n],
                                'time_aucI'     : time_aucI,
                                'time_aucD'     : time_aucD,
                                'time_auc_IoU'  : time_auc_IoU,
                                'time_total'    : time_load+time_exp[n]+time_aucI+time_aucD+time_auc_IoU
                                }
                
                data_to_csv.append(data_selected)
            if plot_IoU:
                if len(filter_methods)==0:
                    fun_plot_IoU(methods,heatmaps,exp_type=exp_type,
                            destroy_fig=destroy_fig,save_fig=save_fig,
                            plot_title=plot_title, dpi=100,text_x=45,text_y=110)
                else:
                    fun_plot_filtered_IoU(methods,heatmaps, filter_methods=filter_methods,
                                        exp_type=exp_type,
                                        destroy_fig=destroy_fig,
                                        plot_title=plot_title, save_fig=save_fig,fontsize=12)
        
        df_selected = pd.DataFrame(data_to_csv)
    
    time_full = time.time()-st_full
    print(f'{"Time for selected test is":<20} : {time_full:<10.10}')
    end_time = datetime.now()
    print(f'{"Duration":<20} : {"{}".format(end_time - start_time)}')
    
    chime.success()

In [None]:
# data_to_csv

# FULL TEST SET

In [None]:
# run_full_exp = False
run_full_exp = True
if run_full_exp:
    paths.csv_filename = f'{paths.path_csv}/csv_exp_{suffix}_{len(methods)}.csv'
    print('csv file will be saved at :', paths.csv_filename)
    

In [None]:
# def get_bg_values(f_masked,ground_truth,predicted_cls,class_names):
#     global f_G, f_B
#     # evaluate the ground truth mask with the background replacement strategy for masking function
#     predicted_fG = f_masked(np.expand_dims(ground_truth, axis=0))[0]
#     f_G = float(predicted_fG[predicted_cls])
#     print(class_names[predicted_cls], f_G, predicted_cls, f_G)
#     # print('softmax prob:', np_softmax(predicted_fG)[predicted_cls])

#     # evaluate the backgrounf (negative of the ground truth mask)
#     background_mask = np.logical_not(ground_truth)
#     predicted_fB = f_masked(np.expand_dims(background_mask, axis=0))[0]
#     f_B = float(predicted_fB[predicted_cls])
#     print(class_names[predicted_cls], f_B, predicted_cls, f_B)
#     # print('softmax prob:', np_softmax(predicted_fB)[predicted_cls])

#     print()
#     print('nu(S):  ', round(f_S, 4))
#     print('nu(G):  ', round(f_G, 4))
#     print('nu(S/G):', round(f_B, 4))
#     print('nu(0):  ', round(f_0, 4))

In [None]:
if run_full_exp:
    run_on_full_first = 0
    data_to_csv =  []

    plot_heatmaps            = False #True
    plot_IoU                 = False #True
    save_fig                 = False #True
    destroy_fig              = True
    save_vectors_heatmaps    = True
    save_vectors_evaluation  = False
    plot_title               = False
    verbose_print            = False
    st_full                  = time.time()

    full_img_subset          = coco.getImgIds()[:params.data_subset]
    
    # disallowed_vectors = ['LRP','GradCAM','aIDG','aGradExpl','ShapGradE']
    # if model_type=='real':
    # disallowed_vectors = ['Partition-100','Partition-500','Partition-1000','aGradExpl','GradCAM','LRP']

    methods_ls,methods_ls_id = [],[]
    for i_id,(i,_,_) in enumerate(methods):
        # i = i.replace('Partition','PE')
        # i = i.replace('GradCAM','GC')
        i = i.replace('-','_')
        # methods_ls.append(i)
        # methods_ls_id.append(i_id+1)
    files_success =[]
    print(f'Running for images: {params.data_subset}')
    print('saving data :\t',paths.plotsIoU_path)
    print(f'SAVING CSV AT :', paths.csv_filename)
    start_time = datetime.now()
    print(f'code started:\t{start_time}')

    paths.csv_filename = f'{paths.path_csv}/csv_exp_{suffix}_{len(methods)}.csv'
    if verbose_print:
        print('='*95)
        print(f'{params.model_name} {model.info()}')
        print('='*95)
        print(f'| {"No":<10}| {"Image no":<10} | {"predicted class":<25} | {"annotations":<10} | {"predicted prob":<12} | {"predicted class-id":<15} |')
        print('='*95)

    verify_annotation_preds = []
    for im_id,image_no in enumerate(tqdm(full_img_subset)):
        if run_on_full_first==0:
            start_time_first = datetime.now()
        # if image_no!=226111:
        #     continue
        # image_no = int(im_no.split('\\')[-1].split('.')[0])
        st_load = time.time()

        image_info = coco.loadImgs(image_no)[0]
        image_path = os.path.join(paths.image_dir, image_info['file_name'])
        load_image_to_explain(image_path,params, bg_type='gray')
        ########################################################################

        fixed_category = class_names[predicted_cls] # 'person'
        # fixed_category = 'person' #class_names[predicted_cls] # 'person'
        annotations_all = get_annotation(coco,image_path,category_name=None)

        if len(annotations_all)==0:
            continue
        preds_ = []
        for annotation in annotations_all:
            bbox = annotation['bbox']  # [x, y, width, height]
            category = coco.loadCats(annotation['category_id'])[0]['name']
            preds_.append(category)
        # else:
            # pa
            # print(image_no, len(annotations))
        ########################################################################
        fixed_category = class_names[predicted_cls] # 'person'
        # fixed_category = 'person' #class_names[predicted_cls] # 'person'

        annotations_selected = get_annotation(coco,image_path,category_name=fixed_category)
        
        if len(annotations_selected)==0:
            verify_annotation_preds.append({
            'image_no':image_no,
            'annotation_count':len(annotations_selected),
            'top_predicted': class_names[predicted_cls],
            'preds': preds_,
            'True?': class_names[predicted_cls] in preds_
            })
            # print(f'| {im_id:<5} | {image_no:<10} | {fixed_category:<25} | {len(annotations_selected):<10} | {f_S:<12.5} | {predicted_cls:<15} |')
            continue
        
        load_groundtruth(coco,image_path,fixed_category=fixed_category)
        ####################
        # get_bg_values(predict_yolo_masked,ground_truth,predicted_cls,class_names)
        ####################
        ########################################################################
            
        annotations_selected = annotations_selected[0]
        
        has_segmentation = any('segmentation' in ann for ann in annotations_selected)
        if not has_segmentation:
            continue

        results = model.predict(image_to_explain)
        time_load = time.time()-st_load
        # plot_predictions(image_to_explain,results,category_name = fixed_category)

        if verbose_print:
            print(f'| {im_id:<5} | {image_no:<10} | {fixed_category:<25} | {len(annotations_selected):<10} | {f_S:<12.5} | {predicted_cls:<15} |')
        #############   XAI
        explainer = shap_bpt.Explainer(predict_yolo_masked, image_to_explain, num_explained_classes=2, verbose=True)
        ############     HeatMaps    ############
        heatmaps = {}
        
        time_exp = {}
        for n,_,funct in tqdm(methods, desc='Explanation',leave=False):
            # heatmap_filename = f'{vector_path}//heatmaps_{image_n}_{n}.pkl'
            
            # if save_vectors_heatmaps:
            #     if os.path.exists(heatmap_filename):
            #         print(f'loaded : {image_n}_{n}.pkl')
            st = time.time()
            heatmaps[n] = funct()
            time_exp[n] = time.time()-st
        
        if plot_heatmaps:
            fun_plot_heatmaps(methods,heatmaps,exp_type='full',destroy_fig=destroy_fig,save_fig=save_fig,plot_title=plot_title, dpi=100)
        ############     AUCs    ############
        aucD, aucI,IoU = {}, {},{}
        overlaps={}
        for n,_,_ in tqdm(methods, desc='Evaluation',leave=False):
            #############     aucD    ############################
            st = time.time()
            aucD[n] = saliency_to_auc(predict_yolo_masked,heatmaps[n][0],f_S,f_0,predicted_cls, method='del', batch_size=params.batch_size)
                    
            time_aucD = time.time()-st
            ############    aucI     ############################
            st = time.time()
            aucI[n] = saliency_to_auc(predict_yolo_masked,heatmaps[n][0],f_S,f_0,predicted_cls, method='ins', batch_size=params.batch_size)
                    
            time_aucI = time.time()-st
            st = time.time()
            ###########   IOU  ############################
            IoU[n]  = calc_IoU_curve_imp(ground_truth.flatten(), heatmaps[n][0].flatten())
            
            time_auc_IoU = time.time()-st      
            
            data = {'image'         : image_no,
                    'image_size'    : image_to_explain.shape,
                    'bg_type'       : params.background_type,
                    'inference_time': results[0].speed['inference'],
                    'pred_cls'      : predicted_cls,
                    'pred_lbl'      : class_names[predicted_cls],
                    'object_count'  : len(annotations),
                    'f_S'           : f_S,
                    'f_0'           : f_0,
                    'delta_f'       : f_S-f_0,
                    # 'f_G'           : f_B,
                    # 'f_B'           : f_G,
                    'f_T'           : np.sum(heatmaps[n][0]),               
                    'f_N'           : len(np.unique(heatmaps[n][0])),           ## UNIQUE PATCHES IN EXPLANATION
                    'method'        : n,
                    # 'threshold'     : IoU[n]['max_IoU_heatmap_threshold'],      # [2],
                    # 'best_point'    : IoU[n]['x_best'],     # [3],
                    # 'max_IoU'       : np.max(IoU[n]['Y']),  # [1],
                    # 'auc_IoU'        : IoU[n]['auc_IoU'],     # [4],
                    # 'aucI_pred'     : aucI[n]['auc_reg'],   # aucI[n][-3],
                    # 'aucD_pred'     : aucD[n]['auc_reg'],   # aucD[n][-3],
                    # 'aucI_mse'      : aucI[n]['auc_mse'],   # aucI[n][-1],
                    # 'aucD_mse'      : aucD[n]['auc_mse'],   # aucD[n][-1],
                    ##############################
                    'aucI_pred'     : aucI[n]['auc'],  
                    'aucD_pred'     : aucD[n]['auc'],  
                    ##------------------------------------
                    'aucI_r'        : aucI[n]['auc_r'],  
                    'aucD_r'        : aucD[n]['auc_r'],  
                    ##------------------------------------
                    'aucI_adj'      : aucI[n]['auc_adj'],  
                    'aucD_adj'      : aucD[n]['auc_adj'],  
                    ##------------------------------------
                    'aucI_adj_r'    : aucI[n]['auc_adjr'],  
                    'aucD_adj_r'    : aucD[n]['auc_adjr'],  
                    ##------------------------------------
                    'aucI_clip'     : aucI[n]['auc_clip'], 
                    'aucI_clipr'    : aucI[n]['auc_clipr'],
                    ##------------------------------------
                    'aucD_clip'     : aucD[n]['auc_clip'], 
                    'aucD_clipr'    : aucD[n]['auc_clipr'],
                    #----------------------------
                    'threshold'     : IoU[n]['max_IoU_heatmap_threshold'],      # [2],
                    'best_point'    : IoU[n]['x_best'],     # [3],
                    'max_IoU'       : np.max(IoU[n]['Y']),  # [1],
                    'auc_IoU'        : IoU[n]['auc_IoU'],     # [4],
                    #----------------------------
                    'time_load'     : time_load,
                    'time_exp'      : time_exp[n],
                    'time_aucI'     : time_aucI,
                    'time_aucD'     : time_aucD,
                    'time_auc_IoU'  : time_auc_IoU,
                    'time_total'    : time_load+time_exp[n]+time_aucI+time_aucD+time_auc_IoU
                    }
            
            data_to_csv.append(data)
        if plot_IoU:
            fun_plot_IoU(methods,heatmaps,exp_type='full',destroy_fig=destroy_fig,save_fig=save_fig,plot_title=plot_title, dpi=100)
        # print('-'*90)
        df = pd.DataFrame(data_to_csv)
        df.to_csv(f'{paths.csv_filename}', sep=',')
        if run_on_full_first==0:
            end_time_first = datetime.now()
            print(f'{"Duration for single example is ":<20} : {"{}".format(end_time_first - start_time_first)}, therefor {params.data_subset} images, it can take {"{}".format( (end_time_first - start_time_first)*params.data_subset)}')
            # print(f"Duration for single example is: {end_time_first - start_time_first:.6f} seconds "
            #       f"for {params.data_subset} data, it can take "
            #       f"{(end_time_first - start_time_first) * params.data_subset:.6f} seconds in total.")
            run_on_full_first +=1
            
        # break

    time_full = time.time()-st_full
    
    print(f'{"Time for test is":<20} : {time_full:<10.10}')
    
    end_time = datetime.now()
    print(f'{"Duration":<20} : {"{}".format(end_time - start_time)}')
    
    chime.success()

## END FULL TEST

In [None]:
df.head()

In [None]:
print(paths.csv_filename)
get_file_modify_time(filepath = paths.csv_filename)

In [None]:
print(f'started:\t{start_time}')
print(f'{"Duration":<20} : {"{}".format(end_time - start_time)}')

# END