In [None]:
%load_ext autoreload
%autoreload 2

# Plot MI decay of Markov, Hierarchical, and Hybrid models
1. load models
2. fit best fit model
3. plot decay

In [None]:
from glob import glob
import pandas as pd
from parallelspaper.config.paths import DATA_DIR, FIGURE_DIR
import parallelspaper.model_fitting as mf
from datetime import datetime
from tqdm.autonotebook import tqdm
import numpy as np
from parallelspaper.utils import save_fig

### Load models

In [None]:
# find the latest instance of a Markov model 
MI_DFs_markov = glob(str(DATA_DIR / 'MI_DF/models/markov_*.pickle'))
loc_table = pd.DataFrame([[datetime.strptime('_'.join(loc.split('/')[-1][:-7].split('_')[1:]),"%Y-%m-%d_%H-%M-%S"),
                           loc] for loc in MI_DFs_markov], columns=['dt', 'locat']).sort_values(by='dt')
markov_MI_DF = pd.read_pickle(loc_table.locat.values[-1])

In [None]:
# find the latest instance of a Markov model 
MI_DFs_hierarchical = glob(str(DATA_DIR / 'MI_DF/models/hierarchical_*.pickle'))
loc_table = pd.DataFrame([[datetime.strptime('_'.join(loc.split('/')[-1][:-7].split('_')[1:]),"%Y-%m-%d_%H-%M-%S"),
                           loc] for loc in MI_DFs_hierarchical], columns=['dt', 'locat']).sort_values(by='dt')
hierarchical_MI_DF = pd.read_pickle(loc_table.locat.values[-1])

In [None]:
# find the latest instance of a Markov model 
MI_DFs_hybrid = glob(str(DATA_DIR / 'MI_DF/models/hybrid_*.pickle'))
loc_table = pd.DataFrame([[datetime.strptime('_'.join(loc.split('/')[-1][:-7].split('_')[1:]),"%Y-%m-%d_%H-%M-%S"),
                           loc] for loc in MI_DFs_hybrid], columns=['dt', 'locat']).sort_values(by='dt')
hybrid_MI_DF = pd.read_pickle(loc_table.locat.values[-1])

In [None]:
# concatenate models
MI_DF = pd.concat([markov_MI_DF, hierarchical_MI_DF, hybrid_MI_DF]).reset_index()

### Fit models

In [None]:
for idx, row in tqdm(MI_DF.iterrows(), total=len(MI_DF)):
    print(row['name'])
    
    # get signal
    sig = np.array(row.MI-row.MI_shuff)
    distances = row.distances
    
    # fit models
    results_power, results_exp, results_pow_exp, best_fit_model = mf.fit_models(distances, sig)
    results_concat = results_pow_exp
    
    # add results to MI_DF
    MI_DF.loc[idx,'exp_results_params'] = [{i:results_exp.params[i].value for i in dict(results_exp.params).keys()}]
    MI_DF.loc[idx,'pow_results_params'] = [{i:results_power.params[i].value for i in dict(results_power.params).keys()}]
    MI_DF.loc[idx,'concat_results_params'] = [{i:results_concat.params[i].value for i in dict(results_concat.params).keys()}]
    MI_DF.loc[idx,'exp_results'] = results_exp
    MI_DF.loc[idx,'pow_results'] = results_power
    MI_DF.loc[idx,'concat_results'] = results_concat
    
    # get model fit results from predictions and signal
    R2_exp, R2_concat, R2_power, AICc_exp, \
        AICc_pow, AICc_concat = mf.fit_results(sig, distances, 
                                              results_exp, results_power,
                                              results_pow_exp, logscaled=True)
    
    # add AIC to MI_DF
    MI_DF.loc[idx,'AICc_exp'] = AICc_exp
    MI_DF.loc[idx,'AICc_concat'] = AICc_concat
    MI_DF.loc[idx,'AICc_power'] = AICc_pow

    # determine best fit model
    MI_DF.loc[idx,'bestfitmodel'] = bestfitmodel = ['exp', 'concat', 'power'][np.argmin([AICc_exp, AICc_concat, AICc_pow])]

In [None]:
MI_DF

### plot fit models

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import seaborn as sns
col_dict = {
    "hierarchical": sns.color_palette('Reds', 5)[2:],
    "markov": sns.color_palette('Greens', 5)[2:],
    "hybrid": sns.color_palette('Blues', 5)[2:],
}

In [None]:
ncols = 3
nrows = 1
zoom = 5
d = 100
# plot data
fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize = (ncols*6,nrows*5))
for axi, (idx, row) in tqdm(enumerate(MI_DF.iterrows()), total=len(MI_DF)):
    distances = np.linspace(1,d,1000)

    if 'hierarchical' in row['name']:
        color = col_dict['hierarchical'][1]
        ax = axs[0]
        y_model = mf.get_y(mf.powerlaw_decay, row.pow_results, distances)
        ax.plot(distances, y_model, alpha = 0.5, lw=10, color=color)
        
    elif 'markov' in row['name']:
        ax = axs[1]
        if row['name'] == 'Okada_markov':
            color = col_dict['markov'][0]
        elif row['name'] == 'Bird2_markov':
            color = col_dict['markov'][1]
        elif row['name'] == 'Bird1_markov':
            color = col_dict['markov'][2]
        y_model = mf.get_y(mf.exp_decay, row.exp_results, distances)
        ax.plot(distances, y_model, alpha = 0.5, lw=10, color=color)
        
    elif 'hybrid' in row['name']:
        ax = axs[2]
        color = col_dict['hybrid'][1]
        y_model = mf.get_y(mf.pow_exp_decay, row.concat_results, distances)
        ax.plot(distances, y_model, alpha = 0.5, lw=10, color=color)

    # plot real data
    sig = np.array(row.MI-row.MI_shuff)
    distances = row.distances
    ax.scatter(distances, sig, alpha = 1, s=80, color=color)

# labels, styling
for axi, ax in enumerate(axs):
    ax.tick_params(axis='both', labelsize=18, pad=15)
    ax.set_xlabel('Distance between elements', labelpad=5, fontsize=18)
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(3)
        ax.spines[axis].set_color('k')
    ax.grid(False)
    ax.tick_params(which='both', direction='in')
    ax.tick_params(which='major', length=10, width =3)
    ax.tick_params(which='minor', length=5, width =2)

    ax.set_xscale( "log" , basex=10)
    ax.set_yscale( "log" , basey=10)

    ax.set_xticks([1,10,100])
    ax.set_xticklabels([1,10,100])
    if axi==0: 
        ax.set_ylabel('Mutual Information (bits)', labelpad=5, fontsize=18)
    #else:
    #    ax.set_yticklabels([])
    ax.set_xlim([1,100])
    ax.set_ylim([1e-4,10])
    #ax.legend()
plt.tight_layout()

save_fig(FIGURE_DIR/'modelfig')