In [1]:
import numpy as np
import os
import json
import cv2
from glob import glob
from matplotlib import pyplot as plt
import operator
from scipy.stats import pearsonr,spearmanr

In [2]:
data_dir = './gradcam_result_single_proto//'
img_dir = '../autism_photo_taking/'
save_dir = './visualization/visualization_cnn'
idx2category = {0:'indoor',1:'outdoor',2:'people'}

for category in ['ASD','Ctrl']:
    if not os.path.exists(os.path.join(save_dir,category)):
        os.mkdir(os.path.join(save_dir,category))

In [3]:
# overlay for consistency
def overlay_heatmap(img,att,cmap=plt.cm.jet):
    gamma = 1.0
    att[att<0.5] = 0.01
    att = cv2.blur(att,(10,10)) # originally 35
    colorized = cmap(np.uint8(att*255))
    alpha = 0.5
#     alpha = np.repeat((att[:,:,np.newaxis]**gamma+1)/2,3,axis=2)
    overlaid = np.uint8(img*(1-alpha)+colorized[:,:,2::-1]*255*alpha)
    return overlaid

In [100]:
# visualize the gradient-based explanation maps (single)
files = glob(os.path.join(data_dir,'*.npy'))
for cur_file in files:
    cur_id = os.path.basename(cur_file)[:-4]
    cur_data = np.load(os.path.join(cur_file),allow_pickle=True).item()
    cur_label = cur_data['label']
    if cur_label == 1:
        cur_label = 'ASD'
    else:
        cur_label = 'Ctrl'
    
    if not os.path.exists(os.path.join(save_dir,cur_label)):
        os.mkdir(os.path.join(save_dir,cur_label))
        
    cur_category = idx2category[cur_data['category']]
    if not os.path.exists(os.path.join(save_dir,cur_label,cur_category)):
        os.mkdir(os.path.join(save_dir,cur_label,cur_category))
    cam_map = cur_data['pixel_importance']
    cam_map = cv2.resize(cam_map,(224,224))
    if cam_map.max()>0:
        cam_map /= cam_map.max()
    else:
        continue
    cur_img = cv2.imread(os.path.join(img_dir,cur_label,cur_id+'.jpg'))
    cur_map = overlay_heatmap(cur_img,cam_map)
    cv2.imwrite(os.path.join(save_dir,cur_label,cur_category,cur_id+'.jpg'),cur_map)

In [1]:
# analyze the importance of different prototypes (single-image)
overall_proto = dict()
num_proto = 20 # number of asd and ctrl prototypes
for label in ['asd','ctrl']:
    overall_proto[label] = dict()
    for proto_id in range(num_proto):
        overall_proto[label][proto_id] = []

files = glob(os.path.join(data_dir,'*.npy'))
for cur_file in files:
    cur_id = os.path.basename(cur_file)[:-4]
    cur_data = np.load(os.path.join(cur_file),allow_pickle=True).item()
    cur_label = 'asd' if cur_data['label'] else 'ctrl'
    
    cur_proto = cur_data['cluster_assignment']
    cur_importance = cur_data['acc']
    overall_proto[cur_label][cur_proto].append(cur_importance)

# overall importance on different prototypes
for label in ['asd','ctrl']:
    print('Overall importance on top-8 prototypes for group %s' %label)
    res = []
    var = []
#     res = [np.sum(overall_proto[label][cur]) for cur in range(num_proto)]
    for cur in range(num_proto):
        if len(overall_proto[label][cur]) == 0:
            res.append(0)
            var.append(0)
        else:
            res.append(np.mean(overall_proto[label][cur]))
            var.append(np.std(overall_proto[label][cur]))
    var = [cur/np.sum(res) for cur in var]
    res = [cur/np.sum(res) for cur in res]
    top_idx = np.argsort(res)[::-1]
    for i in range(20):
        print('%d: %.3f %.3f' %(top_idx[i],res[top_idx[i]], var[top_idx[i]]))
