In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from childes_mi.utils.paths import DATA_DIR, FIGURE_DIR
from childes_mi.utils.general import flatten,save_fig

In [3]:
from childes_mi.information_theory import model_fitting as mf

In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
from tqdm.autonotebook import tqdm



In [5]:
MI_DF = pd.read_pickle(DATA_DIR/'mi/drosophila_mi_1000.pickle')

In [6]:
MI_DF[:3]

Unnamed: 0,MI,MI_var,shuff_MI,shuff_MI_var,distances
0,"[11.491092426440758, 11.113848299482346, 10.94...","[0.0032096905199801226, 0.0031487970998960533,...","[10.162464643140364, 10.161342744045395, 10.16...","[0.0030504540127070092, 0.00305061298348205, 0...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14..."


In [7]:
MI_DF = MI_DF.assign(**{i:np.nan for i in ['exp_results', 'pow_results', 'concat_results',
     'R2_exp', 'R2_concat', 'R2_power', 'AICc_exp',
     'AICc_concat', 'AICc_power', 'bestfitmodel', 'curvature', 'min_peak']})
MI_DF['curvature'] = MI_DF['curvature'].astype(object)

In [8]:
n = 100 # max distance for computation
for idx, row in tqdm(MI_DF.iterrows(), total=len(MI_DF)):
    # get signal
    sig = np.array(row.MI-row.shuff_MI)
    distances = row.distances
    sig = sig
    
    # fit models
    results_power, results_exp, results_pow_exp, best_fit_model = mf.fit_models(distances, sig)
    
    # get fit results
    R2_exp, R2_concat, R2_power, AICc_exp, \
        AICc_pow, AICc_concat = mf.fit_results(sig, distances, 
                                              results_exp, results_power,
                                              results_pow_exp)
    
    
    
    
    # get model y
    distances_mod = np.logspace(0,np.log10(n), base=10, num=1000)
    if best_fit_model == 'pow_exp':
        y_model = mf.get_y(mf.pow_exp_decay, results_pow_exp, distances_mod)
    elif best_fit_model == 'exp':
        y_model = mf.get_y(mf.exp_decay, results_exp, distances_mod)
    elif best_fit_model == 'pow':
        y_model = mf.get_y(mf.powerlaw_decay, results_power, distances_mod)
    
    # get curvature of model_y
    curvature_model = mf.curvature(np.log(y_model))
    
    # if the best fit model is pow_exp, then grab the min peak
    if best_fit_model == 'pow_exp':
        # get peaks of curvature
        peaks = np.where((
            (curvature_model[:-1] < curvature_model[1:])[1:] & (curvature_model[1:] < curvature_model[:-1])[:-1]
        ))
        min_peak = peaks[0][0]
    else:
        min_peak = np.nan

    # get save model fit results to MI_DF
    MI_DF.loc[idx, np.array(['exp_results', 'pow_results', 'concat_results',
                         'R2_exp', 'R2_concat', 'R2_power', 'AICc_exp',
                         'AICc_concat', 'AICc_power', 'bestfitmodel', 'curvature', 'min_peak'])] = [
        results_exp, results_power, results_pow_exp,
        R2_exp, R2_concat, R2_power, AICc_exp,
        AICc_concat, AICc_pow, best_fit_model,
        curvature_model, min_peak
    ]


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




In [9]:
MI_DF

Unnamed: 0,MI,MI_var,shuff_MI,shuff_MI_var,distances,exp_results,pow_results,concat_results,R2_exp,R2_concat,R2_power,AICc_exp,AICc_concat,AICc_power,bestfitmodel,curvature,min_peak
0,"[11.491092426440758, 11.113848299482346, 10.94...","[0.0032096905199801226, 0.0031487970998960533,...","[10.162464643140364, 10.161342744045395, 10.16...","[0.0030504540127070092, 0.00305061298348205, 0...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...",<lmfit.minimizer.MinimizerResult object at 0x7...,<lmfit.minimizer.MinimizerResult object at 0x7...,<lmfit.minimizer.MinimizerResult object at 0x7...,0.952085,0.999521,0.995571,-6513.666172,-11115.26092,-8894.92732,pow_exp,"[3.129027579848531e-07, 4.695250164558214e-07,...",881.0


In [10]:
row = MI_DF.iloc[0]

In [16]:
param_df = pd.DataFrame(
    columns=[
        "dataset", 
        "age_low", 
        "age_high", 
        "a_value", 
        "a_stderr",
        "b_value", 
        "b_stderr",
        "c_value", 
        "c_stderr",
        "d_value", 
        "d_stderr",
        "f_value", 
        "f_stderr"
    ]
)
param_vals = {}
param_stderrs = {}
for param in row.concat_results.params:
    param_vals[param] = row.concat_results.params[param].value
    param_stderrs[param] = row.concat_results.params[param].stderr
param_df.loc[len(param_df)] = [
    'drosophila',
    None,
    None,
    param_vals["e_init"],
    param_stderrs["e_init"],
    param_vals["e_decay_const"],
    param_stderrs["e_decay_const"],
    param_vals["p_init"],
    param_stderrs["p_init"],
    param_vals["p_decay_const"],
    param_stderrs["p_decay_const"],
    param_vals["intercept"],
    param_stderrs["intercept"],

]
param_df

Unnamed: 0,dataset,age_low,age_high,a_value,a_stderr,b_value,b_stderr,c_value,c_stderr,d_value,d_stderr,f_value,f_stderr
0,drosophila,,,0.155142,0.001774,0.014147,0.00016,1.099649,0.00385,-0.505501,0.002269,0.040487,0.00063


In [17]:
from childes_mi.utils.paths import DATA_DIR, FIGURE_DIR, ensure_dir


In [18]:
ensure_dir(DATA_DIR / 'param_dfs')
param_df.to_pickle(DATA_DIR / 'param_dfs' / 'drosophila.pickle')