In [11]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import numpy as np
import imutils
from tqdm import tqdm

from tensorflow.keras.layers import Input
from tensorflow.keras.models import load_model
from keras.datasets import mnist

from architectures.protoshotxai import ProtoShotXAI
from utils.ploting_function import xai_plot
from architectures.exmatchina_star import ExMatchina


In [2]:
model_path_pretrained = '../trained_models/pretrained_conv_mnist/'
base_model = load_model(model_path_pretrained)

In [5]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.expand_dims(x_train,axis = 3)/255
query = x_train[y_train == 6][88]
query = np.expand_dims(query,axis=0)

query_rotations = np.zeros((360,28,28,1))
for irot in range(360):
    img_rot = imutils.rotate(query[0], angle=irot)
    img_rot = np.expand_dims(np.copy(img_rot),axis=2)
    query_rotations[irot,:,:,:] = img_rot

In [8]:
protoshot = ProtoShotXAI(base_model)
shot = 100

In [9]:

support_0 = x_train[y_train == 0]
support_0 = support_0[np.random.permutation(support_0.shape[0])[:shot]]
support_0 = np.expand_dims(support_0,axis=0)

support_5 = x_train[y_train == 5]
support_5 = support_5[np.random.permutation(support_5.shape[0])[:shot]]
support_5 = np.expand_dims(support_5,axis=0)

support_6 = x_train[y_train == 6]
support_6 = support_6[np.random.permutation(support_6.shape[0])[:shot]]
support_6 = np.expand_dims(support_6,axis=0)

support_9 = x_train[y_train == 9]
support_9 = support_9[np.random.permutation(support_9.shape[0])[:shot]]
support_9 = np.expand_dims(support_9,axis=0)

model_output = base_model.predict(query_rotations)
class_val = np.argmax(model_output,axis=1)
query_rotations = np.expand_dims(np.copy(query_rotations),axis=0)

scores_0 = protoshot.compute_score(support_0,query_rotations,0)
scores_5 = protoshot.compute_score(support_5,query_rotations,5)
scores_6 = protoshot.compute_score(support_6,query_rotations,6)
scores_9 = protoshot.compute_score(support_9,query_rotations,9)


all_scores = -1e6*np.ones((360,10))
all_scores[:,0] = scores_0
all_scores[:,5] = scores_5
all_scores[:,6] = scores_6
all_scores[:,9] = scores_9
protoshotXAI_class = np.argmax(all_scores,axis=1)
acc = sum(class_val == protoshotXAI_class)/360

print(acc)

0.9305555555555556


In [14]:
# selected_layer: the layer to use in identifying examples.
# We recommend the layer immediately following the last convolution (e.g. flatten layer)
selected_layer = "dropout_1"
exm = ExMatchina(model=base_model, layer=selected_layer, examples=x_train)
exm_0 = ExMatchina(model=base_model, layer=selected_layer, examples=x_train[y_train==0])
exm_5 = ExMatchina(model=base_model, layer=selected_layer, examples=x_train[y_train==5])
exm_6 = ExMatchina(model=base_model, layer=selected_layer, examples=x_train[y_train==6])
exm_9 = ExMatchina(model=base_model, layer=selected_layer, examples=x_train[y_train==9])


ex_matchina_class = np.zeros(360)
ex_matchina_score_0 = np.zeros(360)
ex_matchina_score_5 = np.zeros(360)
ex_matchina_score_6 = np.zeros(360)
ex_matchina_score_9 = np.zeros(360)

progress_bar = True
for irot in tqdm(range(360),disable=(not progress_bar)):
    img_rot = imutils.rotate(query[0], angle=irot)
    img_rot = np.expand_dims(np.copy(img_rot),axis=2)

    (examples, indices, results) = exm_0.return_nearest_examples(img_rot,num_examples=1)
    ex_matchina_score_0[irot] = results[0]
    (examples, indices, results) = exm_6.return_nearest_examples(img_rot,num_examples=1)
    ex_matchina_score_6[irot] = results[0]
    (examples, indices, results) = exm_5.return_nearest_examples(img_rot,num_examples=1)
    ex_matchina_score_5[irot] = results[0]
    (examples, indices, results) = exm_9.return_nearest_examples(img_rot,num_examples=1)
    ex_matchina_score_9[irot] = results[0]

all_scores = -1e6*np.ones((360,10))
all_scores[:,0] = ex_matchina_score_0
all_scores[:,5] = ex_matchina_score_5
all_scores[:,6] = ex_matchina_score_6
all_scores[:,9] = ex_matchina_score_9
ex_matchina_class = np.argmax(all_scores,axis=1)
acc = sum(class_val == ex_matchina_class)/360
print(acc)

Getting activations...
Getting labels...
Generating activation matrix...
Getting activations...
Getting labels...
Generating activation matrix...
Getting activations...
Getting labels...
Generating activation matrix...
Getting activations...
Getting labels...
Generating activation matrix...
Getting activations...
Getting labels...
Generating activation matrix...


100%|██████████| 360/360 [02:13<00:00,  2.71it/s]

0.9





In [13]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# Create traces
deg = np.arange(360)
fig = make_subplots(rows=3, cols=1, vertical_spacing = 0.07, subplot_titles=("Predicted Class for Rotating Six","ProtoShotXAI Class Scores", "ExMatchina* Class Scores"))
fig.update_annotations(font_size=24)

fig.add_trace(go.Scatter(x=deg, y=class_val,
                    mode='lines', line=dict(width=8),
                    name='Predicted Model Class'),row=1, col=1)
fig.add_trace(go.Scatter(x=deg, y=protoshotXAI_class,
                    mode='lines', line=dict(width=6, dash='dash'),
                    name='Class of Best ProtoShotXAI Score'),row=1, col=1)
fig.add_trace(go.Scatter(x=deg, y=ex_matchina_class,
                    mode='lines', line=dict(width=4, dash='dashdot'),
                    name='Class of Best ExMatchina* Score'),row=1, col=1)
fig.update_xaxes(range = [0,360], row=1, col=1)
fig.update_yaxes(title_text="Class", range = [-0.1,10], row=1, col=1)

fig.add_trace(go.Scatter(x=deg, y=scores_0,
                    mode='lines', line=dict(color="#AB63FA", width=5),
                    name='Score for 0'),row=2, col=1)
fig.add_trace(go.Scatter(x=deg, y=scores_6,
                    mode='lines', line=dict(color="#FFA15A", width=5, dash='dash'),
                    name='Score for 6'),row=2, col=1)
fig.add_trace(go.Scatter(x=deg, y=scores_5,
                    mode='lines', line=dict(color="#19D3F3", width=5, dash='dashdot'),
                    name='Score for 5'),row=2, col=1)
fig.add_trace(go.Scatter(x=deg, y=scores_9,
                    mode='lines',line=dict(color="#FF6692", width=5, dash='dot'),
                    name='Score for 9'),row=2, col=1)
fig.update_xaxes(range = [0,360], row=2, col=1)
fig.update_yaxes(title_text="ProtoShotXAI Score", range = [-0.1,1.1], row=2, col=1)

fig.add_trace(go.Scatter(x=deg, y=ex_matchina_score_0,
                    mode='lines', line=dict(color="#AB63FA", width=5),
                    name='Score for 0',showlegend=False),row=3, col=1)
fig.add_trace(go.Scatter(x=deg, y=ex_matchina_score_6,
                    mode='lines', line=dict(color="#FFA15A", width=5, dash='dash'),
                    name='Score for 6',showlegend=False),row=3, col=1)
fig.add_trace(go.Scatter(x=deg, y=ex_matchina_score_5,
                    mode='lines', line=dict(color="#19D3F3", width=5, dash='dashdot'),
                    name='Score for 5',showlegend=False),row=3, col=1)
fig.add_trace(go.Scatter(x=deg, y=ex_matchina_score_9,
                    mode='lines', line=dict(color="#FF6692", width=5, dash='dot'),
                    name='Score for 9',showlegend=False),row=3, col=1)
fig.update_xaxes(title_text="Rotation Angle (deg)", range = [0,360], row=3, col=1)
fig.update_yaxes(title_text="ExMatchina* Score", range = [-0.1,1.1], row=3, col=1)

fig.update_layout(
    font=dict(
        size=20,
    )
)
fig.update_layout(title_font_size=24)

fig.show()

import plotly.io as pio
pio.write_image(fig, './results/Revolving_Six/Revolving_Six_Experiment.png', width=1400, height=1000)
