In [None]:
import os
import sys
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import glob
import time
import pickle
from scipy import signal
from scipy.spatial.distance import pdist,squareform
from sklearn.metrics import explained_variance_score
from sklearn.linear_model import Ridge, RidgeCV

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
save_figures = False
file_ending = '.png'
model_string = 'NMDA'

dataset_folder = '/kaggle/input/single-neurons-as-deep-nets-nmda-test-data'

morphology_folder = os.path.join(dataset_folder, 'Morphology')
test_data_folder  = os.path.join(dataset_folder, 'Data_test')

morphology_filename      = os.path.join(morphology_folder, 'morphology_dict.pickle')
test_files               = sorted(glob.glob(os.path.join(test_data_folder, '*_128_simulationRuns*_6_secDuration_*')))

print('-----------------------------------------------')
print('finding files: morphology and test data')
print('-----------------------------------------------')
print('morphology found     : "%s"' %(morphology_filename.split('/')[-1]))
print('number of test files is %d' %(len(test_files)))
print('-----------------------------------------------')

# Helper functions

In [None]:


def dict2bin(row_inds_spike_times_map, num_segments, sim_duration_ms):
    
    bin_spikes_matrix = np.zeros((num_segments, sim_duration_ms), dtype='bool')
    for row_ind in row_inds_spike_times_map.keys():
        for spike_time in row_inds_spike_times_map[row_ind]:
            bin_spikes_matrix[row_ind,spike_time] = 1.0
    
    return bin_spikes_matrix


def parse_sim_experiment_file(sim_experiment_file):
    
    print('-----------------------------------------------------------------')
    print("loading file: '" + sim_experiment_file.split("\\")[-1] + "'")
    loading_start_time = time.time()
    experiment_dict = pickle.load(open(sim_experiment_file, "rb" ), encoding='latin1')
    
    # gather params
    num_simulations = len(experiment_dict['Results']['listOfSingleSimulationDicts'])
    num_segments    = len(experiment_dict['Params']['allSegmentsType'])
    sim_duration_ms = experiment_dict['Params']['totalSimDurationInSec']*1000
    num_ex_synapses  = num_segments
    num_inh_synapses = num_segments
    num_synapses = num_ex_synapses + num_inh_synapses
    
    # collect X, y_spike, y_soma
    X = np.zeros((num_synapses,sim_duration_ms,num_simulations), dtype='bool')
    y_spike = np.zeros((sim_duration_ms,num_simulations))
    y_soma  = np.zeros((sim_duration_ms,num_simulations))
    for k, sim_dict in enumerate(experiment_dict['Results']['listOfSingleSimulationDicts']):
        X_ex  = dict2bin(sim_dict['exInputSpikeTimes'] , num_segments, sim_duration_ms)
        X_inh = dict2bin(sim_dict['inhInputSpikeTimes'], num_segments, sim_duration_ms)
        X[:,:,k] = np.vstack((X_ex,X_inh))
        spike_times = (sim_dict['outputSpikeTimes'].astype(float) - 0.5).astype(int)
        y_spike[spike_times,k] = 1.0
        y_soma[:,k] = sim_dict['somaVoltageLowRes']

    loading_duration_sec = time.time() - loading_start_time 
    print('loading took %.3f seconds' %(loading_duration_sec))
    print('-----------------------------------------------------------------')

    return X, y_spike, y_soma


def parse_multiple_sim_experiment_files(sim_experiment_files):
    
    for k, sim_experiment_file in enumerate(sim_experiment_files):
        X_curr, y_spike_curr, y_soma_curr = parse_sim_experiment_file(sim_experiment_file)
        
        if k == 0:
            X       = X_curr
            y_spike = y_spike_curr
            y_soma  = y_soma_curr
        else:
            X       = np.dstack((X,X_curr))
            y_spike = np.hstack((y_spike,y_spike_curr))
            y_soma  = np.hstack((y_soma,y_soma_curr))
            
    return X, y_spike, y_soma


def parse_sim_experiment_file_with_DVTs(sim_experiment_file, return_high_res=False):
    experiment_dict = pickle.load(open(sim_experiment_file, "rb" ), encoding='latin1')

    X_spikes, _, _ = parse_sim_experiment_file(sim_experiment_file)

    # gather params
    num_simulations = len(experiment_dict['Results']['listOfSingleSimulationDicts'])
    num_segments    = len(experiment_dict['Params']['allSegmentsType'])
    sim_duration_ms = 1000 * experiment_dict['Params']['totalSimDurationInSec']

    # collect X, y_spike, y_soma
    sim_dict = experiment_dict['Results']['listOfSingleSimulationDicts'][0]

    t_LR = sim_dict['recordingTimeLowRes']
    t_HR = sim_dict['recordingTimeHighRes']
    y_soma_LR  = np.zeros((sim_duration_ms,num_simulations))
    y_nexus_LR = np.zeros((sim_duration_ms,num_simulations))
    y_soma_HR  = np.zeros((sim_dict['somaVoltageHighRes'].shape[0],num_simulations))
    y_nexus_HR = np.zeros((sim_dict['nexusVoltageHighRes'].shape[0],num_simulations))

    y_DVTs  = np.zeros((num_segments,sim_duration_ms,num_simulations), dtype=np.float16)

    # go over all simulations in the experiment and collect their results
    for k, sim_dict in enumerate(experiment_dict['Results']['listOfSingleSimulationDicts']):
        y_nexus_LR[:,k] = sim_dict['nexusVoltageLowRes']
        y_soma_LR[:,k] = sim_dict['somaVoltageLowRes']    
        y_nexus_HR[:,k] = sim_dict['nexusVoltageHighRes']
        y_soma_HR[:,k] = sim_dict['somaVoltageHighRes']    
        y_DVTs[:,:,k] = sim_dict['dendriticVoltagesLowRes']

        output_spike_times = np.int32(sim_dict['outputSpikeTimes'])
        # fix "voltage spikes" in low res
        y_soma_LR[output_spike_times,k] = 30
        
    if return_high_res:
        return X_spikes, y_DVTs, t_LR, y_soma_LR, y_nexus_LR, t_HR, y_soma_HR, y_nexus_HR
    else:
        return X_spikes, y_DVTs, t_LR, y_soma_LR, y_nexus_LR


In [None]:
##%% load morphology

morphology_dict = pickle.load(open(morphology_filename, "rb" ), encoding='latin1')

allSectionsLength                  = morphology_dict['all_sections_length']
allSections_DistFromSoma           = morphology_dict['all_sections_distance_from_soma']
allSegmentsLength                  = morphology_dict['all_segments_length']
allSegmentsType                    = morphology_dict['all_segments_type']
allSegments_DistFromSoma           = morphology_dict['all_segments_distance_from_soma']
allSegments_SectionDistFromSoma    = morphology_dict['all_segments_section_distance_from_soma']
allSegments_SectionInd             = morphology_dict['all_segments_section_index']
allSegments_seg_ind_within_sec_ind = morphology_dict['all_segments_segment_index_within_section_index']

all_basal_section_coords  = morphology_dict['all_basal_section_coords']
all_basal_segment_coords  = morphology_dict['all_basal_segment_coords']
all_apical_section_coords = morphology_dict['all_apical_section_coords']
all_apical_segment_coords = morphology_dict['all_apical_segment_coords']

seg_ind_to_xyz_coords_map = {}
seg_ind_to_sec_ind_map = {}
for k in range(len(allSegmentsType)):
    curr_segment_ind = allSegments_seg_ind_within_sec_ind[k]
    if allSegmentsType[k] == 'basal':
        curr_section_ind = allSegments_SectionInd[k]
        seg_ind_to_xyz_coords_map[k] = all_basal_segment_coords[(curr_section_ind,curr_segment_ind)]
        seg_ind_to_sec_ind_map[k] = ('basal', curr_section_ind)
    elif allSegmentsType[k] == 'apical':
        curr_section_ind = allSegments_SectionInd[k] - len(all_basal_section_coords)
        seg_ind_to_xyz_coords_map[k] = all_apical_segment_coords[(curr_section_ind,curr_segment_ind)]
        seg_ind_to_sec_ind_map[k] = ('apical', curr_section_ind)
    else:
        print('error!')

sim_experiment_file = test_files[0]
experiment_dict = pickle.load(open(sim_experiment_file, "rb" ), encoding='latin1')
section_index      = np.array(experiment_dict['Params']['allSegments_SectionInd'])
distance_from_soma = np.array(experiment_dict['Params']['allSegments_SectionDistFromSoma'])
is_basal           = np.array([x == 'basal' for x in experiment_dict['Params']['allSegmentsType']])

In [None]:
# load dendritic voltage traces of single simulation file
sim_experiment_file = test_files[0]
X_spikes, y_DVTs, t_LR, y_soma_LR, y_nexus_LR = parse_sim_experiment_file_with_DVTs(sim_experiment_file, return_high_res=False)

In [None]:
print(X_spikes.shape)
print(y_DVTs.shape)
print(t_LR.shape)
print(y_soma_LR.shape)
print(y_nexus_LR.shape)

# Calculate dendritic voltages correlation matrix

In [None]:
corr_matrix_DVTs = np.corrcoef(y_DVTs.reshape([y_DVTs.shape[0],-1]))

plt.figure(figsize=(13,10));
plt.imshow(corr_matrix_DVTs); plt.colorbar();
plt.title('dendritic voltages correlation matrix', fontsize=20)
plt.xlabel('dendritic segment index', fontsize=20)
plt.ylabel('dendritic segment index', fontsize=20);

# Select a single dendritic segment

In [None]:
selected_segment_index = 456


num_segments = corr_matrix_DVTs.shape[0]
selected_simulation = 0

X_exc_segment = X_spikes[selected_segment_index,:,selected_simulation]
X_inh_segment = X_spikes[num_segments + selected_segment_index,:,selected_simulation]
y_DVT_segment = y_DVTs[selected_segment_index,:,selected_simulation]

min_voltage = y_DVT_segment.min()

plt.figure(figsize=(20,12));
plt.subplot(2,1,1); plt.title('segment %d' %(selected_segment_index), fontsize=24)
plt.plot(5 * X_exc_segment + min_voltage - 2, color='r')
plt.plot(5 * X_inh_segment + min_voltage - 2, color='b')
plt.plot(y_DVT_segment, color='k')
plt.xlim(0,X_exc_segment.shape[0]); plt.xlabel('time [ms]', fontsize=20); plt.ylabel('voltage [mV]', fontsize=20);

plt.subplot(2,1,2);
plt.plot(5 * X_exc_segment + min_voltage - 2, color='r')
plt.plot(5 * X_inh_segment + min_voltage - 2, color='b')
plt.plot(y_DVT_segment, color='k')
plt.xlim(250,1750); plt.xlabel('time [ms]', fontsize=20); plt.ylabel('voltage [mV]', fontsize=20);
plt.legend(['exc input directly onto segment','inh input directly onto segment','segment voltage'], fontsize=20);

# Find 20 most similar segments 
According to similar rows in correlation matrix

In [None]:
num_nearby_segments = 30

segment_distance_matrix = squareform(pdist(corr_matrix_DVTs, 'correlation'))

sorted_segments_by_distance = np.argsort(segment_distance_matrix[selected_segment_index,:])
selected_nearby_segment_inds = sorted_segments_by_distance[:num_nearby_segments]

plt.figure(figsize=(13,10));
plt.imshow(segment_distance_matrix); plt.colorbar();
plt.title('segment dissimilarity matrix', fontsize=20)
plt.xlabel('dendritic segment index', fontsize=20)
plt.ylabel('dendritic segment index', fontsize=20);

# Display on top of morphology the selected index as well as nearby inds as well

In [None]:

def plot_morphology(ax, segment_colors, width_mult_factors=None):
    
    if width_mult_factors is None:
        width_mult_factor = 1.2
        width_mult_factors = width_mult_factor * np.ones((segment_colors.shape))
        
    segment_colors = segment_colors / segment_colors.max()
    colors = plt.cm.jet(segment_colors)
    
    all_seg_inds = seg_ind_to_xyz_coords_map.keys()

    # assemble the colors for each dendritic segment
    colors_per_segment = {}
    widths_per_segment = {}
    for seg_ind in all_seg_inds:
        colors_per_segment[seg_ind] = colors[seg_ind]
        widths_per_segment[seg_ind] = width_mult_factors[seg_ind]
        
    # plot the cell morphology
    for key in all_seg_inds:
        seg_color = colors_per_segment[key]
        #seg_line_width = width_mult_factor * np.array(seg_ind_to_xyz_coords_map[key]['d']).mean()
        seg_line_width = widths_per_segment[key] * np.array(seg_ind_to_xyz_coords_map[key]['d']).mean()
        seg_x_coords = seg_ind_to_xyz_coords_map[key]['x']
        seg_y_coords = seg_ind_to_xyz_coords_map[key]['y']

        ax.plot(seg_x_coords,seg_y_coords,lw=seg_line_width,color=seg_color)

    # add black soma    
    ax.scatter(x=45.5,y=19.8,s=120,c='k')
    ax.set_xlim(-180,235)
    ax.set_ylim(-210,1200);


# all segments in sequential order colors and widths
segment_colors_in_order = np.arange(num_segments)

# nearby (similar) segments colors and widths
segment_colors_nearby = np.zeros(segment_colors_in_order.shape)
segment_colors_nearby[selected_nearby_segment_inds] = 1

segment_widths_nearby = 1.2 * np.ones(segment_colors_nearby.shape)
segment_widths_nearby[selected_nearby_segment_inds] = 3.0

# selected index segment colors and widths
segment_colors_selected = np.zeros(segment_colors_in_order.shape)
segment_colors_selected[selected_segment_index] = 1.0

segment_widths_selected = 1.2 * np.ones(segment_colors_nearby.shape)
segment_widths_selected[selected_segment_index] = 5.0


fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20,12))
fig.subplots_adjust(left=0.01,right=0.99,top=0.99,bottom=0.01,hspace=0.01, wspace=0.2)

plot_morphology(ax[0], segment_colors_selected, width_mult_factors=segment_widths_selected)
plot_morphology(ax[1], segment_colors_nearby, width_mult_factors=segment_widths_nearby)
plot_morphology(ax[2], segment_colors_in_order)

ax[0].set_title('selected segment', fontsize=24);
ax[1].set_title('"nearby" segments', fontsize=24);
ax[2].set_title('all segments (by seq order)', fontsize=24);

ax[0].set_axis_off()
ax[1].set_axis_off()
ax[2].set_axis_off()

# Collect (X,y) dataset for model fitting
X - only the inputs of "nearby" segments  
y - dendritic voltage of the selected segment

In [None]:
time_window_size_ms = 60

X_exc_nearby_segs = X_spikes[selected_nearby_segment_inds,:,:]
X_inh_nearby_segs = X_spikes[num_segments + selected_nearby_segment_inds,:,:]

X_nearby_segs = np.concatenate((X_exc_nearby_segs, X_inh_nearby_segs), axis=0)
y_selected_seg = y_DVTs[selected_segment_index]

num_samples_per_simulation = 500
num_simulations = y_selected_seg.shape[1]

X = np.zeros((num_simulations * num_samples_per_simulation, 2 * num_nearby_segments, time_window_size_ms))
y = np.zeros((num_simulations * num_samples_per_simulation, 1))

sample_ind = 0
for sim_ind in range(num_simulations):
    selected_timepoints = np.sort(np.random.randint(low=500, high=y_selected_seg.shape[0], size=num_samples_per_simulation))
    for time_ind in selected_timepoints:
        time_inds = np.arange(time_ind - time_window_size_ms, time_ind)
        X[sample_ind] = X_nearby_segs[:, time_inds, sim_ind]
        y[sample_ind] = y_selected_seg[time_ind, sim_ind]
        
        sample_ind += 1

print('finished collecting dataset')
print('X.shape = %s' %(str(X.shape)))
print('y.shape = %s' %(str(y.shape)))

# Train regularized linear regression model to predict y from X

In [None]:
valid_simulations = 28
valid_samples = num_samples_per_simulation * valid_simulations

X_train = X[:-valid_samples]
y_train = y[:-valid_samples]

X_valid = X[-valid_samples:]
y_valid = y[-valid_samples:]

use_cv = False

if use_cv:
    # generalizes better, but 5 * 9 times slower
    linear_regression_model = RidgeCV(alphas=[0.01, 0.1, 1.0, 10.0, 100.0], cv=9)
else:
    linear_regression_model = Ridge(alpha=10.0)
    
linear_regression_model.fit(X_train.reshape([X_train.shape[0],-1]), y_train)

print('finished training linear regression model')

# Evaluate model performance on validation set

In [None]:
y_valid_hat = linear_regression_model.predict(X_valid.reshape([X_valid.shape[0],-1]))

percent_explained_variance = 100 * explained_variance_score(y_valid, y_valid_hat)

plt.figure(figsize=(10,10));
plt.scatter(x=y_valid_hat.ravel(), y=y_valid.ravel())
plt.title('Ground Truth vs Prediction\n (%.2f%s varinace explained)' %(percent_explained_variance,'%'), fontsize=24)
plt.xlabel('predicted voltage [mV]', fontsize=20)
plt.ylabel('ground truth voltage [mV]', fontsize=20);
plt.plot([-75,-50],[-75,-50],color='k');

# Display learned weights

In [None]:
y_tick_labels = ['segment index %d (exc)' %(x) for x in selected_nearby_segment_inds] + ['segment index %d (inh)' %(x) for x in selected_nearby_segment_inds]

learned_weights = linear_regression_model.coef_.reshape([2 * num_nearby_segments, time_window_size_ms])

plt.figure(figsize=(18,14));
plt.subplot(2,1,1); plt.imshow(learned_weights)
plt.yticks(np.arange(2 * num_nearby_segments), y_tick_labels);
plt.colorbar()
plt.xlabel('time [ms]', fontsize=20);
plt.title('learned weights', fontsize=24);

# Display temporal cross section of learned weights

In [None]:
plt.figure(figsize=(14,10));
plt.plot(learned_weights[:num_nearby_segments,:].T, color='r', alpha=0.8)
plt.plot(learned_weights[num_nearby_segments:,:].T, color='b', alpha=0.8)
plt.title('temporal cross section of learned weights', fontsize=24)
plt.ylabel('weight magnitude (A.U)')
plt.xlabel('time [ms]', fontsize=20);