In [None]:
import os
import shutil

import torch
import torch.utils.data
import transforms
import torchvision.datasets as datasets
import argparse
from helpers import makedir, adjust_learning_rate
import model
import push
import train_and_test as tnt
import save
from log import create_logger
from preprocess import mean, std, preprocess_input_function, img_size
from node import Node
import time
import numpy as np



In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

data_path = "../datasets/imagenet/"
model_path = "saved_models_8protos1/"
resume_path =  model_path + "best_model_last_opt.pth"

batch_size = 50
n_protos_per_class = 8
proto_dim = 32


# load the data
data_path = data_path
train_dir = data_path + 'train/'
valid_dir = data_path + 'valid/'
test_dir = data_path + 'test/'
OOD_dir = data_path + 'OODall/test'
train_push_dir = train_dir
train_batch_size = batch_size
valid_batch_size = batch_size
test_batch_size = batch_size
train_push_batch_size = batch_size

# dataset setup

transform_test = transforms.Compose([
	transforms.Resize(256),
	transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean,std),
])


# train set
train_dataset = datasets.ImageFolder(
    train_dir,
    transform_test)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=False)
# valid set
valid_dataset = datasets.ImageFolder(
    valid_dir,
    transform_test)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=valid_batch_size, shuffle=False)    
# test set
test_dataset = datasets.ImageFolder(
    test_dir,
    transform_test)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=valid_batch_size, shuffle=False)
# OOD set
OOD_dataset = datasets.ImageFolder(
    OOD_dir,
    transform_test)
OOD_loader = torch.utils.data.DataLoader(
    OOD_dataset, batch_size=valid_batch_size, shuffle=False)

print('training set size: {0}'.format(len(train_loader.dataset)))
print('valid set size: {0}'.format(len(valid_loader.dataset)))
print('test set size: {0}'.format(len(test_loader.dataset)))		
print('batch size: {0}'.format(train_batch_size))


# construct the tree
root = Node("root")
root.add_children(['animal','vehicle','everyday_object','weapon','scuba_diver'])
root.add_children_to('animal',['non_primate','primate'])
root.add_children_to('non_primate',['African_elephant','giant_panda','lion'])
root.add_children_to('primate',['capuchin','gibbon','orangutan'])
root.add_children_to('vehicle',['ambulance','pickup','sports_car'])
root.add_children_to('everyday_object',['laptop','sandal','wine_bottle'])
root.add_children_to('weapon',['assault_rifle','rifle'])
root.assign_all_descendents()
root.assign_proto_dirs()


OODroot = Node("root")
OODroot.add_children(['animal','vehicle','everyday_object','weapon',"scuba_diver"])
OODroot.add_children_to('animal',['non_primate','primate'])
OODroot.add_children_to('non_primate',['king_penguin','tree_frog','zebra'])
OODroot.add_children_to('primate',['macaque','gorilla','chimpanzee'])
OODroot.add_children_to('vehicle',['cab','forklift','tractor','mountain_bike'])
OODroot.add_children_to('everyday_object',['golf_ball','wallet','table_lamp'])
OODroot.add_children_to('weapon',['revolver','bow'])
OODroot.assign_all_descendents()
OODroot.assign_proto_dirs()


IDcoarse_names = root.children_names()
IDfine_names = os.listdir(train_dir)
IDfine_names.sort()
label2name = {i : name for (i,name) in enumerate(IDfine_names)}
IDfineLabel2coarseLabel = {label : root.children_to_labels[root.closest_descendent_for(name).name] for label, name in enumerate(IDfine_names)}    

OODcoarse_names = OODroot.children_names()
OODfine_names = os.listdir(OOD_dir)
OODfine_names.sort()
OODfineLabel2coarseLabel = {label : OODroot.children_to_labels[OODroot.closest_descendent_for(name).name] for label, name in enumerate(OODfine_names)}

num_fine = len(IDfine_names)
num_coarse = len(root.children)


vgg = model.vgg16_proto(root, pretrained=True, num_prototypes_per_class=n_protos_per_class,prototype_dimension=proto_dim, img_size=img_size, resume_path = resume_path)
vgg = vgg.cuda()
vgg_multi = torch.nn.DataParallel(vgg)
class_specific=True


save = True

In [None]:
test_acc = tnt.test(model=vgg_multi, dataloader=test_loader, label2name=label2name, class_specific=class_specific, log=print)

In [None]:
names_with_children = [node.name for node in vgg.root.nodes_with_children()]
for name in names_with_children:
    print('\n' + name)
    layer = getattr(vgg,name+"_layer")
    weights = [p.data for p in layer.parameters()][0]
    weights = np.array([[np.round(weight.item(),2) for weight in beta] for beta in weights])
    print(weights)
    #print(np.linalg.norm(weights))
    #print([[l > 0 for l in p.data] for p in layer.parameters()]) 


In [None]:
root_vecs = vgg.root_prototype_vectors.detach().cpu().numpy()
for i in range(30):
    print(IDcoarse_names[i//6])
    print([np.round(x,2) for x in root_vecs[i,:,0,0]])

In [None]:

for node in root.nodes_with_children():
    path = model_path + node.proto_dir
    prototype_info = np.load(os.path.join(path, 'bb'+'.npy'))
    #if node.name == "root": print(prototype_info)
    setattr(node,"proto_idx",prototype_info[:,0])
    
def img_id_to_name(idx):
    per_class = 1250
    class_id = idx // per_class
    return IDfine_names[class_id]

node = root.get_node("root")

for i, proto_id in enumerate(node.proto_idx):
    print(str(i) + " " + img_id_to_name(proto_id))

    


## Pick an ID batch to work with

In [None]:
import copy
import matplotlib.pyplot as plt
from skimage.transform import resize 
from preprocess import undo_preprocess_input_function    

def plot_preprocessed_img(preprocessed_imgs, index):
    img_copy = copy.deepcopy(preprocessed_imgs[index:index+1])
    undo_preprocessed_img = undo_preprocess_input_function(img_copy)
    print('image index {0} in batch'.format(index))
    undo_preprocessed_img = np.transpose(undo_preprocessed_img[0], [1,2,0])    
    plt.imshow(undo_preprocessed_img.numpy())
    #plt.imshow(undo_preprocessed_img)
    plt.show()
    #del img_copy
    #del undo_preprocessed_img
    return undo_preprocessed_img

def show_prototype(img_dir,index,original=False):
    p_img = plt.imread(os.path.join(img_dir, 'prototype-img'+str(index)+'.png'))
    if original: p_img = plt.imread(os.path.join(img_dir, 'prototype-original-img'+str(index)+'.png'))
    d1 = p_img.shape[0]
    d2 = p_img.shape[1]
    #p_img = resize(p_img,[d1*2,d2*2,3])
    plt.imshow(p_img)
    plt.show()
    return p_img



In [None]:
print([(x,y) for (x,y) in enumerate(IDfine_names)])

In [None]:
i = 2 # 50 images per class
batch_dataset = torch.utils.data.Subset(test_dataset, range(i*test_batch_size,(i+1)*test_batch_size))
sample_size = 10
batch_loader = torch.utils.data.DataLoader(
    batch_dataset, batch_size=sample_size, shuffle=False)

iter_batch_loader = iter(batch_loader)

quintile = 2

for i in range(quintile):
    images, labels = iter_batch_loader.next()
    

In [None]:
#x = plot_preprocessed_img(images, 0)

for i in range(10):
    _ = plot_preprocessed_img(images, i)

## ID Analysis

In [None]:
case = "weapon"

eps = vgg.epsilon
def activations(min_distances,eps=eps):
    return torch.log(1 + (1 / (min_distances + eps)))
def activation_patterns(model, x, name, eps=eps):
    conv_features = model.conv_features(x)
    distances = model.prototype_distances(conv_features,name)
    return torch.log(1 + (1 / (distances + eps)))


with torch.no_grad():
    images_cuda = images.cuda()
    _ = vgg(images_cuda)
                    
root_y = torch.tensor([IDfineLabel2coarseLabel[y.item()] for y in labels]).cuda()

preds_root, preds_joint = vgg.get_joint_distribution()

coarse_preds = torch.argmax(preds_root, dim=1)
fine_preds = torch.argmax(preds_joint, dim=1)

root_activations = activations(vgg.root.min_distances)
root_activation_patterns = activation_patterns(vgg, images_cuda, "root")
case_activations = activations(vgg.root.get_node(case).min_distances)
case_activation_patterns = activation_patterns(vgg, images_cuda, case)

coarse_correct = (coarse_preds == root_y)
fine_correct = (fine_preds == labels.cuda())
            
print("Batch coarse accuracy: %.2f" % (coarse_correct.sum().item() / len(labels)))
print("Indices of wrong coarse preds: ", np.where(coarse_correct == 0)[0])
print("Batch fine accuracy: %.2f" % (fine_correct.sum().item() / len(labels)))
print("Indices of wrong fine preds: ", np.where(fine_correct == 0)[0])

del images_cuda


In [None]:
#idx_list = np.arange(10)
idx_list = [1]
import cv2

case_node = vgg.root.get_node(case)

for idx in idx_list:
        
    
    fine_pred_name = IDfine_names[fine_preds[idx]]
    fine_true_name = IDfine_names[labels[idx]]
    coarse_pred_name = IDcoarse_names[coarse_preds[idx]]    
    coarse_true_name = IDcoarse_names[root_y[idx]]
    print("\nlabel: " + fine_true_name + ", " + coarse_true_name,"\npredictions: " + fine_pred_name + ", " + coarse_pred_name)
    original_img = plot_preprocessed_img(images, idx).numpy()
    
    if save:
        plt.imsave("paper_images/lead/original_%s.jpg" % str(idx),original_img)
    
    original_img_gray = cv2.cvtColor(original_img, cv2.COLOR_BGR2GRAY)
    root_logits = vgg.root.logits[idx]
    case_logits = case_node.logits[idx]
    print("root logits: ", [np.round(x.item(),2) for x in root_logits])
    print("fine logits: ", [np.round(x.item(),2) for x in case_logits])
    print("root softmax: " ,[np.round(x.item(),6) for x in torch.nn.functional.softmax(root_logits,0)])
    print("fine softmax: ", [np.round(x.item(),2) for x in torch.nn.functional.softmax(case_logits,0)])

In [None]:
plot_thresh = .5
cmap = "jet"

for idx in idx_list:
    
    print("STARTING ON ID %.1d \n \n" % idx)
    
    correct_class = labels[idx]
    fine_pred_name = IDfine_names[fine_preds[idx]]
    fine_true_name = IDfine_names[labels[idx]]
    coarse_pred_name = IDcoarse_names[coarse_preds[idx]]    
    coarse_true_name = IDcoarse_names[root_y[idx]]
    print("\nlabel: " + fine_true_name + ", " + coarse_true_name,"\npredictions: " + fine_pred_name + ", " + coarse_pred_name)
    
    original_img = plot_preprocessed_img(images, idx).numpy()
    original_img_gray = cv2.cvtColor(original_img, cv2.COLOR_BGR2GRAY)  
    
    root_logits = vgg.root.logits[idx]
    case_logits = case_node.logits[idx]
    
    root_max_logit = max(root_logits)
    case_max_logit = max(case_logits)
    
    print("\nBEGIN ROOT ANALYSIS of image %.2d \n" % idx)
        
    array_act, sorted_indices_act = torch.sort(root_activations[idx])
    for i in range(1,len(sorted_indices_act)):
        print('top {0} activated prototype for this image'.format(i))
        print('prototype index: {0}'.format(sorted_indices_act[-i].item()))  
        print('prototype class identity: %s' % img_id_to_name(vgg.root.proto_idx[sorted_indices_act[-i]]))                
        similarity_score = array_act[-i]
        layer_connection = vgg.root_layer.weight[coarse_preds[idx]][sorted_indices_act[-i].item()]
        print('activation value (similarity score): {0:.2f}'.format(similarity_score))
        print('root layer connection with predicted class: {0:.2f}'.format(layer_connection))
        print('contribution to predicted class logit: %.2f' % (similarity_score*layer_connection))
        activation_pattern = root_activation_patterns[idx][sorted_indices_act[-i].item()].detach().cpu().numpy()        
        upsampled_activation_pattern = resize(activation_pattern, [224,224])
        overlayed_img = 0.4 * original_img_gray + 0.6 * upsampled_activation_pattern
        #plt.imshow(upsampled_activation_pattern)
        #if similarity_score > plot_thresh:
        if i <= 4:
            plt.imshow(overlayed_img,cmap=cmap)
            plt.show()
            prototype = show_prototype(model_path + vgg.root.proto_dir, sorted_indices_act[-i].item(),original=True)
            if save:
                plt.imsave("paper_images/lead/overlaid_%s.jpg" % str(i),overlayed_img,cmap=cmap)
                plt.imsave("paper_images/lead/top_proto_%s.jpg" % str(i),prototype)
        print('--------------------------------------------------------------')

        
    print("\nBEGIN FINE ANALYSIS of image %.2d \n" % idx)
    
    case_node = root.get_node(case)
    case_children = case_node.children_names()
    case_pred_id = 0#case_children.index(fine_pred_name)        
    case_layer = getattr(vgg,case+"_layer")

    array_act, sorted_indices_act = torch.sort(case_activations[idx])
    for i in range(1,len(sorted_indices_act)):
        print('top {0} activated prototype for this image'.format(i))
        print('prototype index: {0}'.format(sorted_indices_act[-i].item()))  
        print('prototype class identity: %s' % img_id_to_name(case_node.proto_idx[sorted_indices_act[-i]]))                
        similarity_score = array_act[-i]
        print('activation value (similarity score): {0:.2f}'.format(similarity_score))
        print('case layer connection with predicted class: {0:.2f}'.format(case_layer.weight[case_pred_id][sorted_indices_act[-i].item()]))
        print('contribution to predicted class logit: %.2f' % (array_act[-i]*case_layer.weight[case_pred_id][sorted_indices_act[-i].item()]))
        activation_pattern = case_activation_patterns[idx][sorted_indices_act[-i].item()].detach().cpu().numpy()        
        upsampled_activation_pattern = resize(activation_pattern, [224,224])        
        overlayed_img = 0.4 * original_img_gray + 0.6 * upsampled_activation_pattern
        #plt.imshow(upsampled_activation_pattern)
        #if similarity_score > plot_thresh:
        if i <= 4: # (len(sorted_indices_act) - 1):
            plt.imshow(overlayed_img,cmap=cmap)
            plt.show()
            prototype = show_prototype(model_path + case_node.proto_dir, sorted_indices_act[-i].item(),original=True)
            if save:
                plt.imsave("paper_images/lead/fine_overlaid_%s.jpg" % str(i),overlayed_img,cmap=cmap)
                plt.imsave("paper_images/lead/fine_top_proto_%s.jpg" % str(i),prototype)
        print('--------------------------------------------------------------')

    
    

 # OOD Analysis

In [None]:
print([(x,y) for (x,y) in enumerate(OODfine_names)])

In [None]:
# forklist is i = 3, q = 3

i = 9 # 50 images per class
batch_dataset = torch.utils.data.Subset(OOD_dataset, range(i*test_batch_size,(i+1)*test_batch_size))
sample_size = 10
batch_loader = torch.utils.data.DataLoader(
    batch_dataset, batch_size=sample_size, shuffle=False)

iter_batch_loader = iter(batch_loader)

quintile = 3
for i in range(quintile):
    images, labels = iter_batch_loader.next()
    

In [None]:
#x = plot_preprocessed_img(images, 0)

for i in range(10):
    _ = plot_preprocessed_img(images, i)

In [None]:
case = "weapon"

eps = vgg.epsilon
def activations(min_distances,eps=eps):
    return torch.log(1 + (1 / (min_distances + eps)))
def activation_patterns(model, x, name, eps=eps):
    conv_features = model.conv_features(x)
    distances = model.prototype_distances(conv_features,name)
    return torch.log(1 + (1 / (distances + eps)))


with torch.no_grad():
    images_cuda = images.cuda()
    _ = vgg(images_cuda)
                    
root_y = torch.tensor([OODfineLabel2coarseLabel[y.item()] for y in labels]).cuda()

preds_root, preds_joint = vgg.get_joint_distribution()

coarse_preds = torch.argmax(preds_root, dim=1)
fine_preds = torch.argmax(preds_joint, dim=1)

root_activations = activations(vgg.root.min_distances)
root_activation_patterns = activation_patterns(vgg, images_cuda, "root")
case_activations = activations(vgg.root.get_node(case).min_distances)
case_activation_patterns = activation_patterns(vgg, images_cuda, case)

coarse_correct = (coarse_preds == root_y)
            
print("Batch coarse accuracy: %.2f" % (coarse_correct.sum().item() / len(labels)))
print("Indices of correct coarse preds: ", np.where(coarse_correct == 1)[0])
print("Indices of wrong coarse preds: ", np.where(coarse_correct == 0)[0])

del images_cuda


In [None]:
idx_list = [7]
import cv2

case_node = vgg.root.get_node(case)

for idx in idx_list:
    fine_pred_name = IDfine_names[fine_preds[idx]]
    fine_true_name = OODfine_names[labels[idx]]
    coarse_pred_name = OODcoarse_names[coarse_preds[idx]]    
    coarse_true_name = OODcoarse_names[root_y[idx]]
    print("\nlabel: " + fine_true_name + ", " + coarse_true_name,"\npredictions: " + fine_pred_name + ", " + coarse_pred_name)
    original_img = plot_preprocessed_img(images, idx).numpy()
    original_img_gray = cv2.cvtColor(original_img, cv2.COLOR_BGR2GRAY)
    root_logits = vgg.root.logits[idx]
    case_logits = case_node.logits[idx]
    print("root logits: ", [np.round(x.item(),2) for x in root_logits])
    print("%s logits: " % case_node.name, [np.round(x.item(),2) for x in case_logits])
    print("root softmax: " ,[np.round(x.item(),4) for x in torch.nn.functional.softmax(root_logits,0)])
    print("%s softmax: " % case_node.name, [np.round(x.item(),2) for x in torch.nn.functional.softmax(case_logits,0)])

In [None]:
plot_thresh = .5
cmap = "jet"

for idx in idx_list:
    
    print("STARTING ON ID %.1d \n \n" % idx)
    
    correct_class = labels[idx]
    fine_pred_name = IDfine_names[fine_preds[idx]]
    fine_true_name = OODfine_names[labels[idx]]
    coarse_pred_name = IDcoarse_names[coarse_preds[idx]]    
    coarse_true_name = IDcoarse_names[root_y[idx]]
    print("\nlabel: " + fine_true_name + ", " + coarse_true_name,"\npredictions: " + fine_pred_name + ", " + coarse_pred_name)
    
    original_img = plot_preprocessed_img(images, idx).numpy()
    original_img_gray = cv2.cvtColor(original_img, cv2.COLOR_BGR2GRAY)    
    
    if save:
        plt.imsave("paper_images/case_study/%s%s_original.jpg" % (case, str(idx)),original_img)   
    
    print("\nBEGIN ROOT ANALYSIS of image %.2d \n" % idx)
    
    cont_from_top_4 = 0
    whole_logit = 0
        
    array_act, sorted_indices_act = torch.sort(root_activations[idx])
    for i in range(1,len(sorted_indices_act)):
        print(i)
        
        #print('prototype class identity: %s' % img_id_to_name(vgg.root.proto_idx[sorted_indices_act[-i]]))                
        similarity_score = array_act[-i]
        if similarity_score > plot_thresh:
            print('top {0} activated prototype for this image'.format(i))
            print('prototype index: {0}'.format(sorted_indices_act[-i].item()))      
            print('activation value (similarity score): {0:.2f}'.format(similarity_score))
            print('root layer connection with predicted class: {0:.2f}'.format(vgg.root_layer.weight[coarse_preds[idx]][sorted_indices_act[-i].item()]))
            print('contribution to predicted class logit: %.2f' % (array_act[-i]*vgg.root_layer.weight[coarse_preds[idx]][sorted_indices_act[-i].item()]))
        contribution = array_act[-i]*vgg.root_layer.weight[coarse_preds[idx]][sorted_indices_act[-i].item()]        
        activation_pattern = root_activation_patterns[idx][sorted_indices_act[-i].item()].detach().cpu().numpy()        
        upsampled_activation_pattern = resize(activation_pattern, [224,224])        
        overlayed_img = 0.3 * original_img_gray + 0.7 * upsampled_activation_pattern
        #plt.imshow(upsampled_activation_pattern)
                            
        if similarity_score > plot_thresh:            
            plt.imshow(overlayed_img,cmap=cmap)
            plt.show()
            proto = show_prototype(model_path + vgg.root.proto_dir, sorted_indices_act[-i].item(),original=True)
            print('--------------------------------------------------------------')
            
            
        whole_logit += contribution
        if i <= 4:
            cont_from_top_4 += contribution
            if save:
                plt.imsave("paper_images/case_study/%s%s_top%s_proto.jpg" % (case, str(idx), str(i)),proto)   
                plt.imsave("paper_images/case_study/%s%s_top%s_heat.jpg" % (case, str(idx), str(i)),overlayed_img,cmap=cmap)   
    
            
    print("cont from top 4: %.2f \t whole logit: %.2f" % (cont_from_top_4,whole_logit))
        
#     print("\nBEGIN FINE ANALYSIS of image %.2d \n" % idx)
    
#     case_node = root.get_node(case)
#     case_children = case_node.children_names()
#     case_pred_id = case_children.index(fine_pred_name)        
#     case_layer = getattr(vgg,case+"_layer")

#     array_act, sorted_indices_act = torch.sort(case_activations[idx])
#     for i in range(1,len(sorted_indices_act)):
#         print('top {0} activated prototype for this image'.format(i))
#         print('prototype index: {0}'.format(sorted_indices_act[-i].item()))  
#         #print('prototype class identity: %s' % img_id_to_name(case_node.proto_idx[sorted_indices_act[-i]]))                
#         similarity_score = array_act[-i]
#         print('activation value (similarity score): {0:.2f}'.format(similarity_score))
#         print('case layer connection with predicted class: {0:.2f}'.format(case_layer.weight[case_pred_id][sorted_indices_act[-i].item()]))
#         print('contribution to predicted class logit: %.2f' % (array_act[-i]*case_layer.weight[case_pred_id][sorted_indices_act[-i].item()]))
#         activation_pattern = case_activation_patterns[idx][sorted_indices_act[-i].item()].detach().cpu().numpy()        
#         upsampled_activation_pattern = resize(activation_pattern, [224,224])        
#         overlayed_img = 0.3 * original_img_gray + 0.7 * upsampled_activation_pattern
#         #plt.imshow(upsampled_activation_pattern)
#         if similarity_score > plot_thresh:
#             plt.imshow(overlayed_img)
#             plt.show()
#             show_prototype(model_path + case_node.proto_dir, sorted_indices_act[-i].item())
#         print('--------------------------------------------------------------')

    
    