In [1]:
import itertools
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
from altair import *
from altair import selection_single
from pipeline import Pipeline
from io import *
import base64
from PIL import Image
from sklearn.preprocessing import StandardScaler

sns.set_theme(style='white', palette='Set2')
plt.rcParams['xtick.bottom'] = True
plt.rcParams['ytick.left'] = True

In [2]:
pipeline = Pipeline(
                        trf_direction=1, 
                        trf_min_lag=0, 
                        trf_max_lag=3,
                        regularization=0.1,
                        modality='vv',
                        audio_type='auditory_nerve'
                    )

In [3]:
df = pipeline.make_main_df()

In [4]:
df_responses = pipeline.make_response_df()

df1 = df[df['Condition'] == 'TRUE']
df1 = df1.reset_index(drop=True)
df2 = df[df['Condition'] != 'TRUE']
df2 = df2.reset_index(drop=True)

In [5]:
trfs = pipeline.train_model(df1)

Cross-validating[##################################################] 100/100..] 0/5

Cross-validating[##################################################] 100/100..] 1/5

Cross-validating[##################################################] 100/100..] 2/5

Cross-validating[##################################################] 100/100..] 3/5

Cross-validating[##################################################] 100/100..] 4/5

Hyperparameter optimization[##################################################] 5/5

Cross-validating[##################################################] 100/100..] 0/5

Cross-validating[##################################################] 100/100..] 1/5

Cross-validating[##################################################] 100/100..] 2/5

Cross-validating[##################################################] 100/100..] 3/5

Cross-validating[##################################################] 100/100..] 4/5

Hyperparameter optimization[#####################################

In [6]:
def plot(
    direction,
    trf,
    channel=None,
    feature=None,
    axes=None,
    show=True,
    kind="line",
):
    """
    Plot the weights of the (forward) model across time for a select channel or feature.

    Arguments:
        channel (None | int | str): Channel selection. If None, all channels will be used. If an integer, the channel at that index will be used. If 'avg' or 'gfp' , the average or standard deviation across channels will be computed.
        feature (None | int | str): Feature selection. If None, all features will be used. If an integer, the feature at that index will be used. If 'avg' , the average across features will be computed.
        axes (matplotlib.axes.Axes): Axis to plot to. If None is provided (default) generate a new plot.
        show (bool): If True (default), show the plot after drawing.
        kind (str): Type of plot to draw. If 'line' (default), average the weights across all stimulus features, if 'image' draw a features-by-times plot where the weights are color-coded.

    Returns:
        fig (matplotlib.figure.Figure): If now axes was provided and a new figure is created, it is returned.
    """
    if plt is None:
        raise ModuleNotFoundError("Need matplotlib to plot TRF!")
    if direction == -1:
        weights = trf.weights.T
        print(
            "WARNING: decoder weights are hard to interpret, consider using the `to_forward()` method"
        )
    if axes is None:
        fig, ax = plt.subplots(figsize=(6, 6))
    else:
        fig, ax = None, axes  # dont create a new figure
    weights = trf.weights
    # select channel and or feature
    if weights.shape[0] == 1:
        feature = 0
    if weights.shape[-1] == 1:
        channel = 0
    if channel is None and feature is None:
        raise ValueError("You must specify a subset of channels or features!")
    if feature is not None:
        image_ylabel = "channel"
        if isinstance(feature, int):
            weights = weights[feature, :, :]
        elif feature == "avg":
            weights = weights.mean(axis=0)
        else:
            raise ValueError('Argument `feature` must be an integer or "avg"!')
    if channel is not None:
        image_ylabel = "feature"
        if isinstance(channel, int):
            weights = weights.T[channel].T
        elif channel == "avg":
            weights = weights.mean(axis=-1)
        elif channel == "gfp":
            weights = weights.std(axis=-1)
        else:
            raise ValueError(
                'Argument `channel` must be an integer, "avg" or "gfp"'
            )
        weights = weights.T  # transpose so first dimension is time
    # plot the result
    scaler = StandardScaler()
    if kind == "line":
        ax.plot(
            trf.times.flatten(), scaler.fit_transform(weights.reshape(-1, 1)), linewidth=2 - 0.01 * weights.shape[-1]
        )
        ax.set(
            xlabel="Time lag[s]",
            ylabel="Amplitude [a.u.]",
            xlim=(trf.times.min(), trf.times.max()),
        )
    elif kind == "image":
        scale = trf.times.max() / len(trf.times)
        im = ax.imshow(
            weights.T,
            origin="lower",
            aspect="auto",
            extent=[0, weights.shape[0], 0, weights.shape[1]],
        )
        extent = np.asarray(im.get_extent(), dtype=float)
        extent[:2] *= scale
        im.set_extent(extent)
        ax.set(
            xlabel="Time lag [s]",
            ylabel=image_ylabel,
            xlim=(trf.times.min(), trf.times.max()),
        )
    if show is True:
        plt.show()
    if fig is not None:
        return fig

In [7]:
img_list = []
for i, item in enumerate(itertools.product(pipeline.aus, repeat=2)):
    plot(direction=1, trf=trfs[i], channel='avg', feature='avg', show=False) 
    plt.title(f'TRF for {item[0]} -> {item[1]}')
    plt.gca().get_lines()[0].set_color("r")
    plt.gca().get_lines()[0].set_linewidth(2)
    plt.axhline(y=0, color='k', linestyle='--')
    plt.savefig(f'./trf_imgs/{item[0]}_{item[1]}.png')
    
    image = Image.open(f'./trf_imgs/{item[0]}_{item[1]}.png')
    output = BytesIO()    
    image.save(output, format='png')
    encoded_string = "data:image/png;base64,"+base64.b64encode(output.getvalue()).decode()
    img_list.append(encoded_string)
    plt.close()

In [None]:
true_data = pipeline.predict_response(df1, trfs)
df_trueCorrs = pipeline.make_trf_df(true_data)

fake_data = pipeline.predict_response(df2, trfs)
df_fakeCorrs = pipeline.make_trf_df(fake_data)

In [None]:
df_correlations = pd.concat([df_trueCorrs, df_fakeCorrs])

In [None]:
df_altair_true = df_trueCorrs.groupby(['speaker_au', 'listener_au'], as_index=False)['r'].mean()
df_altair_true['image'] = img_list
Chart(df_altair_true).mark_rect().encode(
    x=X('listener_au').title('Listener AU'),
    y=Y('speaker_au').title('Speaker AU'),
    color=Color('r', scale=Scale(domain=[-0.2, 0.2], domainMid=0, scheme='redblue')),
    tooltip=['image']).properties(height=600, width=600).interactive()

In [None]:
df_altair_fake = df_fakeCorrs.groupby(['speaker_au', 'listener_au'], as_index=False)['r'].mean()
df_altair_fake['image'] = img_list
Chart(df_altair_fake).mark_rect().encode(
    x=X('listener_au').title('Listener AU'),
    y=Y('speaker_au').title('Speaker AU'),
    color=Color('r', scale=Scale(domain=[-0.2, 0.2], domainMid=0, scheme='redblue')),
    tooltip=['image']).properties(height=600, width=600).interactive()

In [None]:
au1s, au2s, t_statistics, p_vals = [], [], [], []
for idx, item in enumerate(itertools.product(pipeline.aus, repeat=2)):
    true_corrs = df_trueCorrs[(df_trueCorrs['speaker_au']==item[0]) & (df_trueCorrs['listener_au']==item[1])]['r'].to_numpy()
    fake_corrs = df_fakeCorrs[(df_fakeCorrs['speaker_au']==item[0]) & (df_fakeCorrs['listener_au']==item[1])]['r'].to_numpy()
    res = stats.ttest_ind(true_corrs, fake_corrs)
    au1s.append(item[0])
    au2s.append(item[1])
    t_statistics.append(res.statistic)
    p_vals.append(res.pvalue)

df_ttest = pd.DataFrame({
                            'speaker_au': au1s,
                            'listener_au': au2s,
                            't': t_statistics,
                            'p': p_vals
                        })
# heatmap_data = df_ttest.pivot_table(index='speaker_au', columns='listener_au', values='t')
# t_heatmap = plt.figure()
# ax = sns.heatmap(heatmap_data, cmap='crest', vmin=-1, vmax=1, square=True)
# for index, row in df_ttest[df_ttest['p'] < 0.05].iterrows():
#     name_pos = heatmap_data.index.get_loc(row['speaker_au'])
#     id_pos = heatmap_data.columns.get_loc(row['listener_au'])
#     ax.add_patch(Rectangle((id_pos, name_pos), 1, 1, ec='r', fc='none', lw=1, linestyle='--'))
# st.write(t_heatmap)
t_heatmap = Chart(df_ttest).mark_rect().encode(
    x=X('listener_au').title('Listener AU'),
    y=Y('speaker_au').title('Speaker AU'),
    color=Color('t', scale=Scale(domainMid=0, scheme='yellowgreenblue')),
    tooltip='p').properties(height=600, width=600).interactive()
text = t_heatmap.mark_text().encode(
    Text('p', format=".2f"),
    opacity=condition(
        datum.p < 0.05,
        value(1),
        value(0)
    ),
    color=value('black'))
t_heatmap+text