# Instructions for comparison tool:
This tool will analyze and plot different metrics comparing same data (feature vectors or distance matrix) that was process by two different models (neural network, behavioral data, etc).

**Inputs:**
*   **Input type** - choose whether to use feature vectors or disances matrix as inputs.
*   **Input file paths** - insert paths for CSV files for the chosen input for both of your models. Feature vectors file should contain the name of each sample in the first row, Distance matrix should contain the name of each sample in the first row and first column, so the (i,j) cell in the CSV should contain the distance between the samples i and j. The names of the rows/ columns should
be of the following structure: FirstName_LastName_ImageNumber
*   **Models names** - insert name for each model, that will help diffrentiate between the models.
*   **Metric** - choose the metric you wish to calculate and comapre both models with.
*   **Use alternative labeling** - when checked, the path entered in "*Path to labeling file*" will be used as labels for your data. The path should be of a CSV file that its first column contains the name of each sample and each other column contains a possible labling for the data. After inserting the CSV file, you can choose the column that will be the affective labeling for the samples.

**Outputs:**

*   tSNE plot
*   RDM plot
*   ROC curve and AUC

**Follow the following steps:**

1) Select from the top bar Runtime -> Run all

2) Fill in the inputs

3) Find the results in a pdf file named "Comparison Report" in the same drive folder as the data you provided.







# Tool Code

In [None]:
from __future__ import print_function

import sys
import os
import pandas as pd
import numpy as np
import seaborn as sns
sns.set_style('whitegrid')
import ipywidgets as widgets
import matplotlib.pyplot as plt
import pytz
from datetime import datetime
from pandas import DataFrame
from ipywidgets import interact
from sklearn import manifold
from scipy import stats
from IPython.display import display, clear_output
from google.colab import output, drive
from ipywidgets import Output
from scipy.spatial import distance_matrix
from sklearn.preprocessing import normalize
from matplotlib.backends.backend_pdf import PdfPages
from sklearn import metrics

FIRST_NET_FV_PATH = ""
FIRST_MODEL_NAME = ""
SECOND_NET_FV_PATH = ""
SECOND_MODEL_NAME = ""
LABELING_FILE_PATH = ""
USE_FVS = False
CALC_TSNE = False
CALC_RDM = False
CALC_ROC = False
CALC_KENDALL = False
LABEL_NAME = ""
STYLE = {'description_width': 'initial'}

def create_metrics():
  tz = pytz.timezone('Israel')
  time = str(datetime.now(tz)).replace('-', '_').replace(':', '_').replace('.', '_').replace(' ', '_').rsplit("_", 2)[0]
  pdf = PdfPages(FIRST_NET_FV_PATH.rsplit("/", 1)[0] + '/Comparison Report ' + time + '.pdf')
  if USE_FVS:
    print("Preparing data")
    fv1, fv2, dist_mat_1, dist_mat_2 = preprocess_fvs()
  else:
      dist_mat_1, dist_mat_2 = preprocess_distance_matrix()
  image_label_dict = process_labels(dist_mat_1, dist_mat_2)
  if CALC_TSNE:
    print("Processing tsne")
    pdf.savefig(calc_tsne(fv1, fv2, image_label_dict))
    plt.close()
  if CALC_RDM:
    print("Processing RDM")
    pdf.savefig(calc_RDM(dist_mat_1, dist_mat_2, image_label_dict))
    plt.close()
  if CALC_ROC:
    print("Processing ROC")
    pdf.savefig(calc_ROC_and_AUC(dist_mat_1, dist_mat_2, image_label_dict))
    plt.close()
  if CALC_KENDALL:
    print("Calculating Kendall")
    pdf.savefig(calc_kendall(dist_mat_1, dist_mat_2))
    plt.close()
  pdf.close()

def preprocess_fvs():
  if not FIRST_NET_FV_PATH or not SECOND_NET_FV_PATH:
    print("one or more of the paths is empty, existing..")
    sys.exit(1)
  fv1_data = pd.read_csv(FIRST_NET_FV_PATH, index_col=0)
  fv1 = DataFrame(normalize(fv1_data.values, axis=1, norm='l2'))
  dist_mat_1 = DataFrame(distance_matrix(fv1.T, fv1.T))
  fv1.columns = fv1_data.columns
  dist_mat_1.columns = fv1_data.columns
  dist_mat_1.index = fv1_data.columns
  fv2_data = pd.read_csv(SECOND_NET_FV_PATH, index_col=0)
  fv2 = DataFrame(normalize(fv2_data.values, axis=1, norm='l2'))
  dist_mat_2 = DataFrame(distance_matrix(fv2.T, fv2.T))
  fv2.columns = fv2_data.columns
  dist_mat_2.columns = fv2_data.columns
  dist_mat_2.index = fv2_data.columns
  return fv1, fv2, dist_mat_1, dist_mat_2

def preprocess_distance_matrix():
    dist_mat_1 = pd.read_csv(FIRST_NET_FV_PATH, index_col=0)
    dist_mat_2 = pd.read_csv(SECOND_NET_FV_PATH, index_col=0)
    return dist_mat_1, dist_mat_2

def process_labels(dist_mat_1, dist_mat_2):
  image_label_dict = {}
  x1_labels = dist_mat_1.columns
  x2_labels = dist_mat_2.columns
  image_name_dict1 = {image :'_'.join(image.split('_')[:-1]) if "_" in image else image for image in x1_labels}
  image_name_dict2 = {image :'_'.join(image.split('_')[:-1]) if "_" in image else image for image in x2_labels}
  image_name_dict_all = {**image_name_dict1, **image_name_dict2}
  if not LABELING_FILE_PATH:
    image_label_dict = image_name_dict_all
  else:
    labeling_df = pd.read_csv(LABELING_FILE_PATH, index_col=0)
    image_label_dict = {image :labeling_df.loc[image_name_dict_all[image], LABEL_NAME] for image in image_name_dict_all.keys()}
  return image_label_dict

def calc_tsne(fv1, fv2, image_label_dict):
  x1_labels = np.array(fv1.columns)
  labels_list1 = [image_label_dict[name] for name in x1_labels]
  x1 = np.array(fv1.loc[0:])
  X1 = x1.T
  x2_labels = np.array(fv2.columns)
  labels_list2 = [image_label_dict[name] for name in x2_labels]
  x2 = np.array(fv2.loc[0:])
  X2 = x2.T

  # tsne embedding
  tsne = manifold.TSNE()
  X1_fit = tsne.fit_transform(X1)
  x1_1 = X1_fit[:,0]
  x1_2 = X1_fit[:,1]
  tsne = manifold.TSNE()
  X2_fit = tsne.fit_transform(X2)
  x2_1 = X2_fit[:,0]
  x2_2 = X2_fit[:,1]

  # plot 
  fig, (ax1, ax2) = plt.subplots(1,2, figsize=(35,18))
  plt.subplots_adjust(left = 0.04, wspace = 0.4 )
  ax1.set_title(f't-SNE for {FIRST_MODEL_NAME}', {'fontsize': 16})
  sns.scatterplot(x=x1_1,y=x1_2,hue=labels_list1, s=120, ax=ax1) 
  ax1.legend(prop=dict(size=18), loc="center left", bbox_to_anchor=(1, 0.5))
  ax2.set_title(f't-SNE for {SECOND_MODEL_NAME}', {'fontsize': 16})
  sns.scatterplot(x=x2_1,y=x2_2,hue=labels_list2, s=120, ax=ax2)    
  ax2.legend(prop=dict(size=16), loc="center left", bbox_to_anchor=(1, 0.5))
  title = 't-SNEs Comparison'
  fig.text(0.04, 0.95, title, transform=fig.transFigure, size=24)
  return fig

# Includes ordering the matrices by lexographical order, so that both matrices 
# are orderdered the same and are comparable. 
def calc_RDM(dist_matrix_1, dist_matrix_2, image_label_dict):
  sorted_dist_matrix_1 = sort_matrix_lexicographical(dist_matrix_1, image_label_dict)
  sorted_dist_matrix_2 = sort_matrix_lexicographical(dist_matrix_2, image_label_dict)
  fig, (ax1, ax2) = plt.subplots(1,2, figsize=(45,15))
  create_heatmap(sorted_dist_matrix_1, FIRST_MODEL_NAME, ax1)
  create_heatmap(sorted_dist_matrix_2, SECOND_MODEL_NAME, ax2)
  title = 'RDMs Comparison'
  fig.text(0.04, 0.95, title, transform=fig.transFigure, size=24)
  return fig

def sort_matrix_lexicographical(df, image_label_dict):
  if LABELING_FILE_PATH:
    labels = [image_label_dict[img_name] + "_" + img_name for img_name in df.columns]
    df = df.copy(deep=True)
    df.columns = labels
    df.index = labels
  df = df.sort_index(axis=0)
  df = df.sort_index(axis=1)
  return df 

def create_heatmap(df, title, ax):
  sns.heatmap(df, square=True, ax=ax, xticklabels = False)
  ax.set_title(f'RDM for {title}', {'fontsize': 16})
  plt.setp(ax.get_yticklabels(), rotation=0, fontsize=12)

def calc_ROC_and_AUC(dist_mat_1, dist_mat_2, image_label_dict):
  fpr_array1, tpr_array1, auc1 = calc_ROC_and_AUC_for_dist_mat(dist_mat_1, image_label_dict)
  fpr_array2, tpr_array2, auc2 = calc_ROC_and_AUC_for_dist_mat(dist_mat_2, image_label_dict)
  return polt_ROC_and_AUC(fpr_array1, tpr_array1, auc1, fpr_array2, tpr_array2, auc2)

def calc_ROC_and_AUC_for_dist_mat(dist_mat, image_label_dict):
  thresholds = []
  max_dist = dist_mat.max().max()
  min_dist = dist_mat.min().min()
  thresholds = [threshold for threshold in np.linspace(min_dist, max_dist, num=50)]

  rates_counter = np.zeros(shape=(len(thresholds),4))
  for index, threshold in enumerate(thresholds):
      cnt_tp = 0
      cnt_tn = 0
      cnt_fp = 0
      cnt_fn = 0
      for name_1 in dist_mat.columns:
          for name_2 in dist_mat.columns:
              val = dist_mat.at[name_1, name_2]
              if val <= threshold:
                  if image_label_dict[name_1] == image_label_dict[name_2]:
                      cnt_tp += 1 
                  else:
                      cnt_fp += 1 
              else:
                  if image_label_dict[name_1] != image_label_dict[name_2]:
                      cnt_tn += 1 
                  else:
                      cnt_fn += 1   
      rates_counter[index] = [cnt_tp, cnt_tn, cnt_fp, cnt_fn]
      rates_counter[index] = [max(num, 1e-5) for num in rates_counter[index]]
      
  fpr_array = []
  tpr_array = []

  for i in range(len(thresholds)):
      tp, tn , fp, fn , = rates_counter[i]
      x = fp / (fp + tn)
      y = tp / (tp + fn) 
      fpr_array.append(x)
      tpr_array.append(y)
      
  auc = metrics.auc(fpr_array, tpr_array)
  return fpr_array, tpr_array, auc

def polt_ROC_and_AUC(fpr_array1, tpr_array1, auc1, fpr_array2, tpr_array2, auc2):
  fig = plt.figure(figsize=(20,16))
  plt.plot(fpr_array1, tpr_array1, color='green',label=f'{FIRST_MODEL_NAME} (AUC = {auc1})')
  plt.plot(fpr_array2, tpr_array2, color='orange',label=f'{SECOND_MODEL_NAME} (AUC = {auc2})')
  plt.plot([0, 1], [0, 1], color='black', linestyle='--')
  plt.legend()
  plt.xlabel(f'False Positive Rate')
  plt.ylabel(f'True Positive Rate')
  classifier = LABEL_NAME if LABELING_FILE_PATH else "Identity"
  plt.title(f"{FIRST_MODEL_NAME} & {SECOND_MODEL_NAME} ROC - Classification of {classifier}", size=18)
  plt.legend(fontsize=14)
  title = 'ROC Curves Comparison'
  fig.text(0.04, 0.95, title, transform=fig.transFigure, size=24)
  return fig  

def calc_kendall(dist_mat_1, dist_mat_2):
  result = stats.kendalltau(dist_mat_1, dist_mat_2)
  result_text = f'Got correlation of {result.correlation} between {FIRST_MODEL_NAME} and {SECOND_MODEL_NAME}'
  fig = plt.figure(figsize=(20,20)) # Form as fig to be able to print to pdf
  fig.clf()
  title = 'Kendalls τ'
  fig.text(0.04, 0.95, title, transform=fig.transFigure, size=24)
  fig.text(0.5, 0.5, result_text, transform=fig.transFigure, size=28, ha="center")

  return fig

def collect_label(label):
    global LABEL_NAME
    LABEL_NAME = label

def collect_metrics_data(
    first_model_name, 
    first_model_path, 
    second_model_name, 
    second_model_path,
    Choose,
    calc_tsne,
    calc_rdm,
    calc_roc,
    calc_kendall,
    use_labeling_from_file, 
    labeling_file_path
    ):
  global FIRST_NET_FV_PATH, SECOND_NET_FV_PATH, CALC_TSNE, CALC_RDM, CALC_ROC, CALC_KENDALL
  global FIRST_MODEL_NAME, SECOND_MODEL_NAME, LABELING_FILE_PATH
  FIRST_NET_FV_PATH = first_model_path
  SECOND_NET_FV_PATH = second_model_path
  FIRST_MODEL_NAME = first_model_name
  SECOND_MODEL_NAME = second_model_name
  CALC_TSNE = calc_tsne
  CALC_RDM = calc_rdm
  CALC_ROC = calc_roc
  CALC_KENDALL = calc_kendall

  if use_labeling_from_file:
    if labeling_file_path:
      LABELING_FILE_PATH = labeling_file_path
      labels = pd.read_csv(LABELING_FILE_PATH, sep=",").columns[1:]
      interact(collect_label,
               label=widgets.RadioButtons(options=labels,
                                          value=None,
                                          style=STYLE,
                                          description='Choose label:',
                                          indent=False)
               )

def clear_grid(grid):
    if len(grid) == 0:
        return

    for i in range(len(grid)):
        grid[i].close()
    global GRID
    GRID = []

try:
  clear_grid(GRID)
except:
  None
GRID = []
def collect_data():
    mertics_options = ['Feature Vectors', 'Distance Matrix']
    interact(show_input_relevant_data,
             input_type=widgets.RadioButtons(options=mertics_options,
                                             value='Feature Vectors',
                                             style=STYLE,
                                             description='Choose input type:',
                                             indent=False)
             )

def show_input_relevant_data(input_type):
    global USE_FVS

    if input_type == 'Feature Vectors':
      USE_FVS = True
      interact(collect_metrics_data,
               first_model_name=widgets.Text(value="",
                                             description='1st model name:',
                                             style=STYLE),
               first_model_path=widgets.Text(value="",
                                             description='1st feature vectors file:',
                                             style=STYLE),
               second_model_name=widgets.Text(value="",
                                              description='2nd model name:',
                                              style=STYLE),
               second_model_path=widgets.Text(value="",
                                              description='2nd feature vectors file:',
                                              style=STYLE),
               Choose=widgets.HTML(value="metrics:",),
               calc_tsne=widgets.Checkbox(value=True,
                                          description='t-SNE',
                                          disabled=False,
                                          indent=True
                                          ),
              calc_rdm=widgets.Checkbox(value=True,
                                        description='RDM',
                                        isabled=False,
                                        indent=True
                                        ),
              calc_roc=widgets.Checkbox(value=True,
                                        description='ROC & AUC',
                                        isabled=False,
                                        indent=True
                                        ),
              calc_kendall=widgets.Checkbox(value=True,
                                            description='Kendalls τ',
                                            isabled=False,
                                            indent=True
                                            ),
               use_labeling_from_file=widgets.Checkbox(value=False,
                                                       description='Use alternative labeling',
                                                       indent=False),
               labeling_file_path=widgets.Text(value="",
                                               description='Path to labeling file:',
                                               style=STYLE)
               )
    else:
      USE_FVS = False
      interact(collect_metrics_data,
               first_model_name=widgets.Text(value="",
                                             description='1st model name:',
                                             style=STYLE),
               first_model_path=widgets.Text(value="",
                                             description='1st distance matrix file:',
                                             style=STYLE),
               second_model_name=widgets.Text(value="",
                                              description='2nd model name:',
                                              style=STYLE),
               second_model_path=widgets.Text(value="",
                                              description='2nd distance matrix file:',
                                              style=STYLE),
               Choose=widgets.HTML(value="metrics:",),
               calc_tsne=widgets.Checkbox(value=False,
                                          description='t-SNE',
                                          disabled=True,
                                          indent=True
                                          ),               
              calc_rdm=widgets.Checkbox(value=True,
                                        description='RDM',
                                        isabled=False,
                                        indent=True
                                        ),
              calc_roc=widgets.Checkbox(value=True,
                                        description='ROC & AUC',
                                        isabled=False,
                                        indent=True
                                        ),
              calc_kendall=widgets.Checkbox(value=True,
                                            description='Kendalls τ',
                                            isabled=False,
                                            indent=True
                                            ),
              use_labeling_from_file=widgets.Checkbox(value=False,
                                                      description='Use alternative labeling',
                                                      indent=False),
               labeling_file_path=widgets.Text(value="",
                                               description='Path to labeling file:',
                                               style=STYLE)
               )

# important to pass the output object for print() to be captured
def on_button_clicked(b):
    """
    when the user clicked on the 'Run!' button
    this function calls to the visualize function with the collected global params
    :param b = the relevant button

    :return: None
    """

    with out:
        clear_output(wait=True)
        output.clear()
        clear_grid(GRID)
        metrics = create_metrics()

        # try:
        #     metrics = create_metrics()
        # except Exception as e:
        #     raise
    print("done!")
title_button = widgets.HTML(
    value="Comparison tool - analyze two NNs data",
)
button_run = widgets.Button(description="Run!")
button_run.on_click(on_button_clicked)
out = widgets.Output(layout={'border': '1px solid black'})
def run_tool():
    print("Mount Google Drive where you have your images and needed CSVs")
    drive.mount('/content/drive')
    print("Done mounting Google Drive. \n")

    display(title_button)
    try:
        collect_data()
    except FileNotFoundError:
        pass
    display(button_run)
    display(out)



# Two neural networks comparison tool

In [None]:
run_tool()

Mount Google Drive where you have your images and needed CSVs
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Done mounting Google Drive. 



HTML(value='Comparison tool - analyze two NNs data')

interactive(children=(RadioButtons(description='Choose input type:', options=('Feature Vectors', 'Distance Mat…

Button(description='Run!', style=ButtonStyle())

Output(layout=Layout(border='1px solid black'))

done!
