In [None]:
# OPTIONAL: Load the "autoreload" extension so that code can change
%load_ext autoreload

# OPTIONAL: always reload modules so that as you change code in src, it gets loaded
%autoreload 2

In [None]:
""" %matplotlib inline
%config InlineBackend.figure_format = 'retina'"""

import warnings

from tqdm.notebook import tqdm
from multiprocessing import Pool, cpu_count
from functools import partial

import numpy as np
import pandas as pd
import scipy.stats as ss
import os
import hdf5storage

from scipy.optimize import curve_fit

import statsmodels.formula.api as smf

from neurodsp.spectral import compute_spectrum
from fooof import FOOOF

""" import seaborn as sns
sns.set_style('whitegrid')"""

import matplotlib.pyplot as plt
#from matplotlib import cm, rc

""" font = {'family' : 'DejaVu Sans',
        'weight' : 'light',
        'size'   : 13}
figure = {'figsize' : (10,8)}
rc('font', **font)
rc('figure', **figure)"""

from timescales.psd import all_neurons_psd_and_rate, convert_knee_val
from timescales.acf import all_neurons_acf, exp_decay_func
from timescales.plts import plot_rsq, plot_fits
from timescales.utils import get_divs_list, load_spiketimes, compute_rsq

import traceback

In [None]:
def get_mean_std (taus, divs):
    days = sorted(set(divs))
    all_indices = []
    days_taus = []
    mean_taus = []
    std_taus = []

    #get indices for each day

    for d in days:
   
        indices = [i for i, x in enumerate(divs) if x == d]
        
        all_indices.append(indices)
        #get taus for those indices
        day_taus = taus[indices]
        days_taus.append(day_taus)
        #get means and stds
        mean_taus.append(np.mean(day_taus))
        std_taus.append(np.std(day_taus))
        
    return [days, all_indices, days_taus, mean_taus, std_taus]
    



In [None]:
def get_median_std (taus, divs):
    days = sorted(set(divs))
    all_indices = []
    days_taus = []
    median_taus = []
    std_taus = []

    #get indices for each day

    for d in days:
   
        indices = [i for i, x in enumerate(divs) if x == d]
        
        all_indices.append(indices)
        #get taus for those indices
        day_taus = taus[indices]
        days_taus.append(day_taus)
        #get means and stds
        median_taus.append(np.median(day_taus))
        std_taus.append(np.std(day_taus))
        
    return [days, all_indices, days_taus, median_taus, std_taus]
    



In [None]:
#convert matlab files to pickle - comment out after running
"""
directory_mat = '/Users/blancamartin/Downloads/Data_hc8/'
directory_pck = os.getcwd() + '/hc8_data/'

for filename in os.listdir(directory_mat):
    data = hdf5storage.loadmat(directory_mat + filename)
    spikes = data['data']['spikes'][0][0]* .05 #convert to ms
    
    np.save(directory_pck+filename[:-4], spikes , allow_pickle=True)"""
    


In [None]:
## Unpack results
# each culture was recorded at different days in vitro, so this finds those days
all_culture_divs = [result['divs'] for result in results_ac]

# array of timescale estimates
exp_decay_params = np.vstack([result['exp_decay_params'] for result in results_ac])
all_taus = [result['exp_decay_params'][:, 1] for result in results_ac]
rsqs = np.array([rsq for result in results_ac for rsq in result['rsqs']])

# Bins and windows
windows = [win for result in results_ac for win in result['windows']]
bins = [bins for result in results_ac for bins in result['bins']]


# Average cuves and fit parameters
avg_curves = [curve for result in results_ac for curve in result['average_curves']]

exp_decay_params = np.vstack([params for result in results_ac
                              for params in result['exp_decay_params']])



all_flat_taus_ac = np.array([item for sublist in all_taus for item in sublist])
all_flat_divs_ac = np.array([item for sublist in all_culture_divs for item in sublist])
#remove extreme outliers
outliers = 100*np.median(all_flat_taus_ac)
all_flat_divs_ac = all_flat_divs_ac[all_flat_taus_ac<=outliers]
all_flat_taus_ac = all_flat_taus_ac[all_flat_taus_ac<=outliers]


In [None]:
plt.plot(all_flat_divs_ac, all_flat_taus_ac/1000, '.')

In [None]:
# Unpack results
all_culture_divs_spec = [result['divs'] for result in results_spec]
all_rates_spec = [result['rate_by_div'] for result in results_spec]
all_timescales_spec = [result['timescales_by_div'] for result in results_spec]
all_cfs_spec = [result['cf_by_div'] for result in results_spec]
all_pws_spec = [result['pw_by_div'] for result in results_spec]
all_fms =  [result['fms'] for result in results_spec]
flat_all_fms = [item for sublist in all_fms for item in sublist]

In [None]:
flat_divs_tau_spec = [item for sublist in all_culture_divs_spec for item in sublist]
flat_taus_spec = [item for sublist in all_timescales_spec for item in sublist]
flat_divs_tau_spec = np.asarray(flat_divs_tau_spec)
flat_taus_spec = np.asarray(flat_taus_spec)
flat_divs_tau_spec = flat_divs_tau_spec[~np.isnan(flat_taus_spec)]
flat_taus_spec = flat_taus_spec[~np.isnan(flat_taus_spec)]
#remove extreme outliers
outliers = 100*np.median(flat_taus_spec)

flat_divs_tau_spec = flat_divs_tau_spec[flat_taus_spec<=outliers]
flat_taus_spec = flat_taus_spec[flat_taus_spec<=outliers]

In [None]:
plt.plot(flat_divs_tau_spec, flat_taus_spec, '.')

# Figure 2 for timescale dev paper

## Autocorrelation  and Specparam single unit figures

In [None]:
def h_line(x_data, b):
    

    y = [b]*len(x_data)
    return y

In [None]:
def sig_func(x, a, b, k, c):
    y = (a/(1+b*np.exp(-k*x)))+c
    return y
   

In [None]:
def sigmoid(x, L ,x0, k, b):
    y = L / (1 + np.exp(-k*(x-x0)))+b
    return (y)

In [None]:
def gaus(x,a,x0,sigma,c):
    return a*np.exp(-(x-x0)**2/(2*sigma**2)) + c

In [None]:
def fig2_timescales(time, timescales, sig_method = None, sig_func_type=sig_func, guess_h1=None, guess_h2=None):
    

    sorted_index = np.argsort(time)


    time = time[sorted_index]
    timescales = timescales[sorted_index]
    
    
    #plot data points
    plt.figure(figsize =  (6,5))
    plt.plot(time, timescales, '.', color = "dimgrey")
    #plt.ylabel("Log(timescales)", size = 18)
    #plt.xlabel("Days ", size = 18)
    plt.xticks(size = 13)
    plt.yticks(size = 13)
    
    #fit straight line model - yellow
    guess_h0 = np.mean(timescales)
    params_h0, _ = curve_fit(h_line, time, timescales, p0 = guess_h0)
    y_h0 = h_line(time, params_h0)
    rsq_h0 = rsq(time, timescales, params_h0, h_line)
    plt.plot(time, y_h0,'-.', color='#FFC125', linewidth=5, dashes=(5, 2))
    
    
    #fit sigmoid model - blue
    params_h1, _ = curve_fit(sig_func_type, time, timescales, method = sig_method, p0=guess_h1)
    
    y_h1 = sig_func_type(time, *params_h1)
    rsq_h1 = rsq(time, timescales, params_h1, sig_func)
    plt.plot(time, y_h1,'-.', color='darkgreen', linewidth=5, dashes=(4, 2))
    
    #fit inverted u gaussian model - red
    #initial guess

    
    params_h2, _ = curve_fit(gaus, time, timescales, p0=guess_h2, bounds=((0, min(time), -np.inf, -np.inf), (1, max(time), np.inf, np.inf)))
    y_h2 = gaus(time, *params_h2)
    
    plt.plot(time, y_h2,'-.',color='#B22222', linewidth=5, dashes=(2, 2))
    
    return [params_h0, params_h1, params_h2]
    
  


In [None]:
def rsq(xdata, ydata, popt, f):
    
    #residual sum of squares
    residuals = ydata- f(xdata, *popt)
    ss_res = np.sum(residuals**2)
    #total sum of squares
    ss_tot = np.sum((ydata-np.mean(ydata))**2)
    
    r_squared = 1 - (ss_res / ss_tot)
    
    return r_squared
    

In [None]:
def f_test(xdata, ydata, model0, model1, popt0, popt1, p0, p1):

    yfit0 = model0(xdata, *popt0)
    yfit1 = model1(xdata, *popt1)

    ssq0 = ((yfit0-ydata)**2).sum()
    ssq1 = ((yfit1-ydata)**2).sum()
    
    param_diff = p1 - p0
    
    if param_diff == 0:
        df_0 = len(xdata) - p0
        df_1 = len(xdata) - p1
        f_ratio = ssq1/ssq0
        p = 1 - ss.f.cdf(f_ratio, df_0, df_1)
    else:
        df = len(xdata) - p1
        f_ratio = ((ssq0 - ssq1) / param_diff) / (ssq1 / df)
        p = 1 - ss.f.cdf(f_ratio, param_diff, df)

    return f_ratio, p

## Specparam single units

In [None]:
#fits for single units specparama timescale estimation
guess_h1 = [max(flat_taus_spec), 4*np.median(flat_divs_tau_spec)/2,1.15,min(flat_taus_spec)]
n = max(flat_divs_tau_spec)                          
mean = (n/4) 

guess_h2 = [0.1, 2*mean, 10,0]

model_params = fig2_timescales(flat_divs_tau_spec,np.log10(flat_taus_spec), sig_method='dogbox', 
                guess_h1=guess_h1, guess_h2=guess_h2)


In [None]:
#rerun on mean taus
days, all_indices, days_taus, mean_taus, std_taus = get_mean_std(np.log10(flat_taus_spec), flat_divs_tau_spec)
days = np.asarray(days)
mean_taus = np.asarray(mean_taus)
days_taus = np.asarray(days_taus)

guess_h1_mean = [max(mean_taus), 4*np.mean(days)/2,1.15,min(mean_taus)]
n_mean = max(flat_divs_tau_spec)                          
mean_mean = (n/4) 

guess_h2 = [0.1, 2*mean_mean, 10,0]

model_params_mean = fig2_timescales(days,mean_taus, sig_method='dogbox', 
                guess_h1=guess_h1, guess_h2=guess_h2)

#plot with all datapoints
plt.figure(figsize =  (6,5))
plt.plot(days, mean_taus, '.', color='k', linewidth=5)
plt.plot(flat_divs_tau_spec, np.log10(flat_taus_spec), '.', color = "dimgrey")
plt.plot(days, h_line(days, *model_params_mean[0]),'-.', color='#FFC125', linewidth=5, dashes=(5, 2))
plt.plot(days, sig_func(days, *model_params_mean[1]),'-.', color='darkgreen', linewidth=5, dashes=(4, 2))
plt.plot(days, gaus(days, *model_params_mean[2]),'-.',color='#B22222', linewidth=5, dashes=(2, 2))

#plot with all datapoints
fig, ax = plt.subplots(figsize =  (6,5))
ax.plot(flat_divs_tau_spec, np.log10(flat_taus_spec), '.', color = "dimgrey")
ax.tick_params(axis='both', which='major', labelsize=13)
ax.set_xticks(list(plt.xticks()[0][1:]))
x_ticks = list(plt.xticks()[0])
new_top_labels = [int(i-4) for i in x_ticks]

#top axis
ax2 = ax.secondary_xaxis("top")
ax2.tick_params(axis='both', which='major', labelsize=13)
ax2.set_xticklabels([0] +new_top_labels )

ax.plot(days, h_line(days, *model_params_mean[0]),'-.', color='#FFC125', linewidth=5, dashes=(5, 2))
ax.plot(days, sig_func(days, *model_params_mean[1]),'-.', color='darkgreen', linewidth=5, dashes=(4, 2))
ax.plot(days, gaus(days, *model_params_mean[2]),'-.',color='#B22222', linewidth=5, dashes=(2, 2))





In [None]:
#get r-squared values

rsqs_rat_h0 = rsq(days, mean_taus, model_params_mean[0], h_line) 
rsqs_rat_h1 = rsq(days, mean_taus, model_params_mean[1], sig_func) 
rsqs_rat_h2 = rsq(days, mean_taus, model_params_mean[2], gaus) 

print("rsq rat h0:", rsqs_rat_h0)
print("rsq rat h1:", rsqs_rat_h1)
print("rsq rat h2:", rsqs_rat_h2)

In [None]:
#f-tests
f_test_h0_h1 = f_test(days, mean_taus, h_line, sig_func,
                      model_params_mean[0], model_params_mean[1], 1, 4)
f_test_h0_h2 = f_test(days, mean_taus, h_line, gaus, 
                      model_params_mean[0], model_params_mean[2],1,4)
f_test_h2_h1 = f_test(days, mean_taus, gaus, sig_func, 
                      model_params_mean[2], model_params_mean[1],4,4)

print("f_test rat h0 vs h1:", f_test_h0_h1)
print("f_test rat h0 vs h2:", f_test_h0_h2)
print("f_test rat h1 vs h2:", f_test_h2_h1)

## Autocorrelation single units

In [None]:
#single units autocorrelation timescale estimation
#plot data points
plt.figure(figsize =  (6,5))
plt.plot(all_flat_divs_ac,np.log10(all_flat_taus_ac/1000), '.', color = "dimgrey")
#plt.ylabel("Log(timescales)", size = 18)
#plt.xlabel("Days ", size = 18)
plt.xticks(size = 13)
plt.yticks(size = 13)

## Organoid autocorrelation timescales

In [None]:
#read data from pickles
org_divs = pd.read_pickle(r'/Users/blancamartin/Downloads/development_plot_x_days.pickle')
org_ac_taus = pd.read_pickle(r'/Users/blancamartin/Downloads/development_plot_y_tau.pickle')

In [None]:
all_org_divs = []
for i in range(len(org_divs)):
    all_org_divs.append([org_divs[i]]*len(org_ac_taus[0]))
    

In [None]:
flat_org_ac_divs = [item for sublist in all_org_divs for item in sublist]
flat_org_ac_taus = [item for sublist in org_ac_taus for item in sublist]
flat_org_ac_divs = np.asarray(flat_org_ac_divs)
flat_org_ac_taus = np.asarray(flat_org_ac_taus)
flat_org_ac_divs = flat_org_ac_divs[~np.isnan(flat_org_ac_taus)]
flat_org_ac_taus = flat_org_ac_taus[~np.isnan(flat_org_ac_taus)]


In [None]:
fig2_timescales(flat_org_ac_divs,flat_org_ac_taus)


In [None]:
# Rerun fits on mean taus
days, all_indices, days_taus, mean_taus, std_taus = get_mean_std(flat_org_ac_taus, flat_org_ac_divs)
days = np.asarray(days)

mean_taus = np.asarray(mean_taus)

guess_h1_mean = [max(mean_taus), 4*np.mean(days)/2,1.15,min(mean_taus)]
n_mean = max(flat_divs_tau_spec)                          
mean_mean = (n/4) 

guess_h2 = [0.1, 2*mean_mean, 10,0]

model_params_mean = fig2_timescales(days,mean_taus)

#plot with all datapoints
plt.figure(figsize =  (6,5))
plt.plot(days, mean_taus, '.', color='k', linewidth=5)
plt.plot( flat_org_ac_divs, flat_org_ac_taus, '.', color = "dimgrey")
plt.plot(days, h_line(days, *model_params_mean[0]),'-.', color='#FFC125', linewidth=5, dashes=(5, 2))
plt.plot(days, sig_func(days, *model_params_mean[1]),'-.', color='darkgreen', linewidth=5, dashes=(4, 2))
plt.plot(days, gaus(days, *model_params_mean[2]),'-.',color='#B22222', linewidth=5, dashes=(2, 2))


In [None]:
#get r-squared values

rsqs_rat_h0 = rsq(days, mean_taus, model_params_mean[0], h_line) 
rsqs_rat_h1 = rsq(days, mean_taus, model_params_mean[1], sig_func) 
rsqs_rat_h2 = rsq(days, mean_taus, model_params_mean[2], gaus) 

print(rsqs_rat_h0)
print(rsqs_rat_h1)
print(rsqs_rat_h2)

In [None]:
plt.figure(figsize =  (6,5))
plt.plot(flat_org_ac_divs,flat_org_ac_taus, '.', color = "dimgrey")
#plt.ylabel("Log(timescales)", size = 18)
#plt.xlabel("Days ", size = 18)
plt.xticks(size = 13)
plt.yticks(size = 13)
x_ticks = list(plt.xticks()[0][1:])
new_x_labels = [int(i+56) for i in x_ticks]
plt.xticks(list(plt.xticks()[0][1:]), labels= new_x_labels)

fig, ax = plt.subplots(figsize =  (6,5))
ax.plot(flat_org_ac_divs,flat_org_ac_taus, '.', color = "dimgrey")
ax.tick_params(axis='both', which='major', labelsize=13)
ax.set_xticks(list(plt.xticks()[0][1:]))
ax.set_xticklabels(new_x_labels)

#top axis
ax2 = ax.secondary_xaxis("top")
ax2.tick_params(axis='both', which='major', labelsize=13)
top_x_labels = [round(i/30.4167,2) for i in new_x_labels]

ax2.set_xticklabels([0] +top_x_labels)


## Organoid specparam timescales

In [None]:
#read data from pickles
org_spec_taus = pd.read_pickle(r'/Users/blancamartin/Downloads/development_plot_y_kneeTau.pickle')

In [None]:
flat_org_spec_divs = [item for sublist in all_org_divs for item in sublist]
flat_org_spec_taus = [item for sublist in org_spec_taus for item in sublist]
flat_org_spec_divs = np.asarray(flat_org_spec_divs)
flat_org_spec_taus = np.asarray(flat_org_spec_taus)
flat_org_spec_divs = flat_org_spec_divs[~np.isnan(flat_org_spec_taus)]
flat_org_spec_taus = flat_org_spec_taus[~np.isnan(flat_org_spec_taus)]

In [None]:
guess_h1 = [max(flat_taus_spec), 4*np.median(flat_divs_tau_spec)/2,0.9,min(flat_taus_spec)]
n = max(flat_divs_tau_spec)                          
mean = (n/2) 

guess_h2 = [0.15, 2/3*mean, 20,0]

fig2_timescales(flat_org_spec_divs,flat_org_spec_taus, sig_method='dogbox', sig_func_type=sigmoid,
                guess_h1 =guess_h1, guess_h2=guess_h2 )



In [None]:
#rerun firts on mean taus

days, all_indices, days_taus, mean_taus, std_taus = get_mean_std(flat_org_spec_taus, flat_org_spec_divs)
days = np.asarray(days)

mean_taus = np.asarray(mean_taus)


guess_h1 = [max(days), 4*np.median(days)/2,0.9,min(mean_taus)]
n = max(flat_divs_tau_spec)                          
mean = (n/2) 

guess_h2 = [0.15, 2/3*mean, 20,0]

model_params_mean = fig2_timescales(days,mean_taus, sig_method='dogbox', sig_func_type=sigmoid,
                guess_h1 =guess_h1, guess_h2=guess_h2 )

#plot with all datapoints
fig, ax = plt.subplots(figsize =  (6,5))
ax.plot( flat_org_spec_divs, flat_org_spec_taus, '.', color = "dimgrey")
ax.tick_params(axis='both', which='major', labelsize=13)
ax.set_xticks(list(plt.xticks()[0][1:]))
ax.set_xticklabels(new_x_labels)

#top axis
ax2 = ax.secondary_xaxis("top")
ax2.tick_params(axis='both', which='major', labelsize=13)
top_x_labels = [round(i/30.4167,1) for i in new_x_labels]

ax2.set_xticklabels([0] +top_x_labels)

ax.plot(days, h_line(days, *model_params_mean[0]),'-.', color='#FFC125', linewidth=5, dashes=(5, 2))
ax.plot(days, sigmoid(days, *model_params_mean[1]),'-.', color='darkgreen', linewidth=5, dashes=(4, 2))
ax.plot(days, gaus(days, *model_params_mean[2]),'-.',color='#B22222', linewidth=5, dashes=(2, 2))




In [None]:
#get r-squared values

rsqs_rat_h0 = rsq(days, mean_taus, model_params_mean[0], h_line) 
rsqs_rat_h1 = rsq(days, mean_taus, model_params_mean[1], sigmoid) 
rsqs_rat_h2 = rsq(days, mean_taus, model_params_mean[2], gaus) 

print(rsqs_rat_h0)
print(rsqs_rat_h1)
print(rsqs_rat_h2)

In [None]:

#f-tests
f_test_h0_h1 = f_test(days, mean_taus, h_line, sigmoid,
                      model_params_mean[0], model_params_mean[1], 1, 4)
f_test_h0_h2 = f_test(days, mean_taus, h_line, gaus, 
                      model_params_mean[0], model_params_mean[2],1,4)
f_test_h2_h1 = f_test(days, mean_taus, gaus, sigmoid, 
                      model_params_mean[2], model_params_mean[1],4,4)

print("f_test org h0 vs h1:", f_test_h0_h1)
print("f_test org h0 vs h2:", f_test_h0_h2)
print("f_test org h1 vs h2:", f_test_h2_h1)