# Calculate MI for each unit/language
1. load datasets
2. calculate MI

In [None]:
import pandas as pd
import numpy as np
from parallelspaper.config.paths import DATA_DIR, FIGURE_DIR
from parallelspaper.speech_datasets import LCOL_DICT

from parallelspaper import information_theory as it 
from parallelspaper.quickplots import plot_model_fits
from tqdm.autonotebook import tqdm
from parallelspaper import model_fitting as mf
from parallelspaper.utils import save_fig

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
MI_DF = pd.read_pickle((DATA_DIR / 'MI_DF/language/language_MI_DF_fitted-utterance.pickle'))

### Plot shuffling analysis within vs between for utterances in japanese and english

In [None]:
fontsize=18
yoff=-.20
ncol = 4
nrow = len(MI_DF)//ncol
zoom = 5
fig, axs = plt.subplots(ncols=ncol, nrows=nrow, figsize=zoom*np.array([ncol,nrow]))
for axi, (idx, row) in enumerate(MI_DF.sort_values(by=['analysis', 'language', 'unit']).iterrows()):
    ax = axs.flatten()[axi]
    
    color = LCOL_DICT[row.language]
    sig = np.array(row.MI-row.MI_shuff)
    distances = row.distances
    sig = sig
    distances = distances
    # get signal limits
    sig_lims = np.log([np.min(sig[sig>0]), np.nanmax(sig)])
    sig_lims = [sig_lims[0] - (sig_lims[1]-sig_lims[0])/10,
                    sig_lims[1] + (sig_lims[1]-sig_lims[0])/10]
            
    if axi%ncol == 0:
            ax.set_ylabel('Mutual Information (bits)', labelpad=5, fontsize=fontsize)
            ax.yaxis.set_label_coords(yoff,0.5)
    if axi >= (nrow-1)*ncol:      
        ax.set_xlabel('Distance (phones)', labelpad=5, fontsize=fontsize)
    
    
    # plot real data
    ax.scatter(distances, sig, alpha = 1, s=40, color=color)
    
    best_fit_model = np.array(['exp','pow','pow_exp'])[np.argmin(row[['AICc_exp', 'AICc_power', 'AICc_concat']].values)]
    
    # set title
    analysis = 'within utt.' if row.analysis == 'shuffled_within_utterance' else 'between utt.'
    model_type = {'pow_exp': 'comp.', 'exp': 'exp.', 'pow':'power law'}[best_fit_model]
    ax.set_title(' | '.join([row.language.capitalize(), analysis, model_type]), fontsize=16)
    
    # plot model
    distances_model = np.logspace(0,np.log10(distances[-1]), base=10, num=1000)
    
    if best_fit_model == 'pow_exp':
        ax.axvline(distances_model[int(row.min_peak)], lw=3,alpha=0.5, color=color, ls='dashed')
        
    if best_fit_model == 'pow_exp':
        # model data
        #row.concat_results.params.intercept = 0
        y_model = mf.get_y(mf.pow_exp_decay, row.concat_results, distances_model)
        y_pow = mf.get_y(mf.powerlaw_decay, row.concat_results, distances_model)
        y_exp = mf.get_y(mf.exp_decay, row.concat_results, distances_model)

        ax.plot(distances_model, y_pow, ls='dotted', color= 'k', lw=5, alpha=0.5)
        ax.plot(distances_model, y_exp-row.concat_results.params['intercept'].value, ls='dashed', color= 'k', lw=5, alpha=0.5)

        # plot modelled data
        ax.plot(distances_model, y_model, alpha = 0.5, lw=10, color=color)
    
    elif best_fit_model == 'pow':
        y_model = mf.get_y(mf.powerlaw_decay, row.pow_results, distances_model)
        # plot modelled data
        ax.plot(distances_model, y_model, alpha = 0.5, lw=10, color=color)
        
        
    elif best_fit_model == 'exp':
        y_model = mf.get_y(mf.exp_decay, row.exp_results, distances_model)
        # plot modelled data
        ax.plot(distances_model, y_model, alpha = 0.5, lw=10, color=color)
        
    # axis params
    ax.set_xlim([distances[0], distances[-1]])
    sig_lims[0] = np.log(10e-6)
    ax.set_ylim(np.exp(sig_lims))
    ax.tick_params(which='both', direction='in', labelsize=14, pad=10)
    ax.tick_params(which='major', length=10, width =3)
    ax.tick_params(which='minor', length=5, width =2)
    ax.set_xscale( "log" , basex=10)
    ax.set_yscale( "log" , basey=10)
    ax.set_xticks([])
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(3)
        ax.spines[axis].set_color('k')
    
    ax.set_xlim([1,100])
    ax.set_xticks([1,10,100])
    ax.set_xticklabels(['1','10','100'])
    
    
save_fig(FIGURE_DIR/'speech_shuffle_utterance')