In [2]:
import numpy as np
import os
from os.path import join as joinpath
from matplotlib import pyplot as plt
import pandas as pd

In [35]:
FIRST_DIM_TARGET_SIZE = 10

def compute_correlation_from_activity_matrix(activity_matrix):
    if activity_matrix.shape[0] != FIRST_DIM_TARGET_SIZE:
        activity_matrix = activity_matrix[0]
    # ensure activity_matrix is a 2D array
    activity_matrix = activity_matrix.reshape(activity_matrix.shape[0], -1)
    # compute correlation score for matrix of shape [T, N_NEURONS]
    corr_matrix = np.corrcoef(activity_matrix.T)
    corr_matrix[np.isnan(corr_matrix)] = 0
    return corr_matrix

In [None]:
def compute_average_layer_correlation_scores(root_data_path : str, save_path : str = None):
    layer_correlations_correct = {}
    layer_correlations_incorrect = {}
    global_sample_count = 0
    for  label_dir in os.listdir(root_data_path):
        # label level
        print(f"Label: {label_dir}")
        label_dir_path = joinpath(root_data_path, label_dir)
        if not os.path.isdir(label_dir_path):
            continue
        # sample level
        if label_dir == "label_3":
            # skip label 3, as full noise samples result in predicting label 3
            continue
        for num_samples, sample_dir in enumerate(os.listdir(label_dir_path)):
            if num_samples >= 2:
                global_sample_count += num_samples + 1
                break
            print(f"Sample {num_samples + 1} ")
            sample_dir_path = joinpath(label_dir_path, sample_dir)
            if not os.path.isdir(sample_dir_path):
                continue
            
            original_actiities_dir = joinpath(sample_dir_path, 'correct_original')
            incorrect_noise_activities_dir = [joinpath(sample_dir_path, f) 
                                                for f in os.listdir(sample_dir_path)
                                                if 'incorrect' in f and 'noise' in f][0]
            correct_layer_activity_files = [joinpath(original_actiities_dir, f) 
                                            for f in os.listdir(original_actiities_dir)
                                            if 'spike' in f]
            for layer_activity in correct_layer_activity_files:
                layer_name, layer_num = layer_activity.split('_')[-2:]
                layer_num = layer_num.split('.')[0] 
                layer_name = f"{layer_name}_{layer_num}"
                corr_matrix = compute_correlation_from_activity_matrix(np.load(layer_activity))
                if layer_name not in layer_correlations_correct:
                    layer_correlations_correct[layer_name] = corr_matrix
                else:
                    layer_correlations_correct[layer_name] += corr_matrix

            incorrect_layer_activity_files = [joinpath(incorrect_noise_activities_dir, f) 
                                                for f in os.listdir(incorrect_noise_activities_dir)
                                                if 'spike' in f]
            for layer_activity in incorrect_layer_activity_files:
                layer_name, layer_num = layer_activity.split('_')[-2:]
                layer_num = layer_num.split('.')[0] 
                layer_name = f"{layer_name}_{layer_num}"
                corr_matrix = compute_correlation_from_activity_matrix(np.load(layer_activity))
                if layer_name not in layer_correlations_incorrect:
                    layer_correlations_incorrect[layer_name] = corr_matrix
                else:
                    layer_correlations_incorrect[layer_name] += corr_matrix


    # average the correlation scores and save them to npy arrays
    for layer_name in layer_correlations_correct:
        layer_correlations_correct[layer_name] /= num_samples
        
    for layer_name in layer_correlations_incorrect:
        layer_correlations_incorrect[layer_name] /= num_samples

    return layer_correlations_correct, layer_correlations_incorrect

Label: label_5
Sample 1 


  c /= stddev[:, None]
  c /= stddev[None, :]


Sample 2 
Label: label_6
Sample 1 
Sample 2 
Label: label_4
Sample 1 
Sample 2 
Label: label_2
Sample 1 
Sample 2 
Label: label_7
Sample 1 
Sample 2 
Label: label_1
Sample 1 
Sample 2 
Label: label_9
Sample 1 
Sample 2 
Label: label_8
Sample 1 
Sample 2 
Label: label_3
Label: label_0
Sample 1 
Sample 2 
Label: adversarial_test_results.json


In [37]:
# plot pairs of correct and incorrect layer correlations
import seaborn as sns
fig, axs = plt.subplots(len(layer_correlations_correct), 2, figsize=(10, 5 * len(layer_correlations_correct)))

for i, layer_name in enumerate(layer_correlations_correct):
    correct_corr = layer_correlations_correct[layer_name]
    incorrect_corr = layer_correlations_incorrect[layer_name]
    sns.heatmap(correct_corr, ax=axs[i, 0],
                 cmap='viridis', cbar=True, vmin=-1, vmax=1)
    axs[i, 0].set_title(f'Correct Layer Correlation - {layer_name}')
    sns.heatmap(incorrect_corr, ax=axs[i, 1],
                 cmap='viridis', cbar=True, vmin=-1, vmax=1)
    
    axs[i, 1].set_title(f'Incorrect Layer Correlation - {layer_name}')
    

KeyboardInterrupt: 

Error in callback <function flush_figures at 0x75b6f3953250> (for post_execute), with arguments args (),kwargs {}:


KeyboardInterrupt: 