In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import zscore
from scipy.stats import trim_mean
sns.set(style="white", font_scale=1.25)

In [None]:
df_erpac = pd.read_parquet("output/csv/df_cfs_ERGCPAC_inverted.parquet")
print(df_erpac.shape)
df_erpac.head()

In [None]:
# Keep only specific time
df_erpac.reset_index(inplace=True)
df_erpac['time'] = -1 * (df_erpac['time'].max() / 2 - df_erpac['time'])
df_erpac = df_erpac.set_index(['subj', 'time']).sort_index()

df_erpac = df_erpac.loc[(slice(None), slice(-1, 1)), :]

In [None]:
# Choose ERPAC method to keep ("circ" or "gc")
method = "circ"
tmp = df_erpac.filter(like=method)
tmp.columns = tmp.columns.str.strip(method + "_")
df_erpac = df_erpac.loc[:, ["avg_sw"]].join(tmp)
df_erpac.head().round(3)

In [None]:
# Remove frequencies from 5-10 Hz
df_erpac.drop(columns=np.arange(5, 10, 0.5).astype(str), inplace=True)
df_erpac.drop(columns=np.arange(22.5, 25.5, 0.5).astype(str), inplace=True)

In [None]:
unique_subj = df_erpac.index.get_level_values(0).unique()
n_subj = len(unique_subj)
n_freqs = df_erpac.iloc[:, 1:].shape[1]
n_times = df_erpac.index.get_level_values(1).nunique()
print(n_subj, n_freqs, n_times)

In [None]:
# Convert to a numpy array and z-score
data = np.empty((n_subj, n_freqs, n_times))
data_sw = np.empty((n_subj, n_times))

for i, sub in enumerate(unique_subj):
    data[i] = df_erpac.xs(sub, level=0).iloc[:, 1:].to_numpy().T
    data_sw[i] = df_erpac.xs(sub, level=0).loc[:, "avg_sw"].to_numpy()
    # Z-score
    data[i] = zscore(data[i], axis=None)
    data_sw[i] = zscore(data_sw[i])
    
print(data.shape, data_sw.shape)

In [None]:
# Average across subject
data_avg = trim_mean(data, proportiontocut=0.05, axis=0)
data_sw_avg = trim_mean(data_sw, proportiontocut=0, axis=0)
data_avg.shape, data_sw_avg.shape

In [None]:
xvec = df_erpac.index.get_level_values(1).unique().to_numpy()
yvec = df_erpac.iloc[:, 1:].columns.to_numpy().astype(float)

In [None]:
fig, ax = plt.subplots(figsize=(6, 5), dpi=100)

im = plt.imshow(data_avg, aspect='auto', cmap="Spectral_r", origin='upper',
                interpolation="gaussian", 
                vmin=-0.2, vmax=1,
                extent=[xvec[0], xvec[-1], yvec[-1], yvec[0]])

plt.gca().invert_yaxis()

plt.xlabel("Time from negative peak of SO (sec)")
plt.ylabel("Frequency (Hz)")
plt.axvline(0, ls=":", lw=1.5, color="k")

cb = plt.colorbar(im, shrink=0.7, pad=0.05, aspect=20)
cb.set_label("Coupling (z-score)")
cb.outline.set_visible(False)

ax_sw = ax.twinx()
ax_sw.plot(xvec, data_sw_avg, color="k", lw=3)
ax_sw.set_yticks([]);