In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import glob
import scipy.ndimage as ndimage
import pingouin as pg
import re

import sys
from plotting_functions import *

plt.rcParams["font.family"] = "Arial"
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

plt.rcParams.update({'mathtext.default':  'regular' })



# Fits

In [None]:
ONLY_COMMON_CELLS = True
DIR = '../neural_fitting/processed_fits/final_wt_60_shuffle_test/'

FILES = {
    'LN': [
        'LN_wt_60noise_200components_CVshuffle_test.npy'
    ],
    'Temp. Pred.\n(recurrent)': [
        'TP_1step_36x36_raw_1offset_2padding_wt_60noise_200components_CVshuffle_test.npy'
    ],
    'Temp. Pred.\n(feedforward)': [
        'rnn_refactor_ff_raw_noPCA_1offset_2padding_wt_60noise_200components_CVshuffle_test.npy'
    ],
    'Temp. Pred.\n(untrained)': [
        'TP_1step_untrained_36x36_raw_1offset_2padding_wt_60noise_200components_CVshuffle_test.npy'
    ],
    'Inpainting': [
        'rnn_refactor_final_masked_raw_noPCA_1offset_2padding_wt_60noise_200components_CVshuffle_test.npy'
    ],
    'Denoise': [
        'rnn_refactor_final_denoiseSNR3dataset_raw_noPCA_1offset_2padding_wt_60noise_200components_CVshuffle_test.npy'
    ],
    'Sparse\nautoencoder': [
        'rnn_refactor_final_autoencoder_raw_noPCA_1offset_2padding_wt_60noise_200components_CVshuffle_test.npy'
    ]
}

COLORS = {
    'LN': 'tab:gray',
    'Temp. Pred.\n(recurrent)': 'tab:red',
    'Temp. Pred.\n(feedforward)': 'tab:red',
    'Temp. Pred.\n(untrained)': 'tab:red',
    'Sparse\nautoencoder': 'tab:pink',
    'Denoise': 'tab:orange',
    'Inpainting': 'tab:purple'
}

def get_common_cells (fit_dir):
    cell_arr = []
    
    for f in glob.glob(fit_dir + '*.npy'):
        data = np.load(f, allow_pickle=True)
        cells = [c['cell_name'] for c in data if np.isfinite(c['cc_norm_test'])]
        cell_arr.append(cells)
        
    return list(set(cell_arr[0]).intersection(*cell_arr))

def get_best_performing_layer (root_dir, paths, cells):
    mn_val  = []
    
    mn_test   = []
    er_test   = []
    vals_test = []
    
    for f in paths:
        data = np.load(root_dir + f, allow_pickle=True)
        
        val_data = [c['cc_norm'] for c in data if c['cell_name'] in cells]
        mn_val.append(np.nanmean(val_data))
        
        test_data = [c['cc_norm_test'] for c in data if c['cell_name'] in cells]
        test_data, _ = list(zip(*sorted(zip(test_data, cells), key=lambda k: k[1])))
        
        mn_test.append(np.mean(test_data))
        er_test.append(np.std(test_data)/(len(test_data)**0.5))
        vals_test.append(test_data)
        
    best_layer = np.argmax(mn_val)
        
    return mn_test[best_layer], er_test[best_layer], vals_test[best_layer]
        
    

COMMON_CELLS = get_common_cells (DIR)
    
data_arr   = []
mean_arr   = []
error_arr  = []
label_arr  = []
color_arr  = []

for model_name, paths in FILES.items():
    mn, er, vals = get_best_performing_layer(DIR, paths, COMMON_CELLS)
            
    data_arr.append(vals)
    mean_arr.append(mn)
    error_arr.append(er)
    color_arr.append(COLORS[model_name])
    label_arr.append(model_name)

for i in range(len(data_arr)):
    model_idx_a = i
    model_idx_b = label_arr.index('Temp. Pred.\n(recurrent)')
    
    print(f'{label_arr[model_idx_a]} mean = {np.nanmean(data_arr[model_idx_a]):.3g}')
    
    if i != model_idx_b:
        display(pg.ttest(data_arr[model_idx_a], data_arr[model_idx_b], paired=True))
    print('\n\n')


fig = plt.figure()
x = np.arange(len(mean_arr))
bars = plt.bar(x, mean_arr, yerr=error_arr)
plt.plot(
    [-0.5, len(mean_arr)-0.5],
    [
        mean_arr[label_arr.index('Temp. Pred.\n(recurrent)')],
        mean_arr[label_arr.index('Temp. Pred.\n(recurrent)')]
    ],
    '--', c='black'
)
for c, b in zip(color_arr, bars):
    b.set_facecolor(c)
plt.xticks(x, label_arr, rotation=0)
plt.ylabel('$CC_{norm}$')
format_plot(plt.gca(), fontsize=18)
fig.set_size_inches(16, 4)
plt.savefig('./figures/figure4/cc_norm_bar.pdf', bbox_inches='tight')
plt.show()