In [None]:
# This notebook fits MNE model parameters to stimulus spectrgram segments that have been normalized
# to fall on to a fixed radius sphere in n-dim space

import numpy as np
from oe_acute import MNE2
from oe_acute import trial_utils as tu
# from oe_acute import reconstruct as rct
import os
import pickle
import glob
import logging
import pandas as pd
import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
from scipy.io import wavfile
from scipy import signal
from joblib import Parallel,delayed
from importlib import reload


logger = logging.getLogger()
handler = logging.StreamHandler()
formatter = logging.Formatter(
        '%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

from ceciestunepipe.file import bcistructure as et
from ceciestunepipe.util import sglxutil as sglu

plt.rcParams['lines.linewidth'] = 0.1

os.environ["KILOSORT2_PATH"] = '/mnt/cube/earneodo/repos/Kilosort2'
os.environ["KILOSORT3_PATH"] = '/home/finch/repos/Kilosort'
os.environ["NPY_MATLAB_PATH"] = '/mnt/cube/earneodo/repos/npy-matlab'

In [None]:
from ceciestunepipe.file import bcistructure as et
from ceciestunepipe.util import sglxutil as sglu
from ceciestunepipe.util import syncutil as su
from ceciestunepipe.util import wavutil as wu
logger.info('all modules loaded')

In [None]:
def get_window_spikes(spk_df, clu_list, start_sample, end_sample):
    onset = start_sample
    offset = end_sample
    
    spk_t = spk_df.loc[spk_df['times'].between(onset, offset, inclusive=False)]
    
    spk_arr = np.zeros((clu_list.size, offset - onset))

    for i, clu_id in enumerate(clu_list):
        clu_spk_t = spk_t.loc[spk_t['clusters']==clu_id, 'times'].values
        spk_arr[i, clu_spk_t - onset] = 1
    return spk_arr

def collect_bout(bout_dict, bout_idx, t_pre, t_post, spk_df, clu_list, mic_stream):
    s_f_ap = bout_dict['s_f_ap_0']
    s_f_wav = bout_dict['s_f']
    
    ## add the length of the bout (in seconds) to the end of the segment
    t_post += int(bout_dict['len_ms'][bout_idx] * 0.001)
       
    start_ap = bout_dict['start_sample_ap_0'][bout_idx] + int(s_f_ap* t_pre)
    end_ap = bout_dict['start_sample_ap_0'][bout_idx] + int(s_f_ap * t_post)

    start_wav = bout_dict['start_sample_wav'][bout_idx] + int(s_f_wav * t_pre)
    end_wav = bout_dict['start_sample_wav'][bout_idx] + int(s_f_wav * t_post)


    # get the streams/spike array
    spk_arr = get_window_spikes(spk_df, clu_list, int(start_ap), int(end_ap))
    mic_arr = mic_stream.flatten()[start_wav: end_wav]
    
    return spk_arr, mic_arr

# Load and Format Data

## dont need to run if loading MNE 

In [None]:
bird = 's_b1253_21'
all_bird_sess = et.list_sessions(bird)

sess_par = {'bird': 's_b1253_21',
           'sess': '2021-07-18',
           'probe': 'probe_0',
           'sort': 'sort_0'}

exp_struct = et.get_exp_struct(sess_par['bird'], sess_par['sess'], sort=sess_par['sort'])
ksort_folder = exp_struct['folders']['ksort']
raw_folder = exp_struct['folders']['sglx']
sess_epochs = et.list_sgl_epochs(sess_par)
epoch = sess_epochs[0]
sess_par['epoch'] = epoch
sess_par['sort'] = 'sort_0'
exp_struct = et.sgl_struct(sess_par, epoch)
sgl_folders, sgl_files = sglu.sgl_file_struct(exp_struct['folders']['sglx'])
bout_dict_path = os.path.join(exp_struct['folders']['sort'], 'bout_dict_ap0.pkl')
with open(bout_dict_path, 'rb') as handle:
    bout_dict = pickle.load(handle)
wav_mic_path = exp_struct['files']['wav_mic']
s_f_wav, mic_stream = wu.read_wav_chan(exp_struct['files']['wav_mic'])
#logger.info('Will load cluster, spike dataframes from ' + exp_struct2['folders']['sort'])
clu_df_path = os.path.join(exp_struct['folders']['sort'], 'clu_df.pkl')
clu_df = pd.read_pickle(clu_df_path)
spk_df_path = os.path.join(exp_struct['folders']['sort'], 'spk_df.pkl')
spk_df = pd.read_pickle(spk_df_path)
print(bout_dict.keys())
clu_df.head()

In [None]:
t_pre = 0 # window starts some time previous to detected bout onset (seconds)
t_post = 0 # window ends soometime posterior to detected bout offset (seconds)

clu_list_ra_all = np.unique(clu_df.loc[(clu_df['nucleus'].isin(['ra'])), 
                                 'cluster_id'])
#clu_list = np.unique(clu_df['cluster_id'])

clu_list = clu_list_ra_all
spk_arr_list = []
mic_arr_list = []
clu_id_arr_list = []

## get the bouts arrays into a bout_dict
for bout_idx, start in enumerate(bout_dict['start_ms']):
    spk_arr, mic_arr = collect_bout(bout_dict, bout_idx, t_pre, t_post, spk_df, clu_list, mic_stream)
    spk_arr_list.append(spk_arr.astype(np.short))
    mic_arr_list.append(mic_arr.astype(np.int16))
    clu_id_arr_list.append(np.array(clu_list))
    
bout_dict['t_pre_ms'] = t_pre * 1000
bout_dict['t_post_ms'] = t_post * 1000
bout_dict['spk_arr'] = spk_arr_list
bout_dict['mic_arr'] = mic_arr_list
bout_dict['clu_id_arr'] = clu_id_arr_list
keys_to_df = ['start_sample_nidq', 'start_sample_ap_0', 'len_ms', 'spk_arr', 'mic_arr', 'clu_id_arr']
bout_dict_df = {k: bout_dict[k] for k in keys_to_df}
bout_df = pd.DataFrame.from_dict(bout_dict_df)
bout_df['bout_id'] = bout_df['start_sample_nidq']
# bout_df1['hemisphere']='right'

bout_df.head()

In [None]:
print(bout_df.shape)
print(spk_arr.shape)
print(len(spk_arr_list))

## Format Stimuli and Responses

In [None]:
#input - stimuli and responses in 20khz and 30khz for each bout in a given recording session 
#output - stimuli and responses in the same bins as the stimuli. abritrary clusters and cluster ratings to make the code work

s_f_ap = 30000
ms_bin_size = int(0.001 * s_f_ap)

nums = bout_df.shape[0]
spike_responses_all = []
spectrograms_all = []
stim_names = [] 
clusters = list(range(bout_df.iloc[1]['spk_arr'].shape[0]))
cluster_ratings = [1] * bout_df.iloc[1]['spk_arr'].shape[0]

#mic_to_sxx = MNE2.preprocess_sig_Kozlov(bout_df.iloc[1]['mic_arr'])

mic_fs = 25000 #change this to be what you want everything to be

def bin_spikes_ls(spk_arr, spk_len, n_bins_tgt):
    spike_clusters, spike_times = np.nonzero(spk_arr)
    fractional_spike_times = spike_times / spk_len
    spike_bins = (fractional_spike_times * n_bins_tgt).astype(np.int)
    sxx_binned_spikes = np.zeros((spk_arr.shape[0], n_bins_tgt), dtype=np.int)
    sxx_binned_spikes[spike_clusters, spike_bins] = 1
    return sxx_binned_spikes

for x in range(nums):
    single_spk_arr_temp = bout_df.iloc[x]['spk_arr']
    mic_to_sxx = MNE2.preprocess_sig_Kozlov(bout_df.iloc[x]['mic_arr'], fs_orig=25000)
    n_bins_tgt = mic_to_sxx.shape[1] # get number of spectrogram time bins
    #single_spk_arr_ms_temp = pu.coarse(single_spk_arr_temp, ms_bin_size)
    sxx_binned_spikes = bin_spikes_ls(single_spk_arr_temp, single_spk_arr_temp.shape[1], n_bins_tgt)[..., np.newaxis]
    spike_responses_all.append(sxx_binned_spikes)
    spectrograms_all.append(np.copy(mic_to_sxx))
    stim_names.append(np.copy(bout_df.iloc[x]['bout_id']))


# Remember to Change Output File Name!

In [None]:
# GLOBAL parameters for this notebook 

#### MNE PARAMETERS ####
# exp_path = '/mnt/cube/btheilma/experiments/B1240/block3/'
# sort_path =  '/mnt/cube/btheilma/sorting/B1240/block3/'
# stim_path = '/home/AD/btheilma/MET_prediction_expt/'
# output_folder = '/home/AD/btheilma/scratch/stim_recon_project_out/MNE_pfinals/'


exp_path = '/mnt/cube/lstanwic/reducted_data/s_b1253_21/'  # contains the experiment dictionary (trials)
#sort_path =  exp_path # contains the sorted data (kilosort 2)
#stim_path = '/mnt/cube/lstanwic/reducted_data/B1240/block3/MET_prediction_expt/'    # contains the stimuli wav files

output_file_mne_full = os.path.join(exp_path, 'MNE_full_res_b1253_20210718_all_no_silence.pkl')
output_file_mne_pfinal = os.path.join(exp_path, 'MNE_pfinal_res_b1253_20210718_all_no_silence.pkl')

# Bird Parameters
bird = sess_par['bird']
sess = sess_par['sess']
probe = sess_par['probe']
sort = sess_par['sort']

# Number of jackknives to use during fitting
n_jackknives = 2

### Stimulus segment params
nsegbins = 16

### Spectrogram Averaging params
avg_nrows = 2
avg_ncols = 3

### Normalization params
radius = 10  # Project stimulus segments to sphere of <radius> radius
stim_zscore_thresh = 30

# number of stimulus dimensions (16 x 16)
n_dim = 256

# FIT MNE Parameters
### MNE stims and responses, response predictions, and shuffled responses
### skip this if you just want to load the data

In [None]:
# Preprocess stimuli and extract response tensors
stim_names = stim_names #stim names = bout names 
stim_spectrograms = spectrograms_all 
stim_responses = spike_responses_all #these are binned to same timepoints as spectrograms (cells, units, ?)
clusters = clusters #1-# of clusters
cluster_ratings = cluster_ratings #all 1s 
n_cells = len(clusters)

In [None]:
#####################################################
# compute shuffle response tensors
# stim_mne_segs; resp_mne

# fully shuffle each stimulus response tensor
fs_stim_responses = [MNE2.full_shuffle_response(x) for x in stim_responses]

# mask shuffle each stimulus response tensor
ms_stim_responses = [MNE2.mask_shuffle_response(x) for x in stim_responses]

# permute time bins across all stims, same for each trial
ps_stim_responses = MNE2.permute_all_stim_time_bins(stim_responses)

#####################################################

# different variations on the response
# just do original and shuffled responses
response_classes = [stim_responses, fs_stim_responses]

# the question is do I concatenate all the stimuli and responses and then split into training and test, 
# or do I pull out training and test *whole stimuli* and then concatenate each.  
# the problem with the latter is that it's not gauranteed that the training and test sets
# have the same number of data points for each iteration. 

# Canonical order of operations:
# extract_stim_resp
# do shuffling
# kozlov_avg_stim_resp
# segment_stim_resp
# reshape_stim_segments
# preprocess_MNE

# og = original
# fs = fullshuffle
# ms = mask shuffle
# ps = permute time bins across all stims, same permutation for each trial

kozlov_stim_responses = [MNE2.kozlov_avg_stim_resp(stim_spectrograms, x, avg_nrows, avg_ncols) for x in response_classes]

stim_response_segments = [MNE2.segment_stim_resp(x[0], x[1], nsegbins=nsegbins, mode='forward') for x in kozlov_stim_responses]

stim_response_segments_reshape = [(MNE2.reshape_stim_segments(x[0]), x[1]) for x in stim_response_segments]

stim_resp_mne = [MNE2.preprocess_MNE(x[0], x[1], radius) for x in stim_response_segments_reshape]

stim_resp_mne_train_test = [MNE2.prepare_MNE_train_test_set(x[0], x[1]) for x in stim_resp_mne]

# for each class, we have:
# (stim train, resp train, stim test, resp test)

In [None]:
#n_cells

In [None]:
mne_results = []
full_mne_results = []
for r_class in stim_resp_mne_train_test:
    print('Fitting MNE...')
    stims_train, resps_train, stims_test, resps_test = r_class
    mne_res = Parallel(n_jobs=20)(delayed(MNE2.fit_MNE)(stims_train, resps_train[:, idx][:, np.newaxis], order=2) for idx in range(n_cells))
    print('MNE fit done')
    mne_results.append(mne_res)

In [None]:
# Save MNE parameter fit results, the preprocessed stimuli, and the shuffled responses
results_dict = {'bird': bird, 
                'sess' : sess,
                'probe' : probe,
                'sort' : sort,
                'units': clusters, 
                'nsegbins': nsegbins, 
                'n_rows_avg':avg_nrows, 
                'n_cols_avg':avg_ncols, 
                'stim_resp': stim_resp_mne_train_test,
                'results': mne_results, 
                'stim_names':np.array(stim_names, dtype='object'), 
                'stim_spectrograms': stim_spectrograms,
                'stim_response_classes': response_classes,
                'stim_zscore_thresh': stim_zscore_thresh}

print('Saving results...')
with open(output_file_mne_pfinal, 'wb') as f:
    pickle.dump(results_dict, f)
print('Done')

# Plot MNES

In [None]:
original_pfinals = mne_results[0]
tags = []
for num in clu_list_ra_all:
    tags.append(clu_df['KSLabel'][num])
# tags_df = pd.DataFrame(tags)
# pd.set_option('display.max_rows', tags_df.shape[0]+1)
# #print(tags_df[0])
# print(tags[0])

def plot_MNE(pfinal, unit, bird, block, n_eigvec_to_display=6, sdim=256, rf_shape=(16,16), color_map='jet', figure_output_path=None):

    ''' 
    This function produces the MNE summary plot given a list of MNE output parameters (pfinals)
    '''

    interp = None

    a_avg = pfinal[0]
    h_avg = np.reshape(pfinal[1:sdim+1], rf_shape)
    j_avg = np.reshape(pfinal[-sdim*sdim:], (sdim, sdim))
    

    eigval, eigvec = np.linalg.eig(j_avg)
    # display recovered rf to generative rf
    eigval_ixd = [(eigval[i],i) for i in range(len(eigval))]
    eigval_ixd.sort()
    sorted_eigval, permt = zip(*eigval_ixd)


    topn_negative = [np.reshape(eigvec[:, permt[x]], rf_shape) for x in range(n_eigvec_to_display)]
    topn_positive = [np.reshape(eigvec[:, permt[x]], rf_shape) for x in range(-1, -(n_eigvec_to_display+1), -1)]


    fig = plt.figure(constrained_layout=True, figsize=(13.3, 10))
    gs = fig.add_gridspec(4, n_eigvec_to_display)

    neg_axs = []
    pos_axs = []
    for idx, v in enumerate(topn_negative):
        ax = fig.add_subplot(gs[2 +idx//3, idx%3])
        neg_axs.append(ax)
        ax.imshow(v, cmap=color_map, interpolation=interp, origin='lower', aspect='equal')
        ax.set_title('{:.3f}'.format(sorted_eigval[idx]), fontsize=14)
        ax.tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)


    for idx, v in enumerate(topn_positive):
        ax = fig.add_subplot(gs[2 + idx//3, idx%3 +3])
        neg_axs.append(ax)
        ax.imshow(v, cmap=color_map, interpolation=interp, origin='lower', aspect='equal')
        ax.set_title('{:.3f}'.format(sorted_eigval[-(idx+1)]), fontsize=14)
        ax.tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)


    ax_eigs = fig.add_subplot(gs[0:2, 3:])
    ax_eigs.plot(sorted_eigval, 'k.')
    #ax_eigs.plot(sorted_eigval[:n_eigvec_to_display], 'rx')
    #ax_eigs.plot(range(sdim-n_eigvec_to_display, sdim+1), sorted_eigval[-(n_eigvec_to_display+1):], 'rx')
    ax_eigs.set_title('Sorted Eigenvalues of J Matrix', fontsize=18)
    ax_eigs.set_ylabel('Value', fontsize=16)
    ax_eigs.set_xlabel('Index', fontsize=16)
    ax_eigs.tick_params(labelsize=14)

    ax_a = fig.add_subplot(gs[0:2, :3])
    ax_a.imshow(h_avg, cmap=color_map, interpolation=interp, origin='lower', aspect='equal')
    ax_a.set_title('Linear Feature', fontsize=18)
    ax_a.tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)
    tag = tags[unit]
    fig.suptitle("{} Block: {} Unit: {} KSLabel: {}".format(bird, block, unit, tag), fontsize=20)
    
    return fig 


In [None]:
len(original_pfinals)

In [None]:
for unit, pfinal in enumerate(original_pfinals):
    plot_MNE(pfinal, unit, bird, sess, n_eigvec_to_display=6, sdim=256, rf_shape=(16,16), color_map='jet', figure_output_path=None)

In [None]:
# full shuffled
for unit, pfinal in enumerate(mne_results[1]):
    plot_MNE(pfinal, unit, bird, sess, n_eigvec_to_display=6, sdim=256, rf_shape=(16,16), color_map='jet', figure_output_path=None)

-----
# Load MNE results
#### you can do this instead of fitting the MNE above every time for the same data, if that makes sense

In [None]:
# load previously computed parameters / results

import os


output_file_mne_pfinal = os.path.join(exp_path, 'MNE_pfinal_res_b1253_20210718_all.pkl')


# contains stim spectrograms and response class matrices
with open(output_file_mne_pfinal, 'rb') as f:
    results_dict_pfinal = pickle.load(f)


n_cells = len(results_dict_pfinal['units'])

# pfinals
# this is by class
mne_results = results_dict_pfinal['results']
print(len(mne_results))

# spectrograms and responses
stim_spectrograms = results_dict_pfinal['stim_spectrograms']
response_classes = results_dict_pfinal['stim_response_classes']
avg_nrows = results_dict_pfinal['n_rows_avg']
avg_ncols = results_dict_pfinal['n_cols_avg']
nsegbins = results_dict_pfinal['nsegbins']

kozlov_stim_responses = [MNE2.kozlov_avg_stim_resp(stim_spectrograms, x, avg_nrows, avg_ncols) for x in response_classes]

stim_response_segments = [MNE2.segment_stim_resp(x[0], x[1], nsegbins=nsegbins, mode='forward') for x in kozlov_stim_responses]

stim_response_segments_reshape = [(MNE2.reshape_stim_segments(x[0]), x[1]) for x in stim_response_segments]

stim_resp_mne = [MNE2.preprocess_MNE(x[0], x[1], radius) for x in stim_response_segments_reshape]


----
# Reconstruction Code
#### Performing stimulus reconstructions using overlap-add

## first, load things


In [None]:
import scipy.spatial
from scipy.spatial import distance

def rec_compare(orgin_comp, recon_comp):
    flat_orgin_comp = np.ravel(orgin_comp)
    flat_recon_comp = np.ravel(recon_comp)
    cosine_diff = scipy.spatial.distance.cosine(flat_orgin_comp, flat_recon_comp)
    return cosine_diff

In [None]:
def plot_comparisons(class_comparison):
    
    fig, axs = plt.subplots(1, 1, sharex=True)

    # True Stimulus
    axs.plot(class_comparison, linewidth=1)
    axs.set_yticks(np.arange(0, 1.5, step=0.2)) 
    axs.set_ylabel('Cosine Difference')
    axs.set_xlabel('Time Bin')
    plt.tight_layout()
    plt.show()

In [None]:
from statistics import mean 

def plot_true_recon2(true_stim, recon_stim, comparison, comparison_shuff=[0], color_map='jet', figsize=(16,9)):
    
    fig, axs = plt.subplots(3, 1, figsize=figsize, sharex=True)
    
    # True Stimulus
    axs[0].imshow(true_stim, cmap=color_map, origin='lower', aspect='auto', interpolation=None)
    axs[0].set_title('True Stimulus', fontsize=15)
    axs[0].tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)
    
    # recon stim
    axs[1].imshow(recon_stim, cmap=color_map, origin='lower', aspect='auto', interpolation=None)
    axs[1].set_title('Reconstructed Stimulus', fontsize=15)
    axs[1].tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)
    
    mean_comp = round(mean(comparison),3)
    mean_comp_shuff = round(mean(comparison_shuff),3)
    if len(comparison_shuff) > 2:
        axs[2].plot(comparison_shuff, linewidth=0.5, color = 'black', label = f'full shuffle reconstruction - mean: {mean_comp_shuff}')
        axs[2].plot(comparison, linewidth=1, label = f'true reconstruction - mean: {mean_comp}')
        axs[2].legend(loc = 'lower left', fontsize = 'large')
        axs[2].set_yticks(np.arange(0, 1.5, step=0.2)) 
        axs[2].set_title(f'Cosine Differences for True and Shuffled Responses', fontsize=15)
        axs[2].tick_params(labelbottom=False, direction='in', bottom=False)
    else:
        axs[2].plot(comparison, linewidth=1, label = f'true reconstruction - mean: {mean_comp}')
        axs[2].set_yticks(np.arange(0, 1.5, step=0.2)) 
        #axs[2].set_ylabel('Cosine Difference')
        axs[2].set_title(f'Cosine Difference - Mean: {mean_comp}', fontsize=15)
        axs[2].tick_params(labelbottom=False, direction='in', bottom=False)
    
    plt.tight_layout()
    plt.show()
    return fig

In [None]:
def plot_true_recon_shuff(true_stim, recon_stim, comparison, comparison_shuff=[0], color_map='jet', figsize=(16,9)):
    
    fig, axs = plt.subplots(3, 1, figsize=figsize, sharex=True)
    
    # True Stimulus
    axs[0].imshow(true_stim, cmap=color_map, origin='lower', aspect='auto', interpolation=None)
    axs[0].set_title('True Stimulus', fontsize=15)
    axs[0].tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)
    
    # recon stim
    axs[1].imshow(recon_stim, cmap=color_map, origin='lower', aspect='auto', interpolation=None)
    axs[1].set_title('Reconstructed Stimulus', fontsize=15)
    axs[1].tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)
    
    mean_comp = round(mean(comparison),3)
    mean_comp_shuff = round(mean(comparison_shuff),3)
    if len(comparison_shuff) > 2:
        axs[2].plot(comparison_shuff, linewidth=0.5, color = 'black', label = f'full shuffle reconstruction - mean: {mean_comp_shuff}')
        axs[2].plot(comparison, linewidth=1, label = f'true reconstruction - mean: {mean_comp}')
        axs[2].legend(loc = 'lower left', fontsize = 'large')
        axs[2].set_yticks(np.arange(0, 1.5, step=0.2)) 
        axs[2].set_title(f'Cosine Difference', fontsize=15)
        axs[2].tick_params(labelbottom=False, direction='in', bottom=False)
    else:
        axs[2].plot(comparison, linewidth=1, label = f'true reconstruction - mean: {mean_comp}')
        axs[2].set_yticks(np.arange(0, 1.5, step=0.2)) 
        #axs[2].set_ylabel('Cosine Difference')
        axs[2].set_title(f'Cosine Difference - Mean: {mean_comp}', fontsize=15)
        axs[2].tick_params(labelbottom=False, direction='in', bottom=False)
    
    plt.tight_layout()
    plt.show()
    return fig

### Compute Spectral Radius For Each Cell MNE & Pull Out Useful Cells

In [None]:
from numpy.linalg import eigvalsh as brad
#get J matrix for each cell and compute spectral radi
spectra = []
J_np = [np.reshape(x[(n_dim+1):], (n_dim, n_dim)) for x in mne_results[0]]
for J in J_np:
    radium = max(abs(brad(J)))
    spectra.append(radium)
#print(spectra)
fig, ax = plt.subplots(1,1)
bht,bins,x = ax.hist(spectra, bins=100)
print(len(spectra))
spectra_array = np.array(spectra)
spectra_array_thresholded = spectra_array > bins[1]
#print(spectra_array_thresholded)
print(len(bins))

In [None]:
#Pull out sets of cells from histogram after removing the noise cells 

from itertools import compress
spectra_thresholded_list = list(compress(spectra, spectra_array_thresholded))
print(len(spectra_thresholded_list))
fig, ax = plt.subplots(1,1)
bht,bins,x = ax.hist(spectra_thresholded_list, bins=100)

spectra_chopped_array = np.array(spectra_thresholded_list)

spectra_thresholded_list_sorted = sorted(spectra_thresholded_list)
total_num = len(spectra_thresholded_list_sorted)
quarter =int(total_num*.25)
half =int(total_num*.5)
seventyfive =int(total_num*.75)

spectra_array_quarter = spectra_chopped_array > spectra_chopped_array[quarter]
spectra_array_half = spectra_chopped_array > spectra_chopped_array[half]
spectra_array_seventyfive = spectra_chopped_array > spectra_chopped_array[seventyfive]

## Run Reconstructions on Useful Cells
#### change file output name!

In [None]:
#finding the GPUs available
# to see usage, nvidia-smi in terminal
import tensorflow as tf
tf.config.list_physical_devices('GPU')

In [None]:
from oe_acute import reconstruct as rct

output_file_name_rec = 'stim_reconstructions_results_20210718_left_all_nosilence.pkl'
#mne_results_limited = [mne_results[0][spectra_array_thresholded],mne_results[1][spectra_array_thresholded]]

gpu_num = 0

with tf.device('/GPU:{}'.format(gpu_num)):

    class_reconstructions = []
    class_comparisons = []
    for (stim, resp), mne_res in zip(stim_response_segments_reshape, mne_results):
        # reconstruct stimuli
        a_np = np.array([x[0] for x in mne_res])[None, ...].astype(np.float32)
        h_np = np.concatenate([x[1:n_dim+1][:, None] for x in mne_res], axis=1).astype(np.float32)
        J_np = np.concatenate([np.reshape(x[(n_dim+1):], (n_dim, n_dim, 1)) for x in mne_res], axis=-1).astype(np.float32)
       
        a_np = a_np[:, spectra_array_thresholded]
        h_np = h_np[:, spectra_array_thresholded]
        J_np = J_np[:, :, spectra_array_thresholded]

        print('Reconstructing test stimulus segments...')
        resp_trialavg = [MNE2.trial_average(x) for x in resp]
        resp_binary = [MNE2.resp_to_binary(x) for x in resp_trialavg]

        stim_comparisons = []
        stim_reconstructions = []
        for (stim_resp, single_stim) in zip(resp_binary, stim):
            # reconstructions is n_bins x ndim; same as stims_test

            reconstructions = rct.manopt_MLE(stim_resp[:, spectra_array_thresholded], a_np, h_np, J_np, radius/2, ndim=n_dim)
            
            n_bins = stim_resp.shape[0]
            res = np.zeros((16, n_bins+nsegbins))
            comp = []

            for idx in range(n_bins):
                res[:, idx:idx+nsegbins] += np.reshape(reconstructions[idx, :], (16, nsegbins))
                comp.append(rec_compare(reconstructions[idx, :],single_stim[idx, :]))
            stim_comparisons.append(comp)
            stim_reconstructions.append(res)

        class_reconstructions.append(stim_reconstructions)
        class_comparisons.append(stim_comparisons)
        print('Reconstruction complete')

recon_results_dict = {'class_reconstructions': class_reconstructions, 
                'class_comparisons' : class_comparisons}

print('Saving results...')
reconstruction_result_file = os.path.join(exp_path, output_file_name_rec)
with open(reconstruction_result_file, 'wb') as f:
    pickle.dump(recon_results_dict, f)
print('Done')

print(comp)

## Load Previous Reconstruction Results 

In [None]:
#load reconstruction results 
from oe_acute import reconstruct as rct
output_file_name_rec = 'stim_reconstructions_results_20210718_left_all_nosilence.pkl'
reconstruction_result_file = os.path.join(exp_path, output_file_name_rec)

# contains stim spectrograms and response class matrices
with open(reconstruction_result_file, 'rb') as f:
    recon_results_dict = pickle.load(f)
class_reconstructions = recon_results_dict['class_reconstructions']
class_comparisons = recon_results_dict['class_comparisons']

print(class_comparisons[0])
print(class_reconstructions[0])

## Reconstruction Visualizations!!!!

In [None]:
# original responses with comparison 
import scipy.spatial
from scipy.spatial import distance
from statistics import mean 

recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons = class_comparisons[0]
count = 0

for orig, recon, comparison in zip(stims_plot, recons_plot, comparisons):
    count += 1
    print(count)
    flat_orig = np.ravel(orig)
    flat_recon = np.ravel(recon[:, :-16])
    cosine_diff = round(scipy.spatial.distance.cosine(flat_orig, flat_recon),3)
    print(f'Overall difference: {cosine_diff}')
    fig = plot_true_recon2(orig, recon[:, :-16], comparison)
    plt.show()

In [None]:
# plot first third of stimului

recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons = class_comparisons[0]
count = 0

for orig, recon, comparison in zip(stims_plot, recons_plot, comparisons):
    third = round((orig.shape[1])/3)
    count += 1
    print(count)
    flat_orig = np.ravel(orig[:, :third])
    flat_recon = np.ravel(recon[:, :third])
    cosine_diff = scipy.spatial.distance.cosine(flat_orig, flat_recon)
    print(f'Overall difference: {cosine_diff}')
    fig = plot_true_recon2(orig[:, :third], recon[:, :third], comparison[:third], figsize=(7,7))
    plt.show()

In [None]:
# plot second third of stimului
recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons = class_comparisons[0]
count = 0

for orig, recon, comparison in zip(stims_plot, recons_plot, comparisons):
    third = round((orig.shape[1])/3)
    count += 1
    print(count)
    flat_orig = np.ravel(orig[:, third:(third+third)])
    flat_recon = np.ravel(recon[:, third:(third+third)])
    cosine_diff = scipy.spatial.distance.cosine(flat_orig, flat_recon)
    print(f'Overall difference: {cosine_diff}')
    fig = plot_true_recon2(orig[:, third:(third+third)], recon[:, third:(third+third)], comparison[third:(third+third)], figsize=(7,7))
    plt.show()

In [None]:
# plot last third of stimului
recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons = class_comparisons[0]
count = 0

for orig, recon, comparison in zip(stims_plot, recons_plot, comparisons):
    third = round((orig.shape[1])/3)
    count += 1
    print(count)
    flat_orig = np.ravel(orig[:, (third+third):])
    flat_recon = np.ravel(recon[:, (third+third):-16])
    cosine_diff = scipy.spatial.distance.cosine(flat_orig, flat_recon)
    print(f'Overall difference: {cosine_diff}')
    fig = plot_true_recon2(orig[:, (third+third):], recon[:, (third+third):-16], comparison[(third+third):], figsize=(7,7))
    plt.show()

In [None]:
# full shuffle responses

recons_plot2 = class_reconstructions[1]
stims_plot2 = kozlov_stim_responses[1][0]
comparisons2 = class_comparisons[1]
count = 0
for orig, recon, comparison in zip(stims_plot2, recons_plot2, comparisons2):
    count += 1
    print(count)
    flat_orig = np.ravel(orig)
    flat_recon = np.ravel(recon[:, :-16])
    cosine_diff = round(scipy.spatial.distance.cosine(flat_orig, flat_recon),3)
    print(f'Overall difference: {cosine_diff}')
    fig = plot_true_recon2(orig, recon[:, :-16], comparison)
    plt.show()

### Workshop this -adding shuffle comparisons

In [None]:
def plot_true_recon_shuff(true_stim, recon_stim, comparison, comparison_shuff=[0], color_map='jet', figsize=(16,9)):
    
    fig, axs = plt.subplots(3, 1, figsize=figsize, sharex=True)
    
    # True Stimulus
    axs[0].imshow(true_stim, cmap=color_map, origin='lower', aspect='auto', interpolation=None)
    axs[0].set_title('True Stimulus', fontsize=15)
    axs[0].tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)
    
    # recon stim
    axs[1].imshow(recon_stim, cmap=color_map, origin='lower', aspect='auto', interpolation=None)
    axs[1].set_title('Reconstructed Stimulus', fontsize=15)
    axs[1].tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)
    
    mean_comp = round(mean(comparison),3)
    mean_comp_shuff = round(mean(comparison_shuff),3)
    if len(comparison_shuff) > 2:
        axs[2].plot(comparison_shuff, linewidth=0.5, color = 'black', label = f'full shuffle reconstruction - mean: {mean_comp_shuff}')
        axs[2].plot(comparison, linewidth=1, label = f'true reconstruction - mean: {mean_comp}')
        axs[2].legend(loc = 'lower left', fontsize = 'large')
        axs[2].set_yticks(np.arange(0, 1.5, step=0.2)) 
        axs[2].set_title(f'Cosine Difference', fontsize=15)
        axs[2].tick_params(labelbottom=False, direction='in', bottom=False)
    else:
        axs[2].plot(comparison, linewidth=1, label = f'true reconstruction - mean: {mean_comp}')
        axs[2].set_yticks(np.arange(0, 1.5, step=0.2)) 
        #axs[2].set_ylabel('Cosine Difference')
        axs[2].set_title(f'Cosine Difference - Mean: {mean_comp}', fontsize=15)
        axs[2].tick_params(labelbottom=False, direction='in', bottom=False)
    
    plt.tight_layout()
    plt.show()
    return fig

In [None]:
#full spectrogram reconstructions vs. full shuffle comparisons 

recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons = class_comparisons[0]
recons_plot_shuff = class_reconstructions[1]
stims_plot_shuff = kozlov_stim_responses[1][0]
comparisons_shuff = class_comparisons[1]
count = 0 

for orig, recon, comparison, orig_shuff, recon_shuff, comparison_shuff in zip(stims_plot, recons_plot, comparisons, stims_plot_shuff, recons_plot_shuff, comparisons_shuff):
    count += 1
    print(count)
    cosine_diff = round(rec_compare(orig, recon[:, :-16]),3)
    cosine_diff_shuff = round(rec_compare(orig_shuff, recon_shuff[:, :-16]),3)
    print(f'Overall true cosine difference: {cosine_diff}')
    print(f'Overall shuffled cosine difference: {cosine_diff_shuff}')
    fig = plot_true_recon2(orig, recon[:, :-16], comparison, comparison_shuff)
    plt.show()

In [None]:
#full spectrogram reconstructions vs. full shuffle comparisons 

recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons = class_comparisons[0]
recons_plot_shuff = class_reconstructions[1]
stims_plot_shuff = kozlov_stim_responses[1][0]
comparisons_shuff = class_comparisons[1]
count = 0 

for orig, recon, comparison, orig_shuff, recon_shuff, comparison_shuff in zip(stims_plot, recons_plot, comparisons, stims_plot_shuff, recons_plot_shuff, comparisons_shuff):
    count += 1
    print(count)
    cosine_diff = round(rec_compare(orig, recon[:, :-16]),3)
    cosine_diff_shuff = round(rec_compare(orig_shuff, recon_shuff[:, :-16]),3)
    print(f'Overall true cosine difference: {cosine_diff}')
    print(f'Overall shuffled cosine difference: {cosine_diff_shuff}')
    fig = plot_true_recon_shuff(orig, recon[:, :-16], comparison, comparison_shuff)
    plt.show()

## Reconstruction Quantifications and Visualizations

In [None]:
#quanification of true vs. shuffled 

recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons = class_comparisons[0]
recons_plot_shuff = class_reconstructions[1]
stims_plot_shuff = kozlov_stim_responses[1][0]
comparisons_shuff = class_comparisons[1]
true_means = []
shuff_means = []
comparison_true_means = []
comparison_shuff_means = []

for orig, recon, comparison, orig_shuff, recon_shuff, comparison_shuff in zip(stims_plot, recons_plot, comparisons, stims_plot_shuff, recons_plot_shuff, comparisons_shuff):
    cosine_diff = rec_compare(orig, recon[:, :-16])
    true_means.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff, recon_shuff[:, :-16])
    shuff_means.append(cosine_diff_shuff)
    comparison_true_mean = mean(comparison)
    comparison_true_means.append(comparison_true_mean)
    comparison_shuff_mean = mean(comparison_shuff)
    comparison_shuff_means.append(comparison_shuff_mean)

means = [true_means, shuff_means, comparison_true_means, comparison_shuff_means]
labels = ['Overall True Comparisons', 'Overall Shuffled Comparisons', 'Binned True Means', 'Binned Shuffled Means']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True)
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # labels
colors = ['lightgreen', 'lightblue', 'mediumseagreen', 'dodgerblue']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)
axs.set_ylabel('Cosine Difference')
axs.set_title('Cosine Difference Values Over All Bouts for True and Shuffled Responses', fontsize='large')
plt.show()
        

print(f'Mean & Standard Deviation of overall true cosine differences: {round(mean(true_means),3)}, {round(np.std(true_means),3)}')
print(f'Mean & Standard Deviation of overall shuffled cosine differences: {round(mean(shuff_means),3)}, {round(np.std(shuff_means),3)}')
print(f'Mean & Standard Deviation of time bin true cosine differences: {round(mean(comparison_true_means),3)}, {round(np.std(comparison_true_means),3)}')
print(f'Mean & Standard Deviation of time bin shuffled cosine differences: {round(mean(comparison_shuff_means),3)}, {round(np.std(comparison_shuff_means),3)}')
    

In [None]:
# plot of beginning, middle, and end quality with overall comparions

recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons = class_comparisons[0]
recons_plot_shuff = class_reconstructions[1]
stims_plot_shuff = kozlov_stim_responses[1][0]
comparisons_shuff = class_comparisons[1]
true_means1 = []
shuff_means1 = []
true_means2 = []
shuff_means2 = []
true_means3 = []
shuff_means3 = []


for orig, recon, comparison, orig_shuff, recon_shuff, comparison_shuff in zip(stims_plot, recons_plot, comparisons, stims_plot_shuff, recons_plot_shuff, comparisons_shuff):
    third = round((orig.shape[1])/3)
#     if third * 3 > orig.shape[1]:
#         third - 1 = third 
    cosine_diff = rec_compare(orig[:, :third], recon[:, :third])
    true_means1.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff[:, :third], recon_shuff[:, :third])
    shuff_means1.append(cosine_diff_shuff)
    
    cosine_diff = rec_compare(orig[:, third:(third+third)], recon[:, third:(third+third)])
    true_means2.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff[:, third:(third+third)], recon_shuff[:, third:(third+third)])
    shuff_means2.append(cosine_diff_shuff)
    
    cosine_diff = rec_compare(orig[:, (third+third):], recon[:, (third+third):-16])
    true_means3.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff[:, (third+third):], recon_shuff[:, (third+third):-16])
    shuff_means3.append(cosine_diff_shuff)
    
means = [true_means1, true_means2, true_means3, shuff_means1, shuff_means2, shuff_means3]
labels = ['Beginning True', 'Middle True', 'End True', 'Beginning Shuffled', 'Middle Shuffled', 'End Shuffled']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True)
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # labels
axs.set_xlabel('Overall True Comparisons                                         Overall Shuffled Comparisons') # labels
colors = ['plum', 'mediumorchid', 'darkorchid', 'lightskyblue','blue','navy']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)
axs.set_ylabel('Cosine Difference')
axs.set_title('Overall Cosine Difference Across Beginning, Middle, and Ends of Bouts for True and Shuffled Responses', fontsize='medium')
plt.show()
        

# print(f'Mean & Standard Deviation of overall true cosine differences: {round(mean(true_means),3)}, {round(np.std(true_means),3)}')
# print(f'Mean & Standard Deviation of overall shuffled cosine differences: {round(mean(shuff_means),3)}, {round(np.std(shuff_means),3)}')
# print(f'Mean & Standard Deviation of time bin true cosine differences: {round(mean(comparison_true_means),3)}, {round(np.std(comparison_true_means),3)}')
# print(f'Mean & Standard Deviation of time bin shuffled cosine differences: {round(mean(comparison_shuff_means),3)}, {round(np.std(comparison_shuff_means),3)}')
    

##### binned comparisons

In [None]:
#quanification of true vs. shuffled just binned

recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons, comparisons_shuff = class_comparisons[0], class_comparisons[1]
recons_plot_shuff = class_reconstructions[1]
stims_plot_shuff = kozlov_stim_responses[1][0]
comparison_true_means = []
comparison_shuff_means = []

for comparison, comparison_shuff in zip(comparisons, comparisons_shuff):
    comparison_true_mean = mean(comparison)
    comparison_true_means.append(comparison_true_mean)
    comparison_shuff_mean = mean(comparison_shuff)
    comparison_shuff_means.append(comparison_shuff_mean)

means = [comparison_true_means, comparison_shuff_means]
labels = ['Binned True Means', 'Binned Shuffled Means']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True)
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # labels
colors = ['mediumseagreen', 'dodgerblue']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)
axs.set_ylabel('Cosine Similarity')
axs.set_title('Cosine Similarity Values For Each Time Bin Over All Bouts for True and Shuffled Responses', fontsize='large')
plt.show()
        
print(f'Mean & Standard Deviation of time bin true cosine differences: {round(mean(comparison_true_means),3)}, {round(np.std(comparison_true_means),3)}')
print(f'Mean & Standard Deviation of time bin shuffled cosine differences: {round(mean(comparison_shuff_means),3)}, {round(np.std(comparison_shuff_means),3)}')
    

In [None]:
# plot of beginning, middle, and end quality with binned comparions
comparisons = class_comparisons[0]
comparisons_shuff = class_comparisons[1]
comparison_true_means1 = []
comparison_shuff_means1 = []
comparison_true_means2 = []
comparison_shuff_means2 = []
comparison_true_means3 = []
comparison_shuff_means3 = []

for comparison, comparison_shuff in zip(comparisons, comparisons_shuff):
    third = round((len(comparison))/3)
    comparison_true_mean1 = mean(comparison[:third])
    comparison_true_means1.append(comparison_true_mean1)
    comparison_shuff_mean1 = mean(comparison_shuff[:third])
    comparison_shuff_means1.append(comparison_shuff_mean1)
    
    comparison_true_mean2 = mean(comparison[third:(third+third)])
    comparison_true_means2.append(comparison_true_mean2)
    comparison_shuff_mean2 = mean(comparison_shuff[third:(third+third)])
    comparison_shuff_means2.append(comparison_shuff_mean2)
    
    comparison_true_mean3 = mean(comparison[(third+third):])
    comparison_true_means3.append(comparison_true_mean3)
    comparison_shuff_mean3 = mean(comparison_shuff[(third+third):])
    comparison_shuff_means3.append(comparison_shuff_mean3)
    
    
    
    
means = [comparison_true_means1, comparison_true_means2, comparison_true_means3, comparison_shuff_means1, comparison_shuff_means2, comparison_shuff_means3]
labels = ['Beginning True', 'Middle True', 'End True', 'Beginning Shuffled', 'Middle Shuffled', 'End Shuffled']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True)
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # labels
axs.set_xlabel('Binned True Comparisons                                         Binned Shuffled Comparisons') # labels
colors = ['plum', 'mediumorchid', 'darkorchid', 'lightskyblue','blue','navy']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)
axs.set_ylabel('Cosine Similarity')
axs.set_title('Binned Cosine Similarity Across Beginning, Middle, and Ends of Bouts for True and Shuffled Responses', fontsize='medium')
plt.show()
        

# print(f'Mean & Standard Deviation of overall true cosine differences: {round(mean(true_means),3)}, {round(np.std(true_means),3)}')
# print(f'Mean & Standard Deviation of overall shuffled cosine differences: {round(mean(shuff_means),3)}, {round(np.std(shuff_means),3)}')
# print(f'Mean & Standard Deviation of time bin true cosine differences: {round(mean(comparison_true_means),3)}, {round(np.std(comparison_true_means),3)}')
# print(f'Mean & Standard Deviation of time bin shuffled cosine differences: {round(mean(comparison_shuff_means),3)}, {round(np.std(comparison_shuff_means),3)}')
    

## doing same as above on sets of cells

In [None]:
#get J matrix for each cell and compute spectral radi
#Pull out sets of cells from histogram after removing the noise cells 

from itertools import compress
from numpy.linalg import eigvalsh as brad
#get J matrix for each cell and compute spectral radi
spectra = []
J_np = [np.reshape(x[(n_dim+1):], (n_dim, n_dim)) for x in mne_results[0]]
for J in J_np:
    radium = max(abs(brad(J)))
    spectra.append(radium)
fig, ax = plt.subplots(1,1)
bht,bins,x = ax.hist(spectra, bins=100)
spectra_array = np.array(spectra)
spectra_array_thresholded = spectra_array > bins[1]



spectra_thresholded_list = list(compress(spectra, spectra_array_thresholded))
print(len(spectra_thresholded_list))
fig, ax = plt.subplots(1,1)
bht,bins,x = ax.hist(spectra_thresholded_list, bins=100)
spectra_thresholded_list_sorted = sorted(spectra_thresholded_list)
total_num = len(spectra_thresholded_list_sorted)
quarter =int(total_num*.25)
half =int(total_num*.5)
seventyfive =int(total_num*.75)
ninety =int(total_num*.90)
ninetyfive = int(total_num*.95)
ninetynine =int(total_num*.99)
twentyfive =int(total_num*.25)
ten =int(total_num*.10)
five = int(total_num*.05)
one =int(total_num*.01)

spectra_array_quarter = spectra_array > spectra_thresholded_list_sorted[quarter]
spectra_array_half = spectra_array > spectra_thresholded_list_sorted[half]
spectra_array_seventyfive = spectra_array > spectra_thresholded_list_sorted[seventyfive]
spectra_array_ninety = spectra_array > spectra_thresholded_list_sorted[ninety]
spectra_array_ninetyfive = spectra_array > spectra_thresholded_list_sorted[ninetyfive]
spectra_array_ninetynine = spectra_array > spectra_thresholded_list_sorted[ninetynine]

In [None]:
print(len(spectra_array_ninetynine))
true_count = sum(spectra_array_ninetynine)
print(true_count)
print(one)

In [None]:
#top 75% of cells 

from oe_acute import reconstruct as rct

output_file_name_rec = 'stim_reconstructions_results_20210718_left_all_75_nosilence.pkl'
#mne_results_limited = [mne_results[0][spectra_array_thresholded],mne_results[1][spectra_array_thresholded]]

gpu_num = 0

with tf.device('/GPU:{}'.format(gpu_num)):

    class_reconstructions_75 = []
    class_comparisons_75 = []
    for (stim, resp), mne_res in zip(stim_response_segments_reshape, mne_results):
        # reconstruct stimuli
        a_np = np.array([x[0] for x in mne_res])[None, ...].astype(np.float32)
        h_np = np.concatenate([x[1:n_dim+1][:, None] for x in mne_res], axis=1).astype(np.float32)
        J_np = np.concatenate([np.reshape(x[(n_dim+1):], (n_dim, n_dim, 1)) for x in mne_res], axis=-1).astype(np.float32)
       
        a_np = a_np[:, spectra_array_seventyfive]
        h_np = h_np[:, spectra_array_seventyfive]
        J_np = J_np[:, :, spectra_array_seventyfive]

        print('Reconstructing test stimulus segments...')
        resp_trialavg = [MNE2.trial_average(x) for x in resp]
        resp_binary = [MNE2.resp_to_binary(x) for x in resp_trialavg]

        stim_comparisons = []
        stim_reconstructions = []
        for (stim_resp, single_stim) in zip(resp_binary, stim):
            # reconstructions is n_bins x ndim; same as stims_test

            reconstructions = rct.manopt_MLE(stim_resp[:, spectra_array_seventyfive], a_np, h_np, J_np, radius/2, ndim=n_dim)
            
            n_bins = stim_resp.shape[0]
            res = np.zeros((16, n_bins+nsegbins))
            comp = []

            for idx in range(n_bins):
                res[:, idx:idx+nsegbins] += np.reshape(reconstructions[idx, :], (16, nsegbins))
                comp.append(rec_compare(reconstructions[idx, :],single_stim[idx, :]))
            stim_comparisons.append(comp)
            stim_reconstructions.append(res)

        class_reconstructions_75.append(stim_reconstructions)
        class_comparisons_75.append(stim_comparisons)
        print('Reconstruction complete')

recon_results_dict = {'class_reconstructions': class_reconstructions_75, 
                'class_comparisons' : class_comparisons_75}

print('Saving results...')
reconstruction_result_file = os.path.join(exp_path, output_file_name_rec)
with open(reconstruction_result_file, 'wb') as f:
    pickle.dump(recon_results_dict, f)
print('Done')

#print(comp)

In [None]:
#top half of cells

from oe_acute import reconstruct as rct

output_file_name_rec = 'stim_reconstructions_results_20210718_left_all_50_nosilence.pkl'

gpu_num = 0

with tf.device('/GPU:{}'.format(gpu_num)):

    class_reconstructions_50 = []
    class_comparisons_50 = []
    for (stim, resp), mne_res in zip(stim_response_segments_reshape, mne_results):
        # reconstruct stimuli
        a_np = np.array([x[0] for x in mne_res])[None, ...].astype(np.float32)
        h_np = np.concatenate([x[1:n_dim+1][:, None] for x in mne_res], axis=1).astype(np.float32)
        J_np = np.concatenate([np.reshape(x[(n_dim+1):], (n_dim, n_dim, 1)) for x in mne_res], axis=-1).astype(np.float32)
       
        a_np = a_np[:, spectra_array_half]
        h_np = h_np[:, spectra_array_half]
        J_np = J_np[:, :, spectra_array_half]

        print('Reconstructing test stimulus segments...')
        resp_trialavg = [MNE2.trial_average(x) for x in resp]
        resp_binary = [MNE2.resp_to_binary(x) for x in resp_trialavg]

        stim_comparisons = []
        stim_reconstructions = []
        for (stim_resp, single_stim) in zip(resp_binary, stim):
            # reconstructions is n_bins x ndim; same as stims_test

            reconstructions = rct.manopt_MLE(stim_resp[:, spectra_array_half], a_np, h_np, J_np, radius/2, ndim=n_dim)
            
            n_bins = stim_resp.shape[0]
            res = np.zeros((16, n_bins+nsegbins))
            comp = []

            for idx in range(n_bins):
                res[:, idx:idx+nsegbins] += np.reshape(reconstructions[idx, :], (16, nsegbins))
                comp.append(rec_compare(reconstructions[idx, :],single_stim[idx, :]))
            stim_comparisons.append(comp)
            stim_reconstructions.append(res)

        class_reconstructions_50.append(stim_reconstructions)
        class_comparisons_50.append(stim_comparisons)
        print('Reconstruction complete')

recon_results_dict = {'class_reconstructions': class_reconstructions_50, 
                'class_comparisons' : class_comparisons_50}

print('Saving results...')
reconstruction_result_file = os.path.join(exp_path, output_file_name_rec)
with open(reconstruction_result_file, 'wb') as f:
    pickle.dump(recon_results_dict, f)
print('Done')

#print(comp)

In [None]:
#top 25% of cells 

from oe_acute import reconstruct as rct

output_file_name_rec = 'stim_reconstructions_results_20210718_left_all_25_nosilence.pkl'
#mne_results_limited = [mne_results[0][spectra_array_thresholded],mne_results[1][spectra_array_thresholded]]

gpu_num = 0

with tf.device('/GPU:{}'.format(gpu_num)):

    class_reconstructions_25 = []
    class_comparisons_25 = []
    for (stim, resp), mne_res in zip(stim_response_segments_reshape, mne_results):
        # reconstruct stimuli
        a_np = np.array([x[0] for x in mne_res])[None, ...].astype(np.float32)
        h_np = np.concatenate([x[1:n_dim+1][:, None] for x in mne_res], axis=1).astype(np.float32)
        J_np = np.concatenate([np.reshape(x[(n_dim+1):], (n_dim, n_dim, 1)) for x in mne_res], axis=-1).astype(np.float32)
       
        a_np = a_np[:, spectra_array_quarter]
        h_np = h_np[:, spectra_array_quarter]
        J_np = J_np[:, :, spectra_array_quarter]

        print('Reconstructing test stimulus segments...')
        resp_trialavg = [MNE2.trial_average(x) for x in resp]
        resp_binary = [MNE2.resp_to_binary(x) for x in resp_trialavg]

        stim_comparisons = []
        stim_reconstructions = []
        for (stim_resp, single_stim) in zip(resp_binary, stim):
            # reconstructions is n_bins x ndim; same as stims_test

            reconstructions = rct.manopt_MLE(stim_resp[:, spectra_array_quarter], a_np, h_np, J_np, radius/2, ndim=n_dim)
            
            n_bins = stim_resp.shape[0]
            res = np.zeros((16, n_bins+nsegbins))
            comp = []

            for idx in range(n_bins):
                res[:, idx:idx+nsegbins] += np.reshape(reconstructions[idx, :], (16, nsegbins))
                comp.append(rec_compare(reconstructions[idx, :],single_stim[idx, :]))
            stim_comparisons.append(comp)
            stim_reconstructions.append(res)

        class_reconstructions_25.append(stim_reconstructions)
        class_comparisons_25.append(stim_comparisons)
        print('Reconstruction complete')

recon_results_dict = {'class_reconstructions': class_reconstructions_25, 
                'class_comparisons' : class_comparisons_25}

print('Saving results...')
reconstruction_result_file = os.path.join(exp_path, output_file_name_rec)
with open(reconstruction_result_file, 'wb') as f:
    pickle.dump(recon_results_dict, f)
print('Done')

#print(comp)

In [None]:
#top 10% of cells 

from oe_acute import reconstruct as rct

output_file_name_rec = 'stim_reconstructions_results_20210718_left_all_10_nosilence.pkl'
#mne_results_limited = [mne_results[0][spectra_array_thresholded],mne_results[1][spectra_array_thresholded]]

gpu_num = 0

with tf.device('/GPU:{}'.format(gpu_num)):

    class_reconstructions_10 = []
    class_comparisons_10 = []
    for (stim, resp), mne_res in zip(stim_response_segments_reshape, mne_results):
        # reconstruct stimuli
        a_np = np.array([x[0] for x in mne_res])[None, ...].astype(np.float32)
        h_np = np.concatenate([x[1:n_dim+1][:, None] for x in mne_res], axis=1).astype(np.float32)
        J_np = np.concatenate([np.reshape(x[(n_dim+1):], (n_dim, n_dim, 1)) for x in mne_res], axis=-1).astype(np.float32)
       
        a_np = a_np[:, spectra_array_ninety]
        h_np = h_np[:, spectra_array_ninety]
        J_np = J_np[:, :, spectra_array_ninety]

        print('Reconstructing test stimulus segments...')
        resp_trialavg = [MNE2.trial_average(x) for x in resp]
        resp_binary = [MNE2.resp_to_binary(x) for x in resp_trialavg]

        stim_comparisons = []
        stim_reconstructions = []
        for (stim_resp, single_stim) in zip(resp_binary, stim):
            # reconstructions is n_bins x ndim; same as stims_test

            reconstructions = rct.manopt_MLE(stim_resp[:, spectra_array_ninety], a_np, h_np, J_np, radius/2, ndim=n_dim)
            
            n_bins = stim_resp.shape[0]
            res = np.zeros((16, n_bins+nsegbins))
            comp = []

            for idx in range(n_bins):
                res[:, idx:idx+nsegbins] += np.reshape(reconstructions[idx, :], (16, nsegbins))
                comp.append(rec_compare(reconstructions[idx, :],single_stim[idx, :]))
            stim_comparisons.append(comp)
            stim_reconstructions.append(res)

        class_reconstructions_10.append(stim_reconstructions)
        class_comparisons_10.append(stim_comparisons)
        print('Reconstruction complete')

recon_results_dict = {'class_reconstructions': class_reconstructions_75, 
                'class_comparisons' : class_comparisons_75}

print('Saving results...')
reconstruction_result_file = os.path.join(exp_path, output_file_name_rec)
with open(reconstruction_result_file, 'wb') as f:
    pickle.dump(recon_results_dict, f)
print('Done')

#print(comp)

In [None]:
#top 5% of cells 

from oe_acute import reconstruct as rct

output_file_name_rec = 'stim_reconstructions_results_20210718_left_all_5_nosilence.pkl'
#mne_results_limited = [mne_results[0][spectra_array_thresholded],mne_results[1][spectra_array_thresholded]]

gpu_num = 0

with tf.device('/GPU:{}'.format(gpu_num)):

    class_reconstructions_5 = []
    class_comparisons_5 = []
    for (stim, resp), mne_res in zip(stim_response_segments_reshape, mne_results):
        # reconstruct stimuli
        a_np = np.array([x[0] for x in mne_res])[None, ...].astype(np.float32)
        h_np = np.concatenate([x[1:n_dim+1][:, None] for x in mne_res], axis=1).astype(np.float32)
        J_np = np.concatenate([np.reshape(x[(n_dim+1):], (n_dim, n_dim, 1)) for x in mne_res], axis=-1).astype(np.float32)
       
        a_np = a_np[:, spectra_array_ninetyfive]
        h_np = h_np[:, spectra_array_ninetyfive]
        J_np = J_np[:, :, spectra_array_ninetyfive]

        print('Reconstructing test stimulus segments...')
        resp_trialavg = [MNE2.trial_average(x) for x in resp]
        resp_binary = [MNE2.resp_to_binary(x) for x in resp_trialavg]

        stim_comparisons = []
        stim_reconstructions = []
        for (stim_resp, single_stim) in zip(resp_binary, stim):
            # reconstructions is n_bins x ndim; same as stims_test

            reconstructions = rct.manopt_MLE(stim_resp[:, spectra_array_ninetyfive], a_np, h_np, J_np, radius/2, ndim=n_dim)
            
            n_bins = stim_resp.shape[0]
            res = np.zeros((16, n_bins+nsegbins))
            comp = []

            for idx in range(n_bins):
                res[:, idx:idx+nsegbins] += np.reshape(reconstructions[idx, :], (16, nsegbins))
                comp.append(rec_compare(reconstructions[idx, :],single_stim[idx, :]))
            stim_comparisons.append(comp)
            stim_reconstructions.append(res)

        class_reconstructions_5.append(stim_reconstructions)
        class_comparisons_5.append(stim_comparisons)
        print('Reconstruction complete')

recon_results_dict = {'class_reconstructions': class_reconstructions_5, 
                'class_comparisons' : class_comparisons_5}

print('Saving results...')
reconstruction_result_file = os.path.join(exp_path, output_file_name_rec)
with open(reconstruction_result_file, 'wb') as f:
    pickle.dump(recon_results_dict, f)
print('Done')

#print(comp)

In [None]:
#top 1% of cells 

from oe_acute import reconstruct as rct

output_file_name_rec = 'stim_reconstructions_results_20210718_left_all_1_nosilence.pkl'
#mne_results_limited = [mne_results[0][spectra_array_thresholded],mne_results[1][spectra_array_thresholded]]

gpu_num = 0

with tf.device('/GPU:{}'.format(gpu_num)):

    class_reconstructions_1 = []
    class_comparisons_1 = []
    for (stim, resp), mne_res in zip(stim_response_segments_reshape, mne_results):
        # reconstruct stimuli
        a_np = np.array([x[0] for x in mne_res])[None, ...].astype(np.float32)
        h_np = np.concatenate([x[1:n_dim+1][:, None] for x in mne_res], axis=1).astype(np.float32)
        J_np = np.concatenate([np.reshape(x[(n_dim+1):], (n_dim, n_dim, 1)) for x in mne_res], axis=-1).astype(np.float32)
       
        a_np = a_np[:, spectra_array_ninetynine]
        h_np = h_np[:, spectra_array_ninetynine]
        J_np = J_np[:, :, spectra_array_ninetynine]

        print('Reconstructing test stimulus segments...')
        resp_trialavg = [MNE2.trial_average(x) for x in resp]
        resp_binary = [MNE2.resp_to_binary(x) for x in resp_trialavg]

        stim_comparisons = []
        stim_reconstructions = []
        for (stim_resp, single_stim) in zip(resp_binary, stim):
            # reconstructions is n_bins x ndim; same as stims_test

            reconstructions = rct.manopt_MLE(stim_resp[:, spectra_array_ninetynine], a_np, h_np, J_np, radius/2, ndim=n_dim)
            
            n_bins = stim_resp.shape[0]
            res = np.zeros((16, n_bins+nsegbins))
            comp = []

            for idx in range(n_bins):
                res[:, idx:idx+nsegbins] += np.reshape(reconstructions[idx, :], (16, nsegbins))
                comp.append(rec_compare(reconstructions[idx, :],single_stim[idx, :]))
            stim_comparisons.append(comp)
            stim_reconstructions.append(res)

        class_reconstructions_1.append(stim_reconstructions)
        class_comparisons_1.append(stim_comparisons)
        print('Reconstruction complete')

recon_results_dict = {'class_reconstructions': class_reconstructions_1, 
                'class_comparisons' : class_comparisons_1}

print('Saving results...')
reconstruction_result_file = os.path.join(exp_path, output_file_name_rec)
with open(reconstruction_result_file, 'wb') as f:
    pickle.dump(recon_results_dict, f)
print('Done')

#print(comp)

In [None]:
#quanification of true vs. shuffled just binned

comparisons, comparisons_shuff = class_comparisons[0], class_comparisons[1]
comparison_true_means = []
comparison_shuff_means = []
for comparison, comparison_shuff in zip(comparisons, comparisons_shuff):
    comparison_true_mean = mean(comparison)
    comparison_true_means.append(comparison_true_mean)
    comparison_shuff_mean = mean(comparison_shuff)
    comparison_shuff_means.append(comparison_shuff_mean)

comparisons, comparisons_shuff = class_comparisons_25[0], class_comparisons_25[1]
comparison_true_means_25 = []
comparison_shuff_means_25 = []
for comparison, comparison_shuff in zip(comparisons, comparisons_shuff):
    comparison_true_mean = mean(comparison)
    comparison_true_means_25.append(comparison_true_mean)
    comparison_shuff_mean = mean(comparison_shuff)
    comparison_shuff_means_25.append(comparison_shuff_mean)

    
comparisons, comparisons_shuff = class_comparisons_50[0], class_comparisons_50[1]
comparison_true_means_50 = []
comparison_shuff_means_50 = []
for comparison, comparison_shuff in zip(comparisons, comparisons_shuff):
    comparison_true_mean = mean(comparison)
    comparison_true_means_50.append(comparison_true_mean)
    comparison_shuff_mean = mean(comparison_shuff)
    comparison_shuff_means_50.append(comparison_shuff_mean)

comparisons, comparisons_shuff = class_comparisons_75[0], class_comparisons_75[1]
comparison_true_means_75 = []
comparison_shuff_means_75 = []
for comparison, comparison_shuff in zip(comparisons, comparisons_shuff):
    comparison_true_mean = mean(comparison)
    comparison_true_means_75.append(comparison_true_mean)
    comparison_shuff_mean = mean(comparison_shuff)
    comparison_shuff_means_75.append(comparison_shuff_mean)
    
comparisons, comparisons_shuff = class_comparisons_10[0], class_comparisons_10[1]
comparison_true_means_10 = []
comparison_shuff_means_10 = []
for comparison, comparison_shuff in zip(comparisons, comparisons_shuff):
    comparison_true_mean = mean(comparison)
    comparison_true_means_10.append(comparison_true_mean)
    comparison_shuff_mean = mean(comparison_shuff)
    comparison_shuff_means_10.append(comparison_shuff_mean)

comparisons, comparisons_shuff = class_comparisons_5[0], class_comparisons_5[1]
comparison_true_means_5 = []
comparison_shuff_means_5 = []
for comparison, comparison_shuff in zip(comparisons, comparisons_shuff):
    comparison_true_mean = mean(comparison)
    comparison_true_means_5.append(comparison_true_mean)
    comparison_shuff_mean = mean(comparison_shuff)
    comparison_shuff_means_5.append(comparison_shuff_mean)

comparisons, comparisons_shuff = class_comparisons_1[0], class_comparisons_1[1]
comparison_true_means_1 = []
comparison_shuff_means_1 = []
for comparison, comparison_shuff in zip(comparisons, comparisons_shuff):
    comparison_true_mean = mean(comparison)
    comparison_true_means_1.append(comparison_true_mean)
    comparison_shuff_mean = mean(comparison_shuff)
    comparison_shuff_means_1.append(comparison_shuff_mean)


means = [comparison_true_means_1, comparison_true_means_5, comparison_true_means_10, comparison_true_means_25, comparison_true_means_50, comparison_true_means_75, comparison_true_means,  comparison_shuff_means_1, comparison_shuff_means_5, comparison_shuff_means_10, comparison_shuff_means_25, comparison_shuff_means_50, comparison_shuff_means_75, comparison_shuff_means]
labels = ['1% True Means', '5% True Means', '10% True Means', '25% True Means', '50% True Means', '75% True Means', '100% True Means',  '1% Shuffled Means', '5% Shuffled Means', '10% Shuffled Means', '25% Shuffled Means', '50% Shuffled Means', '75% Shuffled Means', '100% Shuffled Means']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True)
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=80) # labels
colors = ['orchid','orchid','orchid','orchid','orchid','orchid','orchid']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)
axs.set_ylabel('Cosine Similarity')
axs.set_title('Cosine Similarity Values For Each Time Bin Over Sets of Cells Based On Spectral Radius', fontsize='large')
plt.show()
        
# print(f'Mean & Standard Deviation of time bin true cosine differences: {round(mean(comparison_true_means),3)}, {round(np.std(comparison_true_means),3)}')
# print(f'Mean & Standard Deviation of time bin shuffled cosine differences: {round(mean(comparison_shuff_means),3)}, {round(np.std(comparison_shuff_means),3)}')
    

In [None]:

means = [comparison_true_means_1, comparison_true_means_5, comparison_true_means_10, comparison_true_means_25, comparison_true_means_50, comparison_true_means_75, comparison_true_means]
means2 = [comparison_shuff_means_1, comparison_shuff_means_5, comparison_shuff_means_10, comparison_shuff_means_25, comparison_shuff_means_50, comparison_shuff_means_75, comparison_shuff_means]




labels = ['1%', '5%', '10%', '25%', '50%', '75%', '100%','1%', '5%', '10%', '25%', '50%', '75%', '100%']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True, boxprops=dict(facecolor="orchid"))
bplot2 = axs.boxplot(means2, patch_artist=True)
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # label
axs.legend([bplot["boxes"][0], bplot2["boxes"][0]], ['True', 'Shuffled'])
axs.set_ylabel('Cosine Similarity')
axs.set_title('Overall Cosine Difference Values', fontsize='large')
plt.show()

In [None]:

means = [comparison_true_means_1, comparison_true_means_5, comparison_true_means_10, comparison_true_means_25, comparison_true_means_50, comparison_true_means_75, comparison_true_means]
labels = ['1%', '5%', '10%', '25%', '50%', '75%', '100%']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True)
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=60) # labels
colors = ['orchid','orchid','orchid','orchid','orchid','orchid','orchid']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)
axs.set_ylabel('Cosine Similarity')
axs.set_title('Cosine Similarity Values For Each Time Bin Over Sets of Cells Based On Spectral Radius', fontsize='large')
plt.show()
        

In [None]:
#do this for overall means just in case

stims_plot = kozlov_stim_responses[0][0]
stims_plot_shuff = kozlov_stim_responses[1][0]


recons_plot = class_reconstructions[0]
recons_plot_shuff = class_reconstructions[1]
true_means = []
shuff_means = []
for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
    cosine_diff = rec_compare(orig, recon[:, :-16])
    true_means.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff, recon_shuff[:, :-16])
    shuff_means.append(cosine_diff_shuff)
    
recons_plot = class_reconstructions_1[0]
recons_plot_shuff = class_reconstructions_1[1]
true_means_1 = []
shuff_means_1 = []
for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
    cosine_diff = rec_compare(orig, recon[:, :-16])
    true_means_1.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff, recon_shuff[:, :-16])
    shuff_means_1.append(cosine_diff_shuff)

recons_plot = class_reconstructions_5[0]
recons_plot_shuff = class_reconstructions_5[1]
true_means_5 = []
shuff_means_5 = []
for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
    cosine_diff = rec_compare(orig, recon[:, :-16])
    true_means_5.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff, recon_shuff[:, :-16])
    shuff_means_5.append(cosine_diff_shuff)

recons_plot = class_reconstructions_10[0]
recons_plot_shuff = class_reconstructions_10[1]
true_means_10 = []
shuff_means_10 = []
for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
    cosine_diff = rec_compare(orig, recon[:, :-16])
    true_means_10.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff, recon_shuff[:, :-16])
    shuff_means_10.append(cosine_diff_shuff)

recons_plot = class_reconstructions_25[0]
recons_plot_shuff = class_reconstructions_25[1]
true_means_25 = []
shuff_means_25 = []
for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
    cosine_diff = rec_compare(orig, recon[:, :-16])
    true_means_25.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff, recon_shuff[:, :-16])
    shuff_means_25.append(cosine_diff_shuff)

recons_plot = class_reconstructions_50[0]
recons_plot_shuff = class_reconstructions_50[1]
true_means_50 = []
shuff_means_50 = []
for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
    cosine_diff = rec_compare(orig, recon[:, :-16])
    true_means_50.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff, recon_shuff[:, :-16])
    shuff_means_50.append(cosine_diff_shuff)

recons_plot = class_reconstructions_75[0]
recons_plot_shuff = class_reconstructions_75[1]
true_means_75 = []
shuff_means_75 = []
for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
    cosine_diff = rec_compare(orig, recon[:, :-16])
    true_means_75.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff, recon_shuff[:, :-16])
    shuff_means_75.append(cosine_diff_shuff)

means = [true_means_1, true_means_5, true_means_10, true_means_25, true_means_50, true_means_75, true_means,  shuff_means_1, shuff_means_5, shuff_means_10, shuff_means_25, shuff_means_50, shuff_means_75, shuff_means]
labels = ['1% True Means', '5% True Means', '10% True Means', '25% True Means', '50% True Means', '75% True Means', '100% True Means',  '1% Shuffled Means', '5% Shuffled Means', '10% Shuffled Means', '25% Shuffled Means', '50% Shuffled Means', '75% Shuffled Means', '100% Shuffled Means']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True)
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=60) # labels
colors = ['lightgreen','lightgreen','lightgreen','lightgreen','lightgreen','lightgreen','lightgreen']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)
axs.set_ylabel('Cosine Similarity')
axs.set_title('Overall Cosine Difference Values', fontsize='large')
plt.show()


In [None]:
means = [true_means_1, true_means_5, true_means_10, true_means_25, true_means_50, true_means_75, true_means]
means2 = [shuff_means_1, shuff_means_5, shuff_means_10, shuff_means_25, shuff_means_50, shuff_means_75, shuff_means]


labels = ['1%', '5%', '10%', '25%', '50%', '75%', '100%','1%', '5%', '10%', '25%', '50%', '75%', '100%']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True,boxprops=dict(facecolor="lightgreen"))
bplot2 = axs.boxplot(means2, patch_artist=True)
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # labels
axs.legend([bplot["boxes"][0], bplot2["boxes"][0]], ['True', 'Shuffled'], loc='lower right')
axs.set_ylabel('Cosine Similarity')
axs.set_title('Overall Cosine Difference Values', fontsize='large')
plt.show()

In [None]:
means = [true_means_1, true_means_5, true_means_10, true_means_25, true_means_50, true_means_75, true_means]
labels = ['1%', '5%', '10%', '25%', '50%', '75%', '100%']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True)
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # labels
colors = ['lightgreen','lightgreen','lightgreen','lightgreen','lightgreen','lightgreen','lightgreen']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)
axs.set_ylabel('Cosine Similarity')
axs.set_title('Overall Cosine Difference Values', fontsize='large')
plt.show()

## Removing Sets of Cells at Random Then Finding Reconstruction Quaility

In [None]:
def selective_reconstruction(output_file_name_rec, spectra_array, stim_response_segments_reshape=stim_response_segments_reshape, mne_results=mne_results, gpu_num=0):
    from oe_acute import reconstruct as rct

    #output_file_name_rec = 'stim_reconstructions_results_20210718_left_all_25.pkl'
    #mne_results_limited = [mne_results[0][spectra_array_thresholded],mne_results[1][spectra_array_thresholded]]

    with tf.device('/GPU:{}'.format(gpu_num)):

        class_reconstructions_output_name = []
        class_comparisons_output_name = []
        for (stim, resp), mne_res in zip(stim_response_segments_reshape, mne_results):
            # reconstruct stimuli
            a_np = np.array([x[0] for x in mne_res])[None, ...].astype(np.float32)
            h_np = np.concatenate([x[1:n_dim+1][:, None] for x in mne_res], axis=1).astype(np.float32)
            J_np = np.concatenate([np.reshape(x[(n_dim+1):], (n_dim, n_dim, 1)) for x in mne_res], axis=-1).astype(np.float32)

            a_np = a_np[:, spectra_array]
            h_np = h_np[:, spectra_array]
            J_np = J_np[:, :, spectra_array]

            print('Reconstructing test stimulus segments...')
            resp_trialavg = [MNE2.trial_average(x) for x in resp]
            resp_binary = [MNE2.resp_to_binary(x) for x in resp_trialavg]

            stim_comparisons = []
            stim_reconstructions = []
            for (stim_resp, single_stim) in zip(resp_binary, stim):
                # reconstructions is n_bins x ndim; same as stims_test

                reconstructions = rct.manopt_MLE(stim_resp[:, spectra_array], a_np, h_np, J_np, radius/2, ndim=n_dim)

                n_bins = stim_resp.shape[0]
                res = np.zeros((16, n_bins+nsegbins))
                comp = []

                for idx in range(n_bins):
                    res[:, idx:idx+nsegbins] += np.reshape(reconstructions[idx, :], (16, nsegbins))
                    comp.append(rec_compare(reconstructions[idx, :],single_stim[idx, :]))
                stim_comparisons.append(comp)
                stim_reconstructions.append(res)

            class_reconstructions_output_name.append(stim_reconstructions)
            class_comparisons_output_name.append(stim_comparisons)
            print('Reconstruction complete')

    recon_results_dict = {'class_reconstructions': class_reconstructions_output_name, 
                    'class_comparisons' : class_comparisons_output_name}

#     print(len(stim_reconstructions))
#     print(len(class_reconstructions_output_name))
    print('Saving results....')
    print(output_file_name_rec)
    reconstruction_result_file = os.path.join(exp_path, output_file_name_rec)
    with open(reconstruction_result_file, 'wb') as f:
        pickle.dump(recon_results_dict, f)
    print('Done')
    return (class_reconstructions_output_name, class_comparisons_output_name)
    #print(comp)

In [None]:
np.random.seed(9393)
B = np.where(spectra_array_thresholded == True)[0]
random_thresholded_indexes = np.random.permutation(B)

arr_rand_75 = np.zeros(len(spectra),dtype=bool)
arr_rand_75[random_thresholded_indexes[:seventyfive]] = True 

arr_rand_50 = np.zeros(len(spectra),dtype=bool)
arr_rand_50[random_thresholded_indexes[:half]] = True 

arr_rand_25 = np.zeros(len(spectra),dtype=bool)
arr_rand_25[random_thresholded_indexes[:twentyfive]] = True 

arr_rand_10 = np.zeros(len(spectra),dtype=bool)
arr_rand_10[random_thresholded_indexes[:ten]] = True 

arr_rand_5 = np.zeros(len(spectra),dtype=bool)
arr_rand_5[random_thresholded_indexes[:five]] = True 

arr_rand_1 = np.zeros(len(spectra),dtype=bool)
arr_rand_1[random_thresholded_indexes[:one]] = True 

In [None]:
#seventyfive
class_reconstructions_rand_75, class_comparisons_rand_75 = selective_reconstruction('stim_reconstructions_results_20210718_left_all_rand_75_nosilence.pkl', arr_rand_75)

#half
class_reconstructions_rand_50, class_comparisons_rand_50 = selective_reconstruction('stim_reconstructions_results_20210718_left_all_rand_50_nosilence.pkl', 
                                                                         arr_rand_50)
#twentyfive
class_reconstructions_rand_25, class_comparisons_rand_25 = selective_reconstruction('stim_reconstructions_results_20210718_left_all_rand_25_nosilence.pkl', arr_rand_25)

#ten
class_reconstructions_rand_10, class_comparisons_rand_10 = selective_reconstruction('stim_reconstructions_results_20210718_left_all_rand_10_nosilence.pkl', arr_rand_10)

#five
class_reconstructions_rand_5, class_comparisons_rand_5 = selective_reconstruction('stim_reconstructions_results_20210718_left_all_rand_5_nosilence.pkl', arr_rand_5)

#one
class_reconstructions_rand_1, class_comparisons_rand_1 = selective_reconstruction('stim_reconstructions_results_20210718_left_all_rand_1_nosilence.pkl',  arr_rand_1)

In [None]:
#find overall true and shuff means
def find_overall(class_reconstructions, stims_plot = kozlov_stim_responses[0][0], stims_plot_shuff = kozlov_stim_responses[1][0]):
    stims_plot = kozlov_stim_responses[0][0]
    stims_plot_shuff = kozlov_stim_responses[1][0]
    recons_plot = class_reconstructions[0]
    recons_plot_shuff = class_reconstructions[1]
    true_means = []
    shuff_means = []
    for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
        cosine_diff = rec_compare(orig, recon[:, :-16])
        true_means.append(cosine_diff)
        cosine_diff_shuff = rec_compare(orig_shuff, recon_shuff[:, :-16])
        shuff_means.append(cosine_diff_shuff)
    return true_means, shuff_means

#true_means_rand_1, shuff_means_rand_1 = find_overall(class_reconstructions_rand_1)
    

def find_binned (class_comparisons):
    comparisons, comparisons_shuff = class_comparisons[0], class_comparisons[1]
    comparison_true_means = []
    comparison_shuff_means = []
    for comparison, comparison_shuff in zip(comparisons, comparisons_shuff):
        comparison_true_mean = mean(comparison)
        comparison_true_means.append(comparison_true_mean)
        comparison_shuff_mean = mean(comparison_shuff)
        comparison_shuff_means.append(comparison_shuff_mean)
    return comparison_true_means, comparison_shuff_means
    
#comparison_true_means, comparison_shuff_means = find_binned(class_comparisons)


In [None]:
print(len(class_comparisons_rand_1))

In [None]:
true_means_rand_75, shuff_means_rand_75 = find_overall(class_reconstructions_rand_75)
true_means_rand_50, shuff_means_rand_50 = find_overall(class_reconstructions_rand_50)
true_means_rand_25, shuff_means_rand_25 = find_overall(class_reconstructions_rand_25)
true_means_rand_10, shuff_means_rand_10 = find_overall(class_reconstructions_rand_10)
true_means_rand_5, shuff_means_rand_5 = find_overall(class_reconstructions_rand_5)
true_means_rand_1, shuff_means_rand_1 = find_overall(class_reconstructions_rand_1)

means = [true_means_1, true_means_5, true_means_10, true_means_25, true_means_50, true_means_75, true_means]
means2 = [shuff_means_1, shuff_means_5, shuff_means_10, shuff_means_25, shuff_means_50, shuff_means_75, shuff_means]

means_rand = [true_means_rand_1, true_means_rand_5, true_means_rand_10, true_means_rand_25, true_means_rand_50, true_means_rand_75, true_means]
means_rand_shuff = [shuff_means_rand_1, shuff_means_rand_5, shuff_means_rand_10, shuff_means_rand_25, shuff_means_rand_50, shuff_means_rand_75, shuff_means]


labels = ['1%', '5%', '10%', '25%', '50%', '75%', '100%','1%', '5%', '10%', '25%', '50%', '75%', '100%','1%', '5%', '10%', '25%', '50%', '75%', '100%','1%', '5%', '10%', '25%', '50%', '75%', '100%']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True,boxprops=dict(facecolor="lightgreen"))
bplot2 = axs.boxplot(means2, patch_artist=True, boxprops=dict(facecolor="darkgreen"))
bplot3 = axs.boxplot(means_rand, patch_artist=True,boxprops=dict(facecolor="orchid"))
bplot4 = axs.boxplot(means_rand_shuff, patch_artist=True, boxprops=dict(facecolor="rebeccapurple"))
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # labels
axs.legend([bplot["boxes"][0], bplot2["boxes"][0], bplot3["boxes"][0], bplot4["boxes"][0]], ['Ordered', 'Shuffled Ordered', 'Random', 'Shuffled Random'], loc='lower right')
axs.set_ylabel('Cosine Similarity')
axs.set_title('Overall Cosine Difference Values - Left', fontsize='large')
plt.show()

In [None]:

labels = ['1%', '5%', '10%', '25%', '50%', '75%', '100%','1%', '5%', '10%', '25%', '50%', '75%', '100%']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True,boxprops=dict(facecolor="lightgreen"))
bplot3 = axs.boxplot(means_rand, patch_artist=True,boxprops=dict(facecolor="orchid"))
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # labels
axs.legend([bplot["boxes"][0],bplot3["boxes"][0]], ['Ordered','Random'], loc='lower right')
axs.set_ylabel('Cosine Similarity')
axs.set_title('Overall Cosine Difference Values - Left', fontsize='large')
plt.show()

In [None]:
labels = ['1%', '5%', '10%', '25%', '50%', '75%', '100%','1%', '5%', '10%', '25%', '50%', '75%', '100%']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot3 = axs.boxplot(means_rand, patch_artist=True,boxprops=dict(facecolor="orchid"))
bplot4 = axs.boxplot(means_rand_shuff, patch_artist=True, boxprops=dict(facecolor="rebeccapurple"))
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # labels
axs.legend([bplot3["boxes"][0], bplot4["boxes"][0]], ['Random', 'Shuffled Random'], loc='lower right')
axs.set_ylabel('Cosine Similarity')
axs.set_title('Overall Cosine Difference Values - Left', fontsize='large')
plt.show()

In [None]:

labels = ['1%', '5%', '10%', '25%', '50%', '75%', '100%']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot3 = axs.boxplot(means_rand, patch_artist=True,boxprops=dict(facecolor="orchid"))
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # labels
#axs.legend([bplot["boxes"][0], bplot2["boxes"][0]], bplot3["boxes"][0], bplot4["boxes"][0]], ['Ordered', 'Shuffled Ordered', 'Random', 'Shuffled Random'], loc='lower right')
axs.set_ylabel('Cosine Similarity')
axs.set_title('Overall Cosine Difference Values - Left', fontsize='large')
plt.show()

## Loading all of the sets of reconstructions above

In [None]:
def load_reconstructions (output_file_name_rec, exp_path=exp_path):
    from oe_acute import reconstruct as rct
    reconstruction_result_file = os.path.join(exp_path, output_file_name_rec)
    with open(reconstruction_result_file, 'rb') as f:
        recon_results_dict = pickle.load(f)
    class_reconstructions_load = recon_results_dict['class_reconstructions']
    class_comparisons_load = recon_results_dict['class_comparisons']
    print (f"Loaded {output_file_name_rec}")
    return class_reconstructions_load, class_comparisons_load 

In [None]:
class_reconstructions, class_comparisons = load_reconstructions('stim_reconstructions_results_20210718_left_all_nosilence.pkl')

# tclass_reconstructions_75, tclass_comparisons_75 = load_reconstructions('stim_reconstructions_results_20210718_left_all_75_nosilence.pkl')
# tclass_reconstructions_50, tclass_comparisons_50 = load_reconstructions('stim_reconstructions_results_20210718_left_all_50_nosilence.pkl')
# tclass_reconstructions_25, tclass_comparisons_25 = load_reconstructions('stim_reconstructions_results_20210718_left_all_25_nosilence.pkl')
# tclass_reconstructions_10, tclass_comparisons_10 = load_reconstructions('stim_reconstructions_results_20210718_left_all_10_nosilence.pkl')
# tclass_reconstructions_5, tclass_comparisons_5 = load_reconstructions('stim_reconstructions_results_20210718_left_all_5_nosilence.pkl')
# tclass_reconstructions_1, tclass_comparisons_1 = load_reconstructions('stim_reconstructions_results_20210718_left_all_1_nosilence.pkl')

class_reconstructions_75, class_comparisons_75 = load_reconstructions('stim_reconstructions_results_20210718_left_all_75_nosilence.pkl')
class_reconstructions_50, class_comparisons_50 = load_reconstructions('stim_reconstructions_results_20210718_left_all_50_nosilence.pkl')
class_reconstructions_25, class_comparisons_25 = load_reconstructions('stim_reconstructions_results_20210718_left_all_25_nosilence.pkl')
class_reconstructions_10, class_comparisons_10 = load_reconstructions('stim_reconstructions_results_20210718_left_all_10_nosilence.pkl')
class_reconstructions_5, class_comparisons_5 = load_reconstructions('stim_reconstructions_results_20210718_left_all_5_nosilence.pkl')
class_reconstructions_1, class_comparisons_1 = load_reconstructions('stim_reconstructions_results_20210718_left_all_1_nosilence.pkl')


class_reconstructions_rand_75, class_comparisons_rand_75 = load_reconstructions('stim_reconstructions_results_20210718_left_all_rand_75_nosilence.pkl')
class_reconstructions_rand_50, class_comparisons_rand_50 = load_reconstructions('stim_reconstructions_results_20210718_left_all_rand_50_nosilence.pkl') 
class_reconstructions_rand_25, class_comparisons_rand_25 = load_reconstructions('stim_reconstructions_results_20210718_left_all_rand_25_nosilence.pkl')
class_reconstructions_rand_10, class_comparisons_rand_10 = load_reconstructions('stim_reconstructions_results_20210718_left_all_rand_10_nosilence.pkl')
class_reconstructions_rand_5, class_comparisons_rand_5 = load_reconstructions('stim_reconstructions_results_20210718_left_all_rand_5_nosilence.pkl')
class_reconstructions_rand_1, class_comparisons_rand_1 = load_reconstructions('stim_reconstructions_results_20210718_left_all_rand_1_nosilence.pkl')

In [None]:
#lists of segmented data 

comparisons_list = [class_comparisons_75, class_comparisons_50, class_comparisons_25, class_comparisons_10, class_comparisons_5, class_comparisons_1]
reconstructions_list = [class_reconstructions_75, class_reconstructions_50,  class_reconstructions_25,  class_reconstructions_10,  class_reconstructions_5,  class_reconstructions_1]
comparisons_list_rand = [class_comparisons_rand_75, class_comparisons_rand_50, class_comparisons_rand_25, class_comparisons_rand_10, class_comparisons_rand_5, class_comparisons_rand_1]
reconstructions_list_rand = [class_reconstructions_rand_75, class_reconstructions_rand_50,  class_reconstructions_rand_25,  class_reconstructions_rand_10,  class_reconstructions_rand_5,  class_reconstructions_rand_1]

## Finding Temoral Representations of Reconstruction Quality

In [None]:
import copy
comparisons = copy.deepcopy(class_comparisons[0])
comparisons_shuff = copy.deepcopy(class_comparisons[1])
count = 0 

row_lengths = []
for row in comparisons:
    row_lengths.append(len(row))
max_length = max(row_lengths)
for row in comparisons:
    while len(row) < max_length:
        row.append(None)
balanced_array = np.array(comparisons, dtype=np.float)
avg = np.nanmean(balanced_array, axis=0)

row_lengths_shuff = []
for row in comparisons_shuff:
    row_lengths_shuff.append(len(row))
max_length_shuff = max(row_lengths_shuff)
for row in comparisons_shuff:
    while len(row) < max_length_shuff:
        row.append(None)
balanced_array_shuff = np.array(comparisons_shuff, dtype=np.float)
avg_shuff = np.nanmean(balanced_array_shuff, axis=0)

too_high = row_lengths.index(max_length)
row_lengths[too_high] = max_length-1


fig, axs = plt.subplots(3, 1, figsize=(16,9), sharex=True, sharey=True)  

for comparison, comparison_shuff in zip(comparisons, comparisons_shuff):
    count += 1
    axs[1].plot(comparison_shuff, linewidth=0.5, color = 'black')
    axs[0].plot(comparison, linewidth=0.5)
axs[0].plot(avg, linewidth=1, color = 'k', markevery=row_lengths, marker="d", mfc='white', mec='white')
axs[0].set_yticks(np.arange(0, 1.5, step=0.2)) 
axs[0].set_title(f'This is a test of all the reconstructions over time', fontsize=15)
axs[0].tick_params(labelbottom=False, direction='in', bottom=False) 
axs[1].plot(avg_shuff, linewidth=1, color = 'white')
axs[2].plot(avg, linewidth=1, color = 'orchid', markevery=row_lengths, marker="d", mfc='k', mec='k',  label = 'real')
axs[2].plot(avg_shuff, linewidth=1, color = 'lightblue', label = 'shuffled')
axs[2].legend(loc = 'lower left', fontsize = 'large')
axs[1].set_yticks(np.arange(0, 2, step=0.2)) 
axs[1].set_title(f'This is a test shuff', fontsize=15)
axs[1].tick_params(labelbottom=False, direction='in', bottom=False) 

plt.tight_layout()
plt.show()
print(count)
print(len(comparisons[2]))

## Now do the same as above for segements of song

In [None]:
recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons = class_comparisons[0]
recons_plot_shuff = class_reconstructions[1]
stims_plot_shuff = kozlov_stim_responses[1][0]
comparisons_shuff = class_comparisons[1]
true_means1 = []
shuff_means1 = []
true_means2 = []
shuff_means2 = []
true_means3 = []
shuff_means3 = []


for orig, recon, comparison, orig_shuff, recon_shuff, comparison_shuff in zip(stims_plot, recons_plot, comparisons, stims_plot_shuff, recons_plot_shuff, comparisons_shuff):
    third = round((orig.shape[1])/3)
#     if third * 3 > orig.shape[1]:
#         third - 1 = third 
    cosine_diff = rec_compare(orig[:, :third], recon[:, :third])
    true_means1.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff[:, :third], recon_shuff[:, :third])
    shuff_means1.append(cosine_diff_shuff)
    
    cosine_diff = rec_compare(orig[:, third:(third+third)], recon[:, third:(third+third)])
    true_means2.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff[:, third:(third+third)], recon_shuff[:, third:(third+third)])
    shuff_means2.append(cosine_diff_shuff)
    
    cosine_diff = rec_compare(orig[:, (third+third):], recon[:, (third+third):-16])
    true_means3.append(cosine_diff)
    cosine_diff_shuff = rec_compare(orig_shuff[:, (third+third):], recon_shuff[:, (third+third):-16])
    shuff_means3.append(cosine_diff_shuff)
    
means = [true_means1, true_means2, true_means3]
means2 = [shuff_means1, shuff_means2, shuff_means3]
labels = ['Beginning', 'Middle', 'End','Beginning', 'Middle', 'End']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True)
bplot2 = axs.boxplot(means2, patch_artist=True)
axs.set_yticks(np.arange(0, 2, step=0.2)) 
axs.set_xticklabels(labels, rotation=20) # labels
#axs.set_xlabel('Begni) # labels
colors = ['plum', 'mediumorchid', 'darkorchid', 'lightskyblue','blue','navy']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)
axs.set_ylabel('Cosine Difference')
axs.set_title('Overall Cosine Difference Across Segements of Shuffled Responses', fontsize='medium')
plt.show()
        


In [None]:
recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons = class_comparisons[0]
recons_plot_shuff = class_reconstructions[1]
stims_plot_shuff = kozlov_stim_responses[1][0]
comparisons_shuff = class_comparisons[1]

list_shuff_means = [[] for i in range(10)]
list_true_means = [[] for i in range(10)]


for orig, recon, comparison, orig_shuff, recon_shuff, comparison_shuff in zip(stims_plot, recons_plot, comparisons, stims_plot_shuff, recons_plot_shuff, comparisons_shuff):
    tenth = round((orig.shape[1])/10)
    tenths = [*range(0, orig.shape[1], tenth)]
    if tenths[-1] != orig.shape[1]:
        tenths.append(orig.shape[1])
    for i in range(10):
        cosine_diff = rec_compare(orig[:,tenths[i]:tenths[i+1]], recon[:,tenths[i]:tenths[i+1]])
        list_true_means[i].append(cosine_diff)
        cosine_diff_shuff = rec_compare(orig_shuff[:,tenths[i]:tenths[i+1]], recon_shuff[:,tenths[i]:tenths[i+1]])
        list_shuff_means[i].append(cosine_diff_shuff)    
    
means = [list_true_means[0], list_true_means[1], list_true_means[2],list_true_means[3],list_true_means[4],list_true_means[5],list_true_means[6],list_true_means[7],list_true_means[8],list_true_means[9]]
means2 = [list_shuff_means[0], list_shuff_means[1],list_shuff_means[2],list_shuff_means[3],list_shuff_means[4],list_shuff_means[5],list_shuff_means[6],list_shuff_means[7],list_shuff_means[8],list_shuff_means[9]]
#labels = ['Beginning', 'Middle', 'End','Beginning', 'Middle', 'End']
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
bplot = axs.boxplot(means, patch_artist=True, boxprops=dict(facecolor=c, color='k'),medianprops=dict(color='k'))
bplot2 = axs.boxplot(means2, patch_artist=True, boxprops=dict(facecolor=d, color='k'),medianprops=dict(color='k'))
axs.set_yticks(np.arange(0, 2, step=0.2)) 
#axs.set_xticklabels(labels, rotation=20) # labels
#axs.set_xlabel('Begni) # labels
axs.legend([bplot["boxes"][0], bplot2["boxes"][0]], ["True", "Shuffled"])
axs.set_ylabel('Cosine Difference')
axs.set_xlabel('Tenth of Song')
axs.set_title('Overall Cosine Difference Of Bouts When Spilt Into Tenths', fontsize='medium')
plt.show()
        

In [None]:
import matplotlib.patches as mpatches

recons_plot = class_reconstructions[0]
stims_plot = kozlov_stim_responses[0][0]
comparisons = class_comparisons[0]
recons_plot_shuff = class_reconstructions[1]
stims_plot_shuff = kozlov_stim_responses[1][0]
comparisons_shuff = class_comparisons[1]

share = 20 
list_shuff_means = [[] for i in range(share)]
list_true_means = [[] for i in range(share)]


for orig, recon, comparison, orig_shuff, recon_shuff, comparison_shuff in zip(stims_plot, recons_plot, comparisons, stims_plot_shuff, recons_plot_shuff, comparisons_shuff):
    part = round((orig.shape[1])/share)
    parts = [*range(0, orig.shape[1], part)]
    if parts[-1] != orig.shape[1]:
        parts.append(orig.shape[1])
    for i in range(share):
        cosine_diff = rec_compare(orig[:,parts[i]:parts[i+1]], recon[:,parts[i]:parts[i+1]])
        list_true_means[i].append(cosine_diff)
        cosine_diff_shuff = rec_compare(orig_shuff[:,parts[i]:parts[i+1]], recon_shuff[:,parts[i]:parts[i+1]])
        list_shuff_means[i].append(cosine_diff_shuff)    

array_true_means = np.array(list_true_means)
avg_parts = np.average(array_true_means, axis=1)

array_shuff_means = np.array(list_shuff_means)
avg_shuff_parts = np.average(array_shuff_means, axis=1)

In [None]:
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(10,7))
line1 = axs.plot(list_true_means, linewidth=1, color = 'orchid', label = 'True')
line2 = axs.plot(list_shuff_means, linewidth=1, color = 'lightblue', label = 'Shuffled')
avg_line1 = axs.plot(avg_parts, linewidth=1.5, color = 'k')
avg_line2 = axs.plot(avg_shuff_parts, linewidth=1.5, color = 'k')
axs.set_yticks(np.arange(0, 2.1, step=0.2)) 
axs.set_xticks(np.arange(0, share, step=2)) 
axs.set_ylabel('Cosine Difference')
axs.set_xlabel('Segment of Song')
True_markers = mpatches.Patch(color='orchid', label='True')
Shuffled_markers = mpatches.Patch(color='lightblue', label='Shuffled')
axs.legend(handles=[True_markers, Shuffled_markers])
#axs.legend()
axs.set_title('Overall Cosine Similarity Of Bouts When Spilt', fontsize='medium')
plt.show()
        

In [None]:
# comparisons_list = [class_comparisons_75, class_comparisons_50, class_comparisons_25, class_comparisons_10, class_comparisons_5, class_comparisons_1]
reconstructions_list = [class_reconstructions_75, class_reconstructions_50,  class_reconstructions_25,  class_reconstructions_10,  class_reconstructions_5,  class_reconstructions_1]
# comparisons_list_rand = [class_comparisons_rand_75, class_comparisons_rand_50, class_comparisons_rand_25, class_comparisons_rand_10, class_comparisons_rand_5, class_comparisons_rand_1]
# reconstructions_list_rand = [class_reconstructions_rand_75, class_reconstructions_rand_50,  class_reconstructions_rand_25,  class_reconstructions_rand_10,  class_reconstructions_rand_5,  class_reconstructions_rand_1]

stims_plot_shuff = kozlov_stim_responses[1][0]
stims_plot = kozlov_stim_responses[0][0]

values = [75,50,25,10,5,1]
values_str = ['75','50','25','10','5','1']

# percent_list_shuff_means = [[] for i in range(len(reconstructions_list_copy))]
# percent_list_true_means =[[] for i in range(len(reconstructions_list_copy))]
# percent_list_shuff_avg = [[] for i in range(len(reconstructions_list_copy))]
# percent_list_true_avg = [[] for i in range(len(reconstructions_list_copy))]
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(16,12))
plot1 = axs.plot(avg_parts, linewidth=1.5, color = 'orchid', label = '100')
colors = ('plum','purple', 'magenta','mediumpurple','navy','k')
for i, color, value in zip(range(len(reconstructions_list)), colors, values_str): 
    recons_plot = copy.deepcopy(reconstructions_list[i][0])
    recons_plot_shuff = copy.deepcopy(reconstructions_list[i][1])
    share = 20 
    list_shuff_means = [[] for i in range(share)]
    list_true_means = [[] for i in range(share)]
    for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
        part = round((orig.shape[1])/share)
        parts = [*range(0, orig.shape[1], part)]
        if parts[-1] != orig.shape[1]:
            parts.append(orig.shape[1])
        for i in range(share):
            cosine_diff = rec_compare(orig[:,parts[i]:parts[i+1]], recon[:,parts[i]:parts[i+1]])
            list_true_means[i].append(cosine_diff)
            cosine_diff_shuff = rec_compare(orig_shuff[:,parts[i]:parts[i+1]], recon_shuff[:,parts[i]:parts[i+1]])
            list_shuff_means[i].append(cosine_diff_shuff)    

    array_true_means = np.array(list_true_means)
    avg_parts = np.average(array_true_means, axis=1)
    array_shuff_means = np.array(list_shuff_means)
    avg_shuff_parts = np.average(array_shuff_means, axis=1)
    axs.plot(avg_parts, linewidth=1.5, color = color, label = value)
    
#avg_line1 = axs.plot(avg_parts, linewidth=1, color = 'k', label = '100')
axs.set_yticks(np.arange(0, 2.1, step=0.2)) 
axs.set_xticks(np.arange(0, share, step=2)) 
axs.set_ylabel('Cosine Difference')
axs.set_xlabel('Segment of Song')
axs.legend()
#     percent_list_shuff_means.append(list_shuff_means)
#     percent_list_true_means.append(list_true_means)
#     percent_list_shuff_avg.append(avg_parts)
#     percent_list_true_avg.append(avg_shuff_parts)

In [None]:
# comparisons_list = [class_comparisons_75, class_comparisons_50, class_comparisons_25, class_comparisons_10, class_comparisons_5, class_comparisons_1]
reconstructions_list = [class_reconstructions_75, class_reconstructions_50,  class_reconstructions_25,  class_reconstructions_10,  class_reconstructions_5,  class_reconstructions_1]
# comparisons_list_rand = [class_comparisons_rand_75, class_comparisons_rand_50, class_comparisons_rand_25, class_comparisons_rand_10, class_comparisons_rand_5, class_comparisons_rand_1]
# reconstructions_list_rand = [class_reconstructions_rand_75, class_reconstructions_rand_50,  class_reconstructions_rand_25,  class_reconstructions_rand_10,  class_reconstructions_rand_5,  class_reconstructions_rand_1]

stims_plot_shuff = kozlov_stim_responses[1][0]
stims_plot = kozlov_stim_responses[0][0]

values = [75,50,25,10,5,1]
values_str = ['75','50','25','10','5','1']

# percent_list_shuff_means = [[] for i in range(len(reconstructions_list_copy))]
# percent_list_true_means =[[] for i in range(len(reconstructions_list_copy))]
# percent_list_shuff_avg = [[] for i in range(len(reconstructions_list_copy))]
# percent_list_true_avg = [[] for i in range(len(reconstructions_list_copy))]
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(16,12))
plot1 = axs.plot(avg_parts, linewidth=1.5, color = 'orchid', label = '100')
colors = ('mediumpurple','rebeccapurple', 'darkviolet','magenta','deeppink','plum')
colors_shuff = ('darkolivegreen','green', 'lightgreen','lime','olivedrab','turquoise')
for i, color, value, color_shuff in zip(range(len(reconstructions_list)), colors, values_str, colors_shuff): 
    recons_plot = copy.deepcopy(reconstructions_list[i][0])
    recons_plot_shuff = copy.deepcopy(reconstructions_list[i][1])
    share = 20 
    list_shuff_means = [[] for i in range(share)]
    list_true_means = [[] for i in range(share)]
    for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
        part = round((orig.shape[1])/share)
        parts = [*range(0, orig.shape[1], part)]
        if parts[-1] != orig.shape[1]:
            parts.append(orig.shape[1])
        for i in range(share):
            cosine_diff = rec_compare(orig[:,parts[i]:parts[i+1]], recon[:,parts[i]:parts[i+1]])
            list_true_means[i].append(cosine_diff)
            cosine_diff_shuff = rec_compare(orig_shuff[:,parts[i]:parts[i+1]], recon_shuff[:,parts[i]:parts[i+1]])
            list_shuff_means[i].append(cosine_diff_shuff)    

    array_true_means = np.array(list_true_means)
    avg_parts = np.average(array_true_means, axis=1)
    array_shuff_means = np.array(list_shuff_means)
    avg_shuff_parts = np.average(array_shuff_means, axis=1)
    axs.plot(avg_parts, linewidth=1.5, color = color, label = value)
    axs.plot(avg_shuff_parts, linewidth=1.5, color = color_shuff, label = f'{value} shuff')
    
#avg_line1 = axs.plot(avg_parts, linewidth=1, color = 'k', label = '100')
axs.set_yticks(np.arange(0, 2.1, step=0.2)) 
axs.set_xticks(np.arange(0, share, step=2)) 
axs.set_ylabel('Cosine Difference')
axs.set_xlabel('Segment of Song')
axs.legend()
#     percent_list_shuff_means.append(list_shuff_means)
#     percent_list_true_means.append(list_true_means)
#     percent_list_shuff_avg.append(avg_parts)
#     percent_list_true_avg.append(avg_shuff_parts)

In [None]:

stims_plot_shuff = kozlov_stim_responses[1][0]
stims_plot = kozlov_stim_responses[0][0]

values = [75,50,25,10,5,1]
values_str = ['75','50','25','10','5','1']

# percent_list_shuff_means = [[] for i in range(len(reconstructions_list_copy))]
# percent_list_true_means =[[] for i in range(len(reconstructions_list_copy))]
# percent_list_shuff_avg = [[] for i in range(len(reconstructions_list_copy))]
# percent_list_true_avg = [[] for i in range(len(reconstructions_list_copy))]
fig, axs = plt.subplots(1, 1, sharex=True, figsize=(16,12))
plot1 = axs.plot(avg_parts, linewidth=1.5, color = 'orchid', label = '100')
colors = ('mediumpurple','rebeccapurple', 'darkviolet','magenta','deeppink','plum')
colors_rand = ('k','dimgray','grey','darkgrey','silver','tan')
for i, color_rand, value in zip(range(len(reconstructions_list_rand)), colors_rand, values_str): 
    recons_plot = copy.deepcopy(reconstructions_list_rand[i][0])
    recons_plot_shuff = copy.deepcopy(reconstructions_list_rand[i][1])
    share = 20 
    list_shuff_means = [[] for i in range(share)]
    list_true_means = [[] for i in range(share)]
    for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
        part = round((orig.shape[1])/share)
        parts = [*range(0, orig.shape[1], part)]
        if parts[-1] != orig.shape[1]:
            parts.append(orig.shape[1])
        for i in range(share):
            cosine_diff = rec_compare(orig[:,parts[i]:parts[i+1]], recon[:,parts[i]:parts[i+1]])
            list_true_means[i].append(cosine_diff)
            cosine_diff_shuff = rec_compare(orig_shuff[:,parts[i]:parts[i+1]], recon_shuff[:,parts[i]:parts[i+1]])
            list_shuff_means[i].append(cosine_diff_shuff)    

    array_true_means = np.array(list_true_means)
    avg_parts = np.average(array_true_means, axis=1)
    array_shuff_means = np.array(list_shuff_means)
    avg_shuff_parts = np.average(array_shuff_means, axis=1)
    axs.plot(avg_parts, linewidth=1.5, color = color_rand, label = f'{value} random')
    
    
for i, color, value, color_shuff in zip(range(len(reconstructions_list)), colors, values_str, colors_shuff): 
    recons_plot = copy.deepcopy(reconstructions_list[i][0])
    recons_plot_shuff = copy.deepcopy(reconstructions_list[i][1])
    share = 20 
    list_shuff_means = [[] for i in range(share)]
    list_true_means = [[] for i in range(share)]
    for orig, recon, orig_shuff, recon_shuff in zip(stims_plot, recons_plot, stims_plot_shuff, recons_plot_shuff):
        part = round((orig.shape[1])/share)
        parts = [*range(0, orig.shape[1], part)]
        if parts[-1] != orig.shape[1]:
            parts.append(orig.shape[1])
        for i in range(share):
            cosine_diff = rec_compare(orig[:,parts[i]:parts[i+1]], recon[:,parts[i]:parts[i+1]])
            list_true_means[i].append(cosine_diff)
            cosine_diff_shuff = rec_compare(orig_shuff[:,parts[i]:parts[i+1]], recon_shuff[:,parts[i]:parts[i+1]])
            list_shuff_means[i].append(cosine_diff_shuff)    

    array_true_means = np.array(list_true_means)
    avg_parts = np.average(array_true_means, axis=1)
    array_shuff_means = np.array(list_shuff_means)
    avg_shuff_parts = np.average(array_shuff_means, axis=1)
    axs.plot(avg_parts, linewidth=1.5, color = color, label = value)
    #axs.plot(avg_shuff_parts, linewidth=1.5, color = color_shuff, label = f'{value} shuff')
#avg_line1 = axs.plot(avg_parts, linewidth=1, color = 'k', label = '100')
axs.set_yticks(np.arange(0, 2.1, step=0.2)) 
axs.set_xticks(np.arange(0, share, step=2)) 
axs.set_ylabel('Cosine Difference')
axs.set_xlabel('Segment of Song')
axs.legend()