# plot MI decay for birdsong

In [None]:
import pandas as pd
from parallelspaper.config.paths import DATA_DIR, FIGURE_DIR
from parallelspaper.birdsong_datasets import BCOL_DICT
import numpy as np
from parallelspaper import model_fitting as mf
from parallelspaper.utils import save_fig

In [None]:
from matplotlib import gridspec
import matplotlib.pyplot as plt
%matplotlib inline

### load data

In [None]:
MI_DF = pd.read_pickle((DATA_DIR / 'MI_DF/birdsong/birdsong_MI_DF_fitted.pickle'))

In [None]:
subset_MI_DF = MI_DF[MI_DF.type == 'compress']

In [None]:
subset_MI_DF

### plot main decay results for birdsong

In [None]:
yoff=-.20
ncols = 4
zoom = 5
hr = [1, 0.5, 0.5, 0.5]
nrows = np.ceil(len(subset_MI_DF)/ncols).astype(int)
fig = plt.figure(figsize=(len(subset_MI_DF)*zoom,np.sum(hr)*zoom)) 
gs = gridspec.GridSpec(ncols=len(subset_MI_DF), nrows=4, height_ratios=hr) 

for axi, (idx, row) in enumerate(subset_MI_DF.sort_values(by=['type', 'species']).iterrows()):
    color = BCOL_DICT[row.species]
    ax0 = plt.subplot(gs[0,axi])
    ax = ax0
    sig = np.array(row.MI-row.MI_shuff)
    distances = row.distances
    sig = sig
    distances = distances
    # get signal limits
    sig_lims = np.log([np.min(sig[sig>0]), np.nanmax(sig)])
    sig_lims = [sig_lims[0] - (sig_lims[1]-sig_lims[0])/10,
                    sig_lims[1] + (sig_lims[1]-sig_lims[0])/10]
            
    if axi==0: 
            ax.set_ylabel('Mutual Information (bits)', labelpad=5, fontsize=18)
            ax.yaxis.set_label_coords(yoff,0.5)
            
    # model data
    #row.concat_results.params.intercept = 0
    distances_model = np.logspace(0,np.log10(distances[-1]), base=10, num=1000)
    y_model = mf.get_y(mf.pow_exp_decay, row.concat_results, distances)
    y_pow = mf.get_y(mf.powerlaw_decay, row.concat_results, distances_model)
    y_exp = mf.get_y(mf.exp_decay, row.concat_results, distances_model)
    y_pow_dat = mf.get_y(mf.powerlaw_decay, row.concat_results, distances)
    y_exp_dat = mf.get_y(mf.exp_decay, row.concat_results, distances)
    
    # plot real data
    ax.scatter(distances, sig, alpha = 1, s=40, color=color)   
    ax.plot(distances_model, y_pow, ls='dotted', color= 'k', lw=5, alpha=0.5)
    ax.plot(distances_model, y_exp-row.concat_results.params['intercept'].value, ls='dashed', color= 'k', lw=5, alpha=0.5)
    
    # plot modelled data
    ax.plot(distances, y_model, alpha = 0.5, lw=10, color=color)
    
    # plot powerlaw component
    ax1 = plt.subplot(gs[1,axi])
    ax = ax1
    ax.plot(distances_model, y_pow-row.concat_results.params['intercept'].value, alpha = 0.5, lw=10, color=color)
    ax.scatter(distances, sig-y_exp_dat, alpha = 1, s=40, color=color)   
    
    # plot exponential component
    ax2 = plt.subplot(gs[2,axi])
    ax = ax2
    ax.plot(distances_model, y_exp-row.concat_results.params['intercept'].value, alpha = 0.5, lw=10, color=color)
    ax.scatter(distances, sig-y_pow_dat, alpha = 1, s=40, color=color)   
    
    # plot curvature
    ax3 = plt.subplot(gs[3,axi])
    ax = ax3
    if axi==0: 
        ax.set_ylabel('Curvature', labelpad=5, fontsize=18)
        ax.yaxis.set_label_coords(yoff,0.5)
        ax.set_yticks([0.0])
        ax.set_yticklabels(['0.0'])
    else:
        ax.set_yticks([0.0])
        ax.set_yticklabels([])
    
    
    # curvature 
    distances = np.logspace(0,np.log10(100), base=10, num=1000)
    y_model = mf.get_y(mf.pow_exp_decay, row.concat_results, distances)
    curvature_model = row.curvature
    ax.plot(distances[5:-5], curvature_model[5:-5], alpha = 1, lw=5, color=color)
    ax.set_ylim([-1e-4,1e-4])
    
    peak_of_interest = int(row.min_peak)
    ax.axvline(distances[peak_of_interest], lw=3,alpha=0.5, color=color, ls='dashed')
    ax.set_xlabel('Distance (syllables)', labelpad=5, fontsize=18)
    print(row.species, distances[peak_of_interest])
    
    
    # axis labelling, etc
    ax.tick_params(which='both', direction='in', labelsize=14, pad=10)
    ax.tick_params(which='major', length=10, width =3)
    ax.tick_params(which='minor', length=5, width =2)
    ax.set_xscale( "log" , basex=10)
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(3)
        ax.spines[axis].set_color('k')
    ax.set_xlim([1,100])
    
    # set axis limits, etc
    for ax in [ax1, ax2]:
        if axi==0: 
            ax.set_ylabel('MI (bits)', labelpad=5, fontsize=18)
            ax.yaxis.set_label_coords(yoff,0.5)
    
    for ax in [ax0,ax1,ax2]:
        ax.set_xlim([distances[0], distances[-1]])
        ax.set_ylim(np.exp(sig_lims))
        ax.tick_params(which='both', direction='in', labelsize=14, pad=10)
        ax.tick_params(which='major', length=10, width =3)
        ax.tick_params(which='minor', length=5, width =2)
        ax.set_xscale( "log" , basex=10)
        ax.set_yscale( "log" , basey=10)
        ax.set_xticks([])
        ax.set_ylim(np.exp(sig_lims))
        for axis in ['top','bottom','left','right']:
            ax.spines[axis].set_linewidth(3)
            ax.spines[axis].set_color('k')
    ax3.set_xticks([1,10,100])
    ax3.set_xticklabels(['1','10','100'])
        
save_fig(FIGURE_DIR/'song_fig_compressed')

### Dataset statistics

In [None]:
CAVICATH_stats_df = pd.read_pickle(DATA_DIR / 'stats_df/CAVICATH_stats_df.pickle')
Starling_stats_df = pd.read_pickle(DATA_DIR / 'stats_df/starling_stats_df.pickle')
BF_stats_df = pd.read_pickle(DATA_DIR / 'stats_df/BF_stats_df.pickle')

In [None]:
stats_df = pd.concat([CAVICATH_stats_df, Starling_stats_df, BF_stats_df])

In [None]:
stats_df

### Plot distribution

In [None]:
import seaborn as sns
from matplotlib.ticker import FixedLocator

In [None]:
fig, axs = plt.subplots(ncols=4, figsize=(20,2))
for i,l in enumerate(['CAVI', 'CATH', 'Starling', 'BF']):
    ax = axs.flatten()[i]
    wlp =stats_df[stats_df.species==l].recording_duration_syllable.values[0]
    np.sum(np.array(wlp) == 1)/len(wlp)
    ax.hist(wlp,bins=np.arange(100), density=True, color = BCOL_DICT[l])
    ax.set_xlim([1,100])