In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy import stats
import seaborn as sns
from altair import *
import itertools
from pipeline import Pipeline

# sns.set_theme(style='white', palette='Set2')
sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
plt.rcParams['xtick.bottom'] = False
plt.rcParams['ytick.left'] = True

#### A -> V

In [None]:
va_pipeline = Pipeline(
                        trf_direction=1, 
                        trf_min_lag=0, 
                        trf_max_lag=3,
                        regularization=1,
                        modality='va',
                        audio_type='auditory_nerve',
                        similarity_measure='r2'
                    )

In [None]:
df = va_pipeline.make_main_df()
df_crossCorr_va = va_pipeline.make_crosscorr_df(df)

In [None]:
g = sns.FacetGrid(df_crossCorr_va, row="listener_au", hue="condition", aspect=10, height=0.5, palette='Set2', legend_out=True)
g.map(sns.kdeplot, "peak_lag", bw_adjust=.5, clip_on=False, fill=True, lw=1)
g.map(sns.kdeplot, "peak_lag", bw_adjust=.5, clip_on=False, lw=1)
g.refline(x=0, lw=0.5, alpha=1, linestyle="--", clip_on=False)
g.refline(y=0, lw=1, alpha=1, linestyle="-", clip_on=False)
for i, ax in enumerate(g.axes.flat):
    ax.text(0, .2, va_pipeline.aus[i], fontsize=8, fontweight="bold", ha="left", va="center", color=ax.lines[-1].get_color(), transform=ax.transAxes)
g.figure.subplots_adjust(hspace=-.5)
g.set_titles("")
g.set(yticks=[], ylabel="")
# g.set(xticks=[], xlabel="")
g.despine(left=True, bottom=True)
g.set_axis_labels('Lag [s]', '')

#### V -> V

In [None]:
vv_pipeline = Pipeline(
                        trf_direction=1, 
                        trf_min_lag=0, 
                        trf_max_lag=3,
                        regularization=1,
                        modality='vv',
                        audio_type='auditory_nerve',
                        similarity_measure='r2'
                    )

In [None]:
df_vv = vv_pipeline.make_main_df()
df_crossCorr_vv = vv_pipeline.make_crosscorr_df(df_vv)
df_cc_true = df_crossCorr_vv[df_crossCorr_vv['condition'] == 'true']
df_cc_fake = df_crossCorr_vv[df_crossCorr_vv['condition'] != 'true']

In [None]:
sns.heatmap(df_cc_true.pivot_table(index='speaker_au', columns='listener_au', values='peak_lag'), cmap='Reds', vmin=-0.5, vmax=0.5, square=True)
plt.ylabel('Speaker AU')
plt.xlabel('Listener AU')
plt.yticks(rotation=0)
plt.xticks(rotation=90)

In [None]:
sns.heatmap(df_cc_fake.pivot_table(index='speaker_au', columns='listener_au', values='peak_lag'), cmap='Reds', vmin=-0.5, vmax=0.5, square=True)
plt.ylabel('Speaker AU')
plt.xlabel('Listener AU')
plt.yticks(rotation=0)
plt.xticks(rotation=90)

In [None]:
au1s, au2s, t_statistics, p_vals = [], [], [], []
for idx, item in enumerate(itertools.product(vv_pipeline.aus, repeat=2)):
    true_lags = df_cc_true[(df_cc_true['speaker_au']==item[0]) & (df_cc_true['listener_au']==item[1])]['peak_lag'].to_numpy()
    fake_lags = df_cc_fake[(df_cc_fake['speaker_au']==item[0]) & (df_cc_fake['listener_au']==item[1])]['peak_lag'].to_numpy()
    res = stats.ttest_ind(true_lags, fake_lags)
    au1s.append(item[0])
    au2s.append(item[1])
    t_statistics.append(res.statistic)
    p_vals.append(res.pvalue)

df_ttest = pd.DataFrame({
                            'speaker_au': au1s,
                            'listener_au': au2s,
                            't': t_statistics,
                            'p': p_vals
                        })

heatmap_data = df_ttest.pivot_table(index='speaker_au', columns='listener_au', values='t')
t_heatmap = plt.figure()
ax = sns.heatmap(heatmap_data, cmap='crest', vmin=-1, vmax=1, square=True)
for index, row in df_ttest[df_ttest['p'] < 0.05].iterrows():
    name_pos = heatmap_data.index.get_loc(row['speaker_au'])
    id_pos = heatmap_data.columns.get_loc(row['listener_au'])
    ax.add_patch(Rectangle((id_pos, name_pos), 1, 1, ec='r', fc='none', lw=1, linestyle='--'))