# plot MI decay for language

In [None]:
import pandas as pd
from parallelspaper.config.paths import DATA_DIR, FIGURE_DIR
from parallelspaper.speech_datasets import LCOL_DICT
import numpy as np
from parallelspaper import model_fitting as mf
from parallelspaper.utils import save_fig
from parallelspaper import information_theory as it 

In [None]:
LCOL_DICT

In [None]:
from matplotlib import gridspec
import matplotlib.pyplot as plt
%matplotlib inline

### load fit df and determine length to compute MI

In [None]:
german_seqs = pd.read_pickle(DATA_DIR/'speech_seq_df/GECO_seq_df.pickle')
italian_seqs = pd.read_pickle(DATA_DIR/'speech_seq_df/AsiCA_seq_df.pickle')
english_seqs = pd.read_pickle(DATA_DIR/'speech_seq_df/BUCKEYE_seq_df.pickle')
japanese_seqs = pd.read_pickle(DATA_DIR/'speech_seq_df/CSJ_seq_df.pickle')
seq_dfs = pd.concat([german_seqs, italian_seqs, english_seqs, japanese_seqs])

In [None]:
fit_df = pd.read_pickle(DATA_DIR / 'MI_DF/language/fit_df_long.pickle')

In [None]:
fit_df[:3]

In [None]:
language_d = {}
for language in np.unique(fit_df.language):
    language_fit_df = fit_df[fit_df.language == language]
    language_fit_df.R2_concat.values
    r2_100 = language_fit_df[language_fit_df.d == 100].R2_concat.values[0]
    lang_d = language_fit_df.d.values[(language_fit_df.R2_concat.values > r2_100 * .999)][-1]
    language_d[language] = lang_d
    print(language, lang_d)

#### get MI of the longest distance within 99.9% of $r^2$ of 100 syllables distance

In [None]:
n_jobs = 20; verbosity = 0

In [None]:
subsets = [
    ['german', 'speaker/word/phoneme'],
    ['italian', 'speaker/word/phoneme'],
    ['english', 'speaker/utterance/word/phonetic'],
    ['japanese', 'speaker/word/phonemes'],
]
# subset only the main analyses
subset_seq_df = pd.concat([seq_dfs[(seq_dfs.language == l) & (seq_dfs.levels == lev)] for l, lev in subsets])

In [None]:
def flatlist(list_of_lists):
    return [val for sublist in list_of_lists for val in sublist]

In [None]:
verbosity = 0; n_jobs = 20

In [None]:
MI_DF = pd.DataFrame(columns=['language', 'unit', 'type', 'MI', 'MI_shuff', 'distances',
                              'MI_var', 'MI_shuff_var', 'results_power', 'results_exp', 'results_pow_exp'])

for idx, (language, levels, data) in subset_seq_df.iterrows():
    levels = levels.split('/')
    
    distances = np.arange(1, language_d[language]+1)
    
    # buckeye has an additional 'utterance' level to ignore
    if language == 'english':
        data = [flatlist(speaker) for speaker in data]
        if len(levels) == 4:
            levels = np.array(levels)[[0,2,3]].tolist()
        elif len(levels) == 3:
            levels = np.array(levels)[[0,2]].tolist()
            
    if len(levels) == 2:
        # speakers is the highest level or organization so just compute MI
        units = data
        (MI, var_MI), (MI_shuff, MI_shuff_var) = it.sequential_mutual_information(units, distances, n_jobs = n_jobs, verbosity = verbosity)
    else:   
        # concatenate across words, compute MI
        units = np.array([flatlist(i) for i in data])
        (MI, var_MI), (MI_shuff, MI_shuff_var) = it.sequential_mutual_information(units, distances, n_jobs = n_jobs, verbosity = verbosity)

    sig = MI-MI_shuff
    results_power, results_exp, results_pow_exp, best_fit_model = mf.fit_models(
        distances, sig)
    
    plt.loglog(distances, MI-MI_shuff)
    plt.show()
    
    MI_DF.loc[len(MI_DF)] = [language, levels[-1], 'session', MI, MI_shuff, distances,
                             var_MI, MI_shuff_var, results_power, results_exp, results_pow_exp]

### plot main decay results for language

In [None]:
subset_MI_DF = MI_DF
subset_MI_DF['concat_results'] = subset_MI_DF.results_pow_exp

In [None]:
yoff=-.20
ncols = 4
zoom = 5
hr = [1, 0.5, 0.5, 0.5]
nrows = np.ceil(len(subset_MI_DF)/ncols).astype(int)
fig = plt.figure(figsize=(len(subset_MI_DF)*zoom,np.sum(hr)*zoom)) 
gs = gridspec.GridSpec(ncols=len(subset_MI_DF), nrows=4, height_ratios=hr) 

for axi, (idx, row) in enumerate(subset_MI_DF.sort_values(by=['unit', 'language']).iterrows()):
    color = LCOL_DICT[row.language]
    ax0 = plt.subplot(gs[0,axi])
    ax = ax0
    sig = np.array(row.MI-row.MI_shuff)
    distances = row.distances
    sig = sig
    distances = distances
    # get signal limits
    sig_lims = np.log([np.min(sig[sig>0]), np.nanmax(sig)])
    sig_lims = [sig_lims[0] - (sig_lims[1]-sig_lims[0])/10,
                    sig_lims[1] + (sig_lims[1]-sig_lims[0])/10]
            
    if axi==0: 
            ax.set_ylabel('Mutual Information (bits)', labelpad=5, fontsize=18)
            ax.yaxis.set_label_coords(yoff,0.5)
            
    # model data
    #row.concat_results.params.intercept = 0
    distances_model = np.logspace(0,np.log10(distances[-1]), base=10, num=1000)
    y_model = mf.get_y(mf.pow_exp_decay, row.concat_results, distances)
    y_pow = mf.get_y(mf.powerlaw_decay, row.concat_results, distances_model)
    y_exp = mf.get_y(mf.exp_decay, row.concat_results, distances_model)
    y_pow_dat = mf.get_y(mf.powerlaw_decay, row.concat_results, distances)
    y_exp_dat = mf.get_y(mf.exp_decay, row.concat_results, distances)
    
    # plot real data
    ax.scatter(distances, sig, alpha = 1, s=40, color=color)   
    ax.plot(distances_model, y_pow, ls='dotted', color= 'k', lw=5, alpha=0.5)
    ax.plot(distances_model, y_exp-row.concat_results.params['intercept'].value, ls='dashed', color= 'k', lw=5, alpha=0.5)
    
    # plot modelled data
    ax.plot(distances, y_model, alpha = 0.5, lw=10, color=color)
    
    # plot powerlaw component
    ax1 = plt.subplot(gs[1,axi])
    ax = ax1
    ax.plot(distances_model, y_pow-row.concat_results.params['intercept'].value, alpha = 0.5, lw=10, color=color)
    ax.scatter(distances, sig-y_exp_dat, alpha = 1, s=40, color=color)   
    
    # plot exponential component
    ax2 = plt.subplot(gs[2,axi])
    ax = ax2
    ax.plot(distances_model, y_exp-row.concat_results.params['intercept'].value, alpha = 0.5, lw=10, color=color)
    ax.scatter(distances, sig-y_pow_dat, alpha = 1, s=40, color=color)   
    
    # plot curvature
    ax3 = plt.subplot(gs[3,axi])
    ax = ax3
    if axi==0: 
        ax.set_ylabel('Curvature', labelpad=5, fontsize=18)
        ax.yaxis.set_label_coords(yoff,0.5)
        ax.set_yticks([0.0])
        ax.set_yticklabels(['0.0'])
    else:
        ax.set_yticks([0.0])
        ax.set_yticklabels(['0.0'])
    
    distances = np.logspace(0,np.log10(language_d[row.language]), base=10, num=1000)
    y_model = mf.get_y(mf.pow_exp_decay, row.concat_results, distances)
    # get curvature of model_y
    curvature_model = mf.curvature(np.log(y_model))
    peaks = np.where((
            (curvature_model[:-1] < curvature_model[1:])[1:] & (curvature_model[1:] < curvature_model[:-1])[:-1]
        ))
    
    ax.tick_params(which='both', direction='in', labelsize=14, pad=10)
    ax.tick_params(which='major', length=10, width =3)
    ax.tick_params(which='minor', length=5, width =2)
    ax.set_xscale( "log" , basex=10)
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(3)
        ax.spines[axis].set_color('k')
    ax.set_xlim([1,100])

    min_peak = peaks[0][0]
    ax.plot(distances[5:-5], curvature_model[5:-5], alpha = 1, lw=5, color=color)
    ax.set_ylim([-3e-4,3e-4])
    
    peak_of_interest = int(min_peak)
    ax.axvline(distances[peak_of_interest], lw=3,alpha=0.5, color=color, ls='dashed')
    ax.set_xlabel('Distance (phones)', labelpad=5, fontsize=18)
    print(row.language, distances[peak_of_interest])
    
    for ax in [ax1, ax2]:
        if axi==0: 
            ax.set_ylabel('MI (bits)', labelpad=5, fontsize=18)
            ax.yaxis.set_label_coords(yoff,0.5)
    
    for ax in [ax0,ax1,ax2]:
        ax.set_xlim([distances[0], distances[-1]])
        sig_lims[0] = np.log(10e-6)
        ax.set_ylim(np.exp(sig_lims))
        ax.tick_params(which='both', direction='in', labelsize=14, pad=10)
        ax.tick_params(which='major', length=10, width =3)
        ax.tick_params(which='minor', length=5, width =2)
        ax.set_xscale( "log" , basex=10)
        ax.set_yscale( "log" , basey=10)
        ax.set_xticks([])
        for axis in ['top','bottom','left','right']:
            ax.spines[axis].set_linewidth(3)
            ax.spines[axis].set_color('k')
    ax3.set_xticks([1,10,100])
    ax3.set_xticklabels(['1','10','100'])
    ax2.set_ylim([10e-4, ax2.get_ylim()[1]])
        
save_fig(FIGURE_DIR/'lang_fig')

### language dataset statistics

In [None]:
import seaborn as sns

In [None]:
german_stats = pd.read_pickle(DATA_DIR/'stats_df/GECO_stats_df.pickle')
german_stats['Language'] = 'German'

italian_stats = pd.read_pickle(DATA_DIR/'stats_df/AsiCA_stats_df.pickle')
italian_stats['Language'] = 'Italian'

english_stats = pd.read_pickle(DATA_DIR/'stats_df/BUCKEYE_stats_df.pickle')
english_stats['Language'] = 'English'

japanese_stats = pd.read_pickle(DATA_DIR/'stats_df/CSJ_stats_df.pickle')
japanese_stats['Language'] = 'Japanese'

stats_df = pd.concat([german_stats, italian_stats, english_stats, japanese_stats])

In [None]:
fig, axs = plt.subplots(ncols=4, figsize=(20,2))
for i,l in enumerate(['Japanese', 'English', 'German', 'Italian']):
    ax = axs.flatten()[i]
    wlp =stats_df[stats_df.Language==l].word_length_phones.values[0]
    np.sum(np.array(wlp) == 1)/len(wlp)
    ax.hist(wlp,bins=np.arange(25), density=True, color = LCOL_DICT[l.lower()])
    ax.set_xlim([1,25])

In [None]:
from matplotlib.ticker import FixedLocator

In [None]:
bw = 0.5
kwk = {"lw": 6, 'bw':bw}
d = 100
yoff=-.20
nrows = np.ceil(len(subset_MI_DF)/ncols).astype(int)
fig = plt.figure(figsize=(len(subset_MI_DF)*5,zoom/2.3)) 
gs = gridspec.GridSpec(ncols=len(subset_MI_DF), nrows=1) 
#bins=np.arange(100)+0.5
#bins=np.arange(-.5, 100)
bins = np.arange(-.5, 15, .85)
for li, (language,) in enumerate([['German'], ['Italian'], ['Japanese'], ['English']]):
    ax = plt.subplot(gs[li])
    italian_word_lens = np.log2(np.array(stats_df[stats_df.Language==language].word_length_phones.values[0]))
    sns.distplot((italian_word_lens[italian_word_lens<15]), color = LCOL_DICT[language.lower()], ax =ax,bins=bins,
                 kde_kws=kwk);  
    ax.axvline(np.median(italian_word_lens), lw=3,alpha=0.5, color=LCOL_DICT[language.lower()], ls='dashed')
    ax.set_xlabel('Word length (phones)', labelpad=5, fontsize=18)
    
    ax.tick_params(axis='both', labelsize=14, pad=15)
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(3)
        ax.spines[axis].set_color('k')
    ax.grid(False)
    ax.tick_params(which='both', direction='in', labelsize=14, pad=10)
    ax.tick_params(which='major', length=10, width =3)
    ax.tick_params(which='minor', length=5, width =2)
    
    #ax.set_xscale( "log" , basex=10)

    if li==0:  
        ax.set_ylabel('Prob. Density', labelpad=5, fontsize=18)
        ax.yaxis.set_label_coords(yoff,0.5)
    else:
        ax.set_yticklabels([])
    ax.set_xticks([np.log2(1),np.log2(10),np.log2(100)])
    ax.set_xticklabels(['1','10','100'])
    ax.set_xlim([np.log2(1),np.log2(language_d[language.lower()])])
    ax.set_ylim([0,1])
    minor_ticks = np.log2(np.array(list(np.arange(1,10)) + list(np.arange(10,100,10)) + list(np.arange(100,1000,100))))
    minor_locator = FixedLocator(minor_ticks)
    ax.xaxis.set_minor_locator(minor_locator)
    

save_fig(FIGURE_DIR/'word_len_dist')