In [None]:
import os
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from scipy.misc import imresize
import imageio
import time
from node import Node
import model
import cv2

model_dir = "saved_models_8protos1/"

num_protos_per_class = 8

nearest_train = model_dir + "nearest_train/"
nearest_test = model_dir + "nearest_test/"

root = Node("root")
root.add_children(['animal','vehicle','everyday_object','weapon','scuba_diver'])
root.add_children_to('animal',['non_primate','primate'])
root.add_children_to('non_primate',['African_elephant','giant_panda','lion'])
root.add_children_to('primate',['capuchin','gibbon','orangutan'])
root.add_children_to('vehicle',['ambulance','pickup','sports_car'])
root.add_children_to('everyday_object',['laptop','sandal','wine_bottle'])
root.add_children_to('weapon',['assault_rifle','rifle'])
root.assign_all_descendents()

IDcoarse_names = root.children_names()
IDfine_names = os.listdir("../datasets/imagenet/train")
IDfine_names.sort()
label2name = {i : name for (i,name) in enumerate(IDfine_names)}
IDfineLabel2coarseName = {label : root.closest_descendent_for(name).name for label, name in enumerate(IDfine_names)}    



save = False

### Latent space quality metric: Of the nearest neighbors, how many belong to the right class(es)?

In [None]:
name = "root"
K = 5
node = root.get_node(name)

parent_names = [node.name for node in root.nodes_with_children()]
parent2stat_train = {name : 0 for name in parent_names}
parent2stat_test = {name : 0 for name in parent_names}


for node in root.nodes_with_children():
    name = node.name
    n_protos = num_protos_per_class * node.num_children()
    
    children_names = [node.name for node in node.children]
    
    for j in range(n_protos):                 
    
        train_dir = os.path.join(nearest_train,name,str(j))
        test_dir = os.path.join(nearest_test,name,str(j))

        train_class_idx = np.load(os.path.join(train_dir,"class_id.npy"))
        test_class_idx = np.load(os.path.join(test_dir,"class_id.npy"))
                
        train_names = [label2name[idx] for idx in train_class_idx[1:]] # dont count the prototype itself
        test_names = [label2name[idx] for idx in test_class_idx]
                
        proto_name = children_names[j//num_protos_per_class]                
        proto_node = root.get_node(proto_name)     
        
#         print(proto_name)
#         print(test_names)
                        
        if proto_node.has_logits():
            train_correct = np.array([name in proto_node.descendents for name in train_names])
            test_correct = np.array([name in proto_node.descendents for name in test_names])
        else: # leaf node
            train_correct = np.array([name == proto_name for name in train_names])        
            test_correct = np.array([name == proto_name for name in test_names])        
        
        parent2stat_train[name] += np.mean(train_correct) / n_protos
        parent2stat_test[name] += np.mean(test_correct) / n_protos

        
parent2stat_train = {name : np.round(x,2) for name, x in parent2stat_train.items()}
parent2stat_test = {name : np.round(x,2) for name, x in parent2stat_test.items()}        


print("train")
print(parent2stat_train)
print("train overall %.2f \n" % np.mean([y for y in parent2stat_train.values()]))

print("test")
print(parent2stat_test)
print("test overall %.4f" % np.mean([y for y in parent2stat_test.values()]))




## Coarse: All six protos for a class

In [None]:
print(root.children_to_labels)

In [None]:
name = "root"
K = 5
node = root.get_node(name)
show = range(0,K)

start_ids = np.arange(0,5*num_protos_per_class,num_protos_per_class)
start_id = start_ids[0]

for proto_id in range(start_id,start_id+num_protos_per_class):
    
    train_dir = os.path.join(nearest_train,name,str(proto_id))
    test_dir = os.path.join(nearest_test,name,str(proto_id))
    
    print('\n\n')
    print('----------------------------------------------------------------------------------------------------')

    proto_path = os.path.join(model_dir,name + "_prototypes","prototype-img" + str(proto_id) + ".png")
    img = imageio.imread(proto_path)
    fig = plt.figure()
    plt.imshow(img)
    plt.title("Prototype " + str(proto_id))
    plt.show()

    train_class_idx = np.load(os.path.join(train_dir,"class_id.npy"))
    test_class_idx = np.load(os.path.join(test_dir,"class_id.npy"))
    
    
#     print("\n-------------TRAIN neighbors---------------")
#     for k in show:

#         img_path = os.path.join(train_dir,"nearest-" + str(k) + ".png")
#         img_name = IDfineLabel2coarseName[train_class_idx[k]]

#         img = imageio.imread(img_path)
#         fig = plt.figure()
#         plt.imshow(img)
#         plt.title("NN %s: %s  -- Near to %s proto %s" % (k,img_name,name,str(proto_id)))
#         plt.show()        

#         proto_path = os.path.join(train_dir,"nearest-" + str(k) + "_original_with_heatmap.png")
#         img = imageio.imread(proto_path)   
#         fig = plt.figure()                
#         plt.imshow(img)        
#         plt.title("NN %s heatmap" % k)
#         plt.show()
#         if proto_id == 30 and save:
#             img_path = os.path.join(train_dir,"nearest-" + str(k) + "_original.png")
#             org_img = imageio.imread(img_path)
#             plt.imsave("paper_images/knn/%s_train_proto_%s_org_%s.jpg" % (name, str(proto_id), str(k)),org_img)    
#             plt.imsave("paper_images/knn/%s_train_proto_%s_heat_%s.jpg" % (name, str(proto_id), str(k)),img)    
            
    print("\n-------------TEST neighbors---------------")
    for k in show:

        img_path = os.path.join(test_dir,"nearest-" + str(k) + ".png")
        img_name = IDfineLabel2coarseName[test_class_idx[k]]

        img = imageio.imread(img_path)
        fig = plt.figure()
        plt.imshow(img)
        plt.title("NN %s: %s  -- Near to %s proto %s" % (k,img_name,name,str(proto_id)))
        plt.show()

        proto_path = os.path.join(test_dir,"nearest-" + str(k) + "_original_with_heatmap.png")
        img = imageio.imread(proto_path)   
        fig = plt.figure()                
        plt.imshow(img)        
        plt.title("NN %s heatmap" % k)
        plt.show()
        if proto_id == 32 and save:
            img_path = os.path.join(test_dir,"nearest-" + str(k) + "_original.png")
            org_img = imageio.imread(img_path)
            plt.imsave("paper_images/knn/%s_test_proto_%s_org_%s.jpg" % (name, str(proto_id), str(k)),org_img)   
            plt.imsave("paper_images/knn/%s_test_proto_%s_heat_%s.jpg" % (name, str(proto_id), str(k)),img)   






## All six protos for fine class 

In [None]:
name = "vehicle"
#name = "weapon"
K = 5
node = root.get_node(name)

print(node.children_to_labels)

In [None]:
start_ids = np.arange(0,5*num_protos_per_class,num_protos_per_class)
start_id = start_ids[0]

for proto_id in range(start_id,start_id+6):
    
    train_dir = os.path.join(nearest_train,name,str(proto_id))
    test_dir = os.path.join(nearest_test,name,str(proto_id))
    
    print('----------------------------------------------------------------------------------------------------')

    proto_path = os.path.join(model_dir,name + "_prototypes","prototype-img" + str(proto_id) + ".png")
    img = imageio.imread(proto_path)
    fig = plt.figure()
    plt.imshow(img)
    plt.title("Prototype " + str(proto_id))
    plt.show()

    train_class_idx = np.load(os.path.join(train_dir,"class_id.npy"))
    test_class_idx = np.load(os.path.join(test_dir,"class_id.npy"))

    print("\n-------------TRAIN neighbors---------------")
    for k in range(0,K):

        img_path = os.path.join(train_dir,"nearest-" + str(k) + ".png")
        img_name = IDfineLabel2coarseName[train_class_idx[k]]

        img = imageio.imread(img_path)
        fig = plt.figure()
        plt.imshow(img)
        plt.title("NN %s: %s  -- Near to %s proto %s" % (k,img_name,name,str(proto_id)))
        plt.show()        

        proto_path = os.path.join(train_dir,"nearest-" + str(k) + "_original_with_heatmap.png")
        img = imageio.imread(proto_path)   
        fig = plt.figure()                
        plt.imshow(img)        
        plt.title("NN %s heatmap" % k)
        plt.show()
#         if proto_id == 1 and save:
#             img_path = os.path.join(train_dir,"nearest-" + str(k) + "_original.png")
#             org_img = imageio.imread(img_path)
#             plt.imsave("paper_images/knn/%s_train_proto_%s_org_%s.jpg" % (name, str(proto_id), str(k)),org_img)    
#             plt.imsave("paper_images/knn/%s_train_proto_%s_heat_%s.jpg" % (name, str(proto_id), str(k)),img)    

#     print("\n-------------TEST neighbors---------------")
#     for k in range(0,K):

#         img_path = os.path.join(test_dir,"nearest-" + str(k) + ".png")
#         img_name = IDfineLabel2coarseName[test_class_idx[k]]

#         img = imageio.imread(img_path)
#         fig = plt.figure()
#         plt.imshow(img)
#         plt.title("NN %s: %s  -- Near to %s proto %s" % (k,img_name,name,str(proto_id)))
#         plt.show()

#         proto_path = os.path.join(test_dir,"nearest-" + str(k) + "_original_with_heatmap.png")
#         img = imageio.imread(proto_path)   
#         fig = plt.figure()                
#         plt.imshow(img)        
#         plt.title("NN %s heatmap" % k)
#         plt.show()
#         if proto_id == 1 and save:
#             img_path = os.path.join(test_dir,"nearest-" + str(k) + "_original.png")
#             org_img = imageio.imread(img_path)
#             plt.imsave("paper_images/knn/%s_test_proto_%s_org_%s.jpg" % (name, str(proto_id), str(k)),org_img)   
#             plt.imsave("paper_images/knn/%s_test_proto_%s_heat_%s.jpg" % (name, str(proto_id), str(k)),img)   

