# plot MI decay for birdsong

In [None]:
import pandas as pd
from parallelspaper.config.paths import DATA_DIR, FIGURE_DIR
from parallelspaper.birdsong_datasets import BCOL_DICT
import numpy as np
from parallelspaper import model_fitting as mf
from parallelspaper.utils import save_fig


In [None]:
from matplotlib import gridspec
import matplotlib.pyplot as plt
%matplotlib inline

### load data

In [None]:
MI_DF = pd.read_pickle((DATA_DIR / 'MI_DF/birdsong/birdsong_MI_DF_fitted.pickle'))

In [None]:
np.unique(MI_DF.type)

In [None]:
MI_DF[(MI_DF.type == 'song')]

In [None]:
from tqdm.autonotebook import tqdm

In [None]:
for idx, row in tqdm(MI_DF.iterrows(), total=len(MI_DF)):
    AICcs = row[['AICc_exp', 'AICc_concat', 'AICc_power']]
    delta_AICcs = AICcs - np.min(AICcs.T.values)
    delta_AICcs = delta_AICcs.T
    relative_likelihoods = mf.relative_likelihood(np.array(list(delta_AICcs.values)))
    prob_models = mf.Prob_model_Given_data_and_models(relative_likelihoods)
    print(row.species, row.type, prob_models.round(4))

In [None]:
subset_MI_DF = MI_DF[(MI_DF.type == 'shuffled_within') | (MI_DF.type == 'shuffled_between')]

In [None]:
letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

In [None]:
fontsize = 18
yoff=-.20
ncol = 4
nrow = len(subset_MI_DF)//ncol
zoom = 5
fig, axs = plt.subplots(ncols=ncol, nrows=nrow, figsize=zoom*np.array([ncol,nrow]))
for axi, (idx, row) in enumerate(subset_MI_DF.sort_values(by=['type', 'species']).iterrows()):
    ax = axs.flatten()[axi]
    # lettering
    ax.annotate(letters[axi], xy=(-0.05, 1.1), xycoords="axes fraction", size=20, fontweight='bold', fontfamily='Arial')
    color = BCOL_DICT[row.species]
    sig = np.array(row.MI-row.MI_shuff)
    distances = row.distances
    sig = sig
    distances = distances
    # get signal limits
    sig_lims = np.log([np.min(sig[sig>0]), np.nanmax(sig)])
    sig_lims = [sig_lims[0] - (sig_lims[1]-sig_lims[0])/10,
                    sig_lims[1] + (sig_lims[1]-sig_lims[0])/10]
            
    if axi%ncol == 0:
            ax.set_ylabel('Mutual Information (bits)', labelpad=5, fontsize=fontsize)
            ax.yaxis.set_label_coords(yoff,0.5)
    if axi >= (nrow-1)*ncol:      
        ax.set_xlabel('Distance (syllables)', labelpad=5, fontsize=fontsize)
        ax.set_xticks([1,10,100])
        ax.set_xticklabels(['1','10','100'])
    
    # plot real data
    ax.scatter(distances, sig, alpha = 1, s=40, color=color)
    
    best_fit_model = np.array(['exp','pow','pow_exp'])[np.argmin(row[['AICc_exp', 'AICc_power', 'AICc_concat']].values)]
    
    # set title
    analysis = 'within bout' if row.type == 'shuffled_within' else 'between bout'
    model_type = {'pow_exp': 'comp.', 'exp': 'exp.', 'pow':'power law'}[best_fit_model]
    spc = {'Starling':'starling', 'CAVI': 'vireo', 'CATH':'thrasher', 'BF':'finch'}
    ax.set_title(' | '.join([spc[row.species], analysis, model_type]), fontsize=16)
    
    # plot model
    distances_model = np.logspace(0,np.log10(distances[-1]), base=10, num=1000)
    
    if best_fit_model == 'pow_exp':
        ax.axvline(distances_model[int(row.min_peak)], lw=3,alpha=0.5, color=color, ls='dashed')
        
    if best_fit_model == 'pow_exp':
        # model data
        #row.concat_results.params.intercept = 0
        y_model = mf.get_y(mf.pow_exp_decay, row.concat_results, distances_model)
        y_pow = mf.get_y(mf.powerlaw_decay, row.concat_results, distances_model)
        y_exp = mf.get_y(mf.exp_decay, row.concat_results, distances_model)

        ax.plot(distances_model, y_pow, ls='dotted', color= 'k', lw=5, alpha=0.5)
        ax.plot(distances_model, y_exp-row.concat_results.params['intercept'].value, ls='dashed', color= 'k', lw=5, alpha=0.5)

        # plot modelled data
        ax.plot(distances_model, y_model, alpha = 0.5, lw=10, color=color)
    
    elif best_fit_model == 'pow':
        y_model = mf.get_y(mf.powerlaw_decay, row.pow_results, distances_model)
        # plot modelled data
        ax.plot(distances_model, y_model, alpha = 0.5, lw=10, color=color)
        
        
    elif best_fit_model == 'exp':
        y_model = mf.get_y(mf.exp_decay, row.exp_results, distances_model)
        # plot modelled data
        ax.plot(distances_model, y_model, alpha = 0.5, lw=10, color=color)
        
    # axis params
    ax.set_xlim([distances[0], distances[-1]])
    #sig_lims = [np.log(10e-4), 1]
    sig_lims[1] = 1
    ax.set_ylim(np.exp(sig_lims))
    ax.tick_params(which='both', direction='in', labelsize=14, pad=10)
    ax.tick_params(which='major', length=10, width =3)
    ax.tick_params(which='minor', length=5, width =2)
    ax.set_xscale( "log" , basex=10)
    ax.set_yscale( "log" , basey=10)
    ax.set_xticks([])
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(3)
        ax.spines[axis].set_color('k')
    
    ax.set_xlim([1,100])
    ax.set_xticks([1,10,100])
    ax.set_xticklabels(['1','10','100'])
    
#plt.tight_layout()
plt.subplots_adjust(hspace = 0.30)

save_fig(FIGURE_DIR/'song_shuffle_bout')