# Get model contribution by distance

In [None]:
import pandas as pd
from parallelspaper.config.paths import DATA_DIR, FIGURE_DIR
from parallelspaper.birdsong_datasets import BCOL_DICT
import numpy as np
from parallelspaper import model_fitting as mf
from parallelspaper.utils import save_fig
from parallelspaper import information_theory as it 

In [None]:
from matplotlib import gridspec
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# stats
CAVICATH_stats_df = pd.read_pickle(DATA_DIR / 'stats_df/CAVICATH_stats_df.pickle')
Starling_stats_df = pd.read_pickle(DATA_DIR / 'stats_df/starling_stats_df.pickle')
BF_stats_df = pd.read_pickle(DATA_DIR / 'stats_df/BF_stats_df.pickle')

stats_df = pd.concat([CAVICATH_stats_df, Starling_stats_df, BF_stats_df])

In [None]:
 MI_DF = pd.read_pickle(DATA_DIR / 'MI_DF/birdsong/birdsong_MI_DF_fitted.pickle')

In [None]:
subset_MI_DF = MI_DF[MI_DF.type == 'day']

In [None]:
subset_MI_DF

In [None]:
distances = np.logspace(0,2, base=10, num=1000)

In [None]:
fig, axs = plt.subplots(ncols = 4, figsize=(16,4))
for axi, (idx, row) in enumerate(subset_MI_DF.iterrows()):
    
    birdrow = stats_df.species.values == row.species
    median_syllable_len = np.median(stats_df[birdrow].syllable_duration_s.values[0])
    median_isi =  np.median(stats_df[birdrow].isi.values[0])
    
    max_peak_dist = distances[int(np.argmax(row.curvature))]
    lower_mask = row.distances < max_peak_dist
    y_model = mf.get_y(mf.pow_exp_decay, row.concat_results, row.distances)
    y_pow = mf.get_y(mf.powerlaw_decay, row.concat_results, row.distances)
    y_exp = mf.get_y(mf.exp_decay, row.concat_results, row.distances)
    y = row.MI - row.MI_shuff
    
    print(
        row.species, 
        np.sum(y_pow[lower_mask]/y_model[lower_mask])/np.sum(lower_mask),
        (median_syllable_len+median_isi)* max_peak_dist
    )
    
    
    axs[axi].loglog(row.distances, y_model, color = BCOL_DICT[row.species])
    axs[axi].loglog(row.distances, y_pow, color = BCOL_DICT[row.species], ls='dotted')
    axs[axi].loglog(row.distances, y_exp, color = BCOL_DICT[row.species], ls='dashed')
