In [None]:
import numpy as np
import pandas as pd
from scipy.signal import butter, filtfilt, ellip, detrend
import more_itertools
from scipy.stats import zscore
import glob
from autogluon.tabular import TabularDataset, TabularPredictor
from sklearn.metrics import accuracy_score
import os

In [None]:
def notch_filter(data, notch, width, fs):
    fa, fb = (notch - width) / (fs / 2), (notch + width) / (fs / 2)
    b, a = butter(4, [fa, fb], btype="stop")
    return filtfilt(b, a, data)


def elliptic_filter(data, flow, fhigh, fs):
    Wn = np.array([flow, fhigh]) * (2 / fs)
    b, a = ellip(4, 0.1, 40, Wn, btype="pass")
    return filtfilt(b, a, data)


def local_filter(data, notch, width, flow, fhigh, fs):
    signal = data
    for n, w in zip(np.atleast_1d(notch), np.atleast_1d(width)):
        signal = notch_filter(signal, n, w, fs)
    signal = elliptic_filter(signal, flow, fhigh, fs)
    return detrend(signal)

In [None]:
# use the first trial in calibration stage from the 3rd experimental session, player 1, channel 1
EID = 3
pid = 1
cid = 1
annotation = 'S1_1'

fs = 512
f_notch = np.array([60, 120, 180, 240])
width = np.ones(len(f_notch)) * 3
flow, fhigh = 5, 250
window_size = 51
step_size = 1

data = pd.read_csv(f"./data/E{EID}_data.csv", index_col=0)

EMG_signal = pd.DataFrame(zscore(local_filter(
    data["P1_CH1"] - data["P1_CH1"].iloc[0], f_notch, width, flow, fhigh, fs
)))

indices = data[data.annotation == annotation].index.tolist()
index_1s_before = indices[0] - fs

# Compute RMS for the play window
windowed_data_play = list(
    more_itertools.windowed(np.concatenate(np.array(EMG_signal.loc[indices])), window_size, step=step_size)
)
rms_data_play = np.array([
    np.sqrt(np.square(window).mean(axis=0)) for window in windowed_data_play
])

# Compute RMS for the 1 second before onset
windowed_data_1s_before = list(
    more_itertools.windowed(np.concatenate(np.array(EMG_signal.loc[index_1s_before:indices[0]-1])), window_size, step=step_size)
)
rms_data_1s_before = np.array([
    np.sqrt(np.square(window).mean(axis=0)) for window in windowed_data_1s_before
])

peak_idx = np.argmax(rms_data_play)
peak = rms_data_play[peak_idx]

threshold1 = rms_data_play.max() * 0.25
baseline2 = rms_data_1s_before.mean()  # Baseline for 1s before window
threshold2 = baseline2 + 2 * rms_data_1s_before.std()

onset_index = np.where(rms_data_play > threshold1)[0][0]
rough_onset = data["timestamp"].iloc[indices[onset_index]]

candidates = np.where(rms_data_play[:onset_index] < threshold2)[0]
closest_index = candidates[-1] + 1 if len(candidates) > 0 else 0

precise_onset = data["timestamp"].iloc[indices[closest_index]]

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.patches import Patch
import matplotlib.ticker as ticker

plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.8
plt.rcParams['grid.linestyle'] = 'dotted'
plt.rcParams['figure.dpi'] = 300
plt.rcParams['font.size'] = 7
plt.rcParams['axes.labelsize'] = 7
plt.rcParams['axes.titlesize'] = 7
plt.rcParams['mathtext.default']='regular'
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

plt.style.use("seaborn-v0_8-deep")
sns.set_palette("deep")

In [None]:
red = sns.color_palette("Reds", 4)[-1]
green = sns.color_palette("Greens", 4)[-1]
blue = sns.color_palette("Blues", 4)[-1]

plt.figure(figsize=(3.54, 1.8))

rms = np.concatenate([rms_data_1s_before, rms_data_play])
plt.plot(rms, linewidth=0.5, color="k", zorder=0)
plt.vlines(len(rms_data_1s_before), ymin=0, ymax=np.max(rms), colors=red, linewidth=1, linestyle="--")
plt.annotate("1s before", xy=(len(rms_data_1s_before), 0.06), xytext=(-100, 0.08),
             arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.2'), fontsize=7)
plt.hlines(y=threshold2, xmin=0, xmax=len(rms_data_1s_before), colors=green, linewidth=1)
plt.annotate("Precise onset baseline", xy=(300, threshold2), xytext=(-100, 0.04),
             arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.2'), fontsize=7)
plt.annotate("", xy=(0-40, threshold2), xytext=(len(rms_data_1s_before) + 40, threshold2),
             arrowprops=dict(arrowstyle='<->', color=green, linewidth=1), fontsize=7)
plt.annotate("Rough onset", xy=(onset_index+len(rms_data_1s_before), rms_data_play[onset_index]), xytext=(1000, 0.025),
             arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.2'), fontsize=7)
plt.annotate("Precise onset", xy=(closest_index+len(rms_data_1s_before), rms_data_play[closest_index]), xytext=(950, 0.005),
             arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.2'), fontsize=7)
plt.hlines(y=threshold2, xmin=0, xmax=1000, colors=green, linewidth=1, linestyle="--", zorder=1)
plt.scatter(x=onset_index+len(rms_data_1s_before), y=rms_data_play[onset_index], color=blue)
plt.scatter(x=closest_index+len(rms_data_1s_before), y=rms_data_play[closest_index], color=green)
plt.scatter(x=peak_idx+len(rms_data_1s_before), y=peak, color=red)
plt.annotate("Peak", xy=(peak_idx+len(rms_data_1s_before), peak), xytext=(1200, 0.08),
             arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.2'), fontsize=7)
plt.ylabel("RMS")
plt.xlabel("Time step")
plt.show()

In [None]:
df = pd.read_csv('./data/onset.csv')

In [None]:
emg1_time_diff_all = []
emg2_time_diff_all = []
for i in range(1, 13):
    for pid in [1, 2]:
        emg1_time_diff_all.append(np.array((df[f'E{i}_P{pid}_CH1_precise'] - df[f'E{i}_P{pid}']).tolist()))
        emg1_time_diff_all[-1] = emg1_time_diff_all[-1][np.abs(emg1_time_diff_all[-1]) <= 1e4]
        emg2_time_diff_all.append(np.array((df[f'E{i}_P{pid}_CH2_precise'] - df[f'E{i}_P{pid}']).tolist()))
        emg2_time_diff_all[-1] = emg2_time_diff_all[-1][np.abs(emg2_time_diff_all[-1]) <= 1e4]

emg1_time_diff_all = np.concatenate(emg1_time_diff_all).flatten()
emg2_time_diff_all = np.concatenate(emg2_time_diff_all).flatten()

In [None]:
emg1_time_diff_stage1 = []
emg2_time_diff_stage1 = []

emg1_time_diff_stage2 = []
emg2_time_diff_stage2 = []

for i in range(1, 13):
    for pid in [1, 2]:
        if i == 2 and pid == 1:
            continue
        emg1_time_diff_stage1.append(df[f'E{i}_P{pid}_CH1_precise'][:15] - df[f'E{i}_P{pid}'][:15])
        emg2_time_diff_stage1.append(df[f'E{i}_P{pid}_CH2_precise'][:15] - df[f'E{i}_P{pid}'][:15])
        emg1_time_diff_stage1[-1] = emg1_time_diff_stage1[-1][np.abs(emg1_time_diff_stage1[-1]) <= 1e4]
        emg2_time_diff_stage1[-1] = emg2_time_diff_stage1[-1][np.abs(emg2_time_diff_stage1[-1]) <= 1e4]

        emg1_time_diff_stage2.append(df[f'E{i}_P{pid}_CH1_precise'][15:] - df[f'E{i}_P{pid}'][15:])
        emg2_time_diff_stage2.append(df[f'E{i}_P{pid}_CH2_precise'][15:] - df[f'E{i}_P{pid}'][15:])
        emg1_time_diff_stage2[-1] = emg1_time_diff_stage2[-1][np.abs(emg1_time_diff_stage2[-1]) <= 1e4]
        emg2_time_diff_stage2[-1] = emg2_time_diff_stage2[-1][np.abs(emg2_time_diff_stage2[-1]) <= 1e4]

emg1_time_diff_stage1 = np.concatenate(emg1_time_diff_stage1).flatten()
emg2_time_diff_stage1 = np.concatenate(emg2_time_diff_stage1).flatten()
emg1_time_diff_stage2 = np.concatenate(emg1_time_diff_stage2).flatten()
emg2_time_diff_stage2 = np.concatenate(emg2_time_diff_stage2).flatten()

In [None]:
visible_peak_1_time_diff_all = []
visible_peak_2_time_diff_all = []
for i in range(1, 13):
    for pid in [1, 2]:
        if i == 2 and pid == 1:
            continue
        visible_peak_1_time_diff_all.append(np.array((df[f'E{i}_P{pid}_CH1_peak'] - df[f'E{i}_P{pid}']).tolist()))
        visible_peak_1_time_diff_all[-1] = visible_peak_1_time_diff_all[-1][np.abs(visible_peak_1_time_diff_all[-1]) <= 1e4]
        visible_peak_2_time_diff_all.append(np.array((df[f'E{i}_P{pid}_CH2_peak'] - df[f'E{i}_P{pid}']).tolist()))
        visible_peak_2_time_diff_all[-1] = visible_peak_2_time_diff_all[-1][np.abs(visible_peak_2_time_diff_all[-1]) <= 1e4]


visible_peak_1_time_diff_all = np.concatenate(visible_peak_1_time_diff_all).flatten()
visible_peak_2_time_diff_all = np.concatenate(visible_peak_2_time_diff_all).flatten()

In [None]:
colors = sns.color_palette("deep")[:2]
plt.figure(figsize=(4, 2.6))

# Create violin plot
violin = plt.violinplot([visible_peak_1_time_diff_all, emg1_time_diff_all, emg1_time_diff_stage1, emg1_time_diff_stage2] , positions=[4.005, 3.005, 2.005, 1.005], vert=False, showmeans=True, side='high')
violin2 = plt.violinplot([visible_peak_2_time_diff_all, emg2_time_diff_all, emg2_time_diff_stage1, emg2_time_diff_stage2] , positions=[3.985, 2.985, 1.985, 0.985], vert=False, showmeans=True, side='low')

for i, pc in enumerate(violin['bodies']):
    pc.set_facecolor(colors[0])  # Cycle through 3 colors using modulo
    pc.set_alpha(0.7)

for i, pc in enumerate(violin2['bodies']):
    pc.set_facecolor(colors[1])  # Cycle through 3 colors using modulo
    pc.set_alpha(0.7)

violin['cbars'].set_linewidth(1)    # Reduce vertical line width
violin2['cbars'].set_linewidth(1)
violin['cmaxes'].set_linewidth(1)   # Reduce max line width
violin2['cmaxes'].set_linewidth(1)
violin['cmins'].set_linewidth(1)    # Reduce min line width
violin2['cmins'].set_linewidth(1)
violin['cmeans'].set_linewidth(1)   # Reduce mean marker width
violin2['cmeans'].set_linewidth(1)
plt.yticks([1, 2, 3, 4], ["Free\nplay", "Calibration", "Overall", "Overall\n(peak)"], rotation=90, va='center')
plt.xlabel('Visual onset time [ms]')
plt.xticks(rotation=0)

plt.legend([violin['bodies'][0], violin2['bodies'][0]], ['EMG-ch1', 'EMG-ch2'])

# Calculate mean values for each distribution
means = [
    np.mean(emg1_time_diff_all),
    np.mean(emg2_time_diff_all),
    np.mean(emg1_time_diff_stage1),
    np.mean(emg2_time_diff_stage1),
    np.mean(emg1_time_diff_stage2),
    np.mean(emg2_time_diff_stage2),
    np.mean(visible_peak_1_time_diff_all),
    np.mean(visible_peak_2_time_diff_all),
]

# Add text annotations for mean values
positions = [3.35, 2.75, 2.35, 1.75, 1.35, 0.7, 4.25, 3.75]
xs = [means[0] - 250, means[1], means[2] - 250, means[3], means[4] - 250, means[5], means[6] - 50, means[7]]
for pos, mean, x in zip(positions, means, xs):
    plt.text(x, pos, f'{mean:.1f}',
             verticalalignment='center',
             horizontalalignment='left',
             fontsize=8)


plt.grid(True, zorder=0)
plt.gca().set_axisbelow(True)

plt.show()

In [None]:
posed_results = []
spontaneous_results = []

for EID in range(1, 13):
    label = 'class'

    for pid in [1, 2]:
        if EID == 2 and pid == 1:
            continue
        posed_data = pd.read_parquet(f"./self_gesture_recognition_posed/features/E{EID}_P{pid}.parquet")
        spontaneous_data = pd.read_parquet(f"./self_gesture_recognition_spontaneous/features/E{EID}_P{pid}_stage2.parquet")
        existing_model = glob.glob(f"./self_gesture_recognition_posed/models/E{EID}_P{pid}_*")

        posed_results.append([])
        spontaneous_results.append([])
        for path in existing_model:
            predictor = TabularPredictor.load(path)
            # posed
            existing_comb = []
            for i in range(3):
                existing_comb.append(int(path.split('/')[-1].split('_')[2+i]))
            for trial in existing_comb:
                test_label = f"S1_{trial}"

                feature_and_class_cols = ["rms_ch1", "rms_ch2", "class"]

                test_data = TabularDataset(posed_data[posed_data.annotation == test_label][feature_and_class_cols])

                pred = predictor.predict(test_data.drop(columns=[label]))
                posed_results[-1].append(accuracy_score(test_data['class'], pred))

            # spontaneous
            for trial in range(1, 21):
                test_label = f"S2_{trial}_"

                feature_and_class_cols = ["rms_ch1", "rms_ch2", "class"]

                test_data = TabularDataset(spontaneous_data[spontaneous_data.annotation.str.startswith(test_label)][feature_and_class_cols])
                pred = predictor.predict(test_data.drop(columns=[label]))
                spontaneous_results[-1].append(accuracy_score(test_data['class'], pred))

In [None]:
n_participants = 23
trials_num = [15, 100]
conditions = ["Posed", "Spontaneous"]

data = {
    "Posed": posed_results,
    "Spontaneous": spontaneous_results
}

# Plot
plt.figure(figsize=(7.08, 1.8))
colors = sns.color_palette("deep")[:3]
x_base = np.arange(1, n_participants + 1)
offsets = [-0.25, 0.25]

for i, condition in enumerate(conditions):
    acc = data[condition]
    x_pos = x_base + offsets[i]
    for j in range(n_participants):
        # Scatter each trial
        plt.scatter([x_pos[j]] * trials_num[i], acc[j], color=colors[i], alpha=0.3, s=5)
        # Mean ± CI
        mean = np.nanmean(acc[j])
        if i >= 1:
            if j == 6:
                ci95 = 1.96 * np.nanstd(acc[j]) / np.sqrt(trials_num[i] - (1 * 5))
            elif j == 7:
                ci95 = 1.96 * np.nanstd(acc[j]) / np.sqrt(trials_num[i] - (3 * 5))
            else:
                ci95 = 1.96 * np.nanstd(acc[j]) / np.sqrt(trials_num[i])
        else:
            ci95 = 1.96 * np.nanstd(acc[j]) / np.sqrt(trials_num[i])

        plt.errorbar(x_pos[j], mean, yerr=ci95, fmt='o', color='black', capsize=2, elinewidth=1, markersize=3)

# Plot styling
plt.axhline(np.nanmean(posed_results), color=colors[0], linestyle='--', linewidth=1, label="mean accuracy (P)")
plt.axhline(np.nanmean(spontaneous_results), color=colors[1], linestyle='--', linewidth=1, label="mean accuracy (S)")

plt.xticks(x_base, [str(i) for i in x_base])
plt.xlim(0.5, n_participants + 0.5)
plt.ylim(0, 1)
plt.xlabel("Participant Index (n=23)")
plt.ylabel("Accuracy")
plt.grid(axis='y', linestyle='--', alpha=0.5)

legend_patches = [
    mpatches.Patch(color="skyblue", label="Posed"),
    mpatches.Patch(color="orange", label="Spontaneous"),
]

mean_line_P = mlines.Line2D([], [], color=colors[0], linestyle='--', linewidth=1, label="Mean Accuracy (P)")
mean_line_S = mlines.Line2D([], [], color=colors[1], linestyle='--', linewidth=1, label="Mean Accuracy (S)")

plt.legend(handles=legend_patches + [mean_line_P, mean_line_S], loc='lower center', bbox_to_anchor=(0.5, 1.02),
           ncol=4, frameon=False)
plt.tight_layout()

print(np.nanmean(np.concatenate(posed_results)), np.nanmean(np.concatenate(spontaneous_results)))

plt.show()

In [None]:
correctness = []
for group_id in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]:
    for pid in [1, 2]:
        if group_id == 2 and pid == 1:
            continue
        counter = 0
        gt = pd.read_parquet(f'self_gesture_recognition_spontaneous/features/E{group_id}_P{pid}_stage2.parquet')
        gt_groupby_trial = []
        pred_groupby_trial = []
        for trial in range(1, 21):
            pred_path = f'self_gesture_recognition_spontaneous/predictions/E{group_id}_P{pid}_{trial}.parquet'
            if os.path.exists(pred_path) is False:
                continue
            pred = pd.read_parquet(pred_path)
            ind = gt["annotation"] == f"S2_{trial}_1"
            if ind.sum() < 1400:
                print(f"Sample number: {ind.sum()}, E{group_id}_P{pid}_stage2 trial {trial} has less than 1400 samples")
                continue
            ind = (gt["annotation"] == f"S2_{trial}_1") | (gt["annotation"] == f"S2_{trial}_2")
            gt_groupby_trial.append(gt.loc[ind[ind == True].index, "class"].tolist())
            pred_groupby_trial.append(pred["prediction"].tolist())
        for trial in range(len(gt_groupby_trial)):
            correctness.append((np.array(gt_groupby_trial[trial]) == np.array(pred_groupby_trial[trial])).astype(int))
        counter += 1

min_len = min(len(c) for c in correctness)
correctness = [c[:min_len] for c in correctness]
correctness = np.array(correctness)
xticks = np.arange(0, correctness.shape[1]) * (1/512) * 1000
mean_correctness = correctness.mean(axis=0)
cumulative_mean_accuracy = np.array([mean_correctness[:i+1].mean() for i in range(len(mean_correctness))])

In [None]:
df = pd.read_csv('data/onset.csv')
df.head()

visual_time_diff_stage2 = []
emg1_time_diff_stage2 = []
emg2_time_diff_stage2 = []

for i in range(1, 13):
    for pid in [1, 2]:
        if i == 2 and pid == 1:
            continue
        visual_time_diff_stage2.append(df[f'E{i}_P{pid}'][15:] - df[f'E{i}_trial_start'][15:])
        emg1_time_diff_stage2.append(df[f'E{i}_P{pid}_CH1_precise'][15:] - df[f'E{i}_trial_start'][15:])
        emg2_time_diff_stage2.append(df[f'E{i}_P{pid}_CH2_precise'][15:] - df[f'E{i}_trial_start'][15:])
        visual_time_diff_stage2[-1] = visual_time_diff_stage2[-1][np.abs(visual_time_diff_stage2[-1]) <= 1e4]
        emg1_time_diff_stage2[-1] = emg1_time_diff_stage2[-1][np.abs(emg1_time_diff_stage2[-1]) <= 1e4]
        emg2_time_diff_stage2[-1] = emg2_time_diff_stage2[-1][np.abs(emg2_time_diff_stage2[-1]) <= 1e4]

visual_time_diff_stage2 = np.concatenate(visual_time_diff_stage2).flatten()
emg1_time_diff_stage2 = np.concatenate(emg1_time_diff_stage2).flatten()
emg2_time_diff_stage2 = np.concatenate(emg2_time_diff_stage2).flatten()

In [None]:
colors = sns.color_palette("deep")[:4]
grey = sns.color_palette("Greys", 4)[2]

mean_correctness = correctness.mean(axis=0)

plt.figure(figsize=(3.54, 1.8))

line1, = plt.plot(xticks, mean_correctness, linewidth=0.6, label='Accuracy', color="k", alpha=0.4)
violin1 = plt.violinplot(visual_time_diff_stage2, positions=[0.60], vert=False, showmeans=True, widths=0.07)
violin2 = plt.violinplot(emg1_time_diff_stage2, positions=[0.50], vert=False, showmeans=True, widths=0.07)
violin3 = plt.violinplot(emg2_time_diff_stage2, positions=[0.40], vert=False, showmeans=True, widths=0.07)
plt.xlabel('Time since beep (ms)')
plt.ylabel('Accuracy')

for i, violin in enumerate([violin1, violin2, violin3]):
    for pc in violin['bodies']:
        pc.set_facecolor(colors[i])
        # pc.set_alpha(0.7)
    violin['cmeans'].set_color(colors[i])
    violin['cmeans'].set_linewidth(1)
    violin['cbars'].set_color(colors[i])
    violin['cbars'].set_linewidth(1)
    violin['cmaxes'].set_color(colors[i])
    violin['cmaxes'].set_linewidth(1)
    violin['cmins'].set_color(colors[i])
    violin['cmins'].set_linewidth(1)


ax = plt.gca()
ax2 = ax.twinx()
line2, = ax2.plot(xticks, cumulative_mean_accuracy, linewidth=1.1, color=colors[-1])
ax2.set_ylabel('Mean cumulative accuracy')

ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))
ax2.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))

ax2.grid(False)

violin_patch1 = Patch(color=colors[0], label='Visual onset')  # Adjust color if needed
violin_patch2 = Patch(color=colors[1], label='EMG-ch1 onset')  # Adjust color if needed
violin_patch3 = Patch(color=colors[2], label='EMG-ch2 onset')  # Adjust color if needed

plt.legend([violin_patch1, violin_patch2, violin_patch3, line1, line2], ['Visual onset', 'EMG-ch1 onset', 'EMG-ch2 onset', 'Accuracy', 'Mean cumulative accuracy'], loc='lower left', ncol=2, bbox_to_anchor=(-0.02, 1.0), frameon=False)

plt.show()

In [None]:
correctness = []
for group_id in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]:
    for pid in [1, 2]:
        if group_id == 2 and pid == 1:
            continue
        for trial in range(1, 21):
            try:
                pred = pd.read_parquet(f'opponent_gesture_prediction/predictions/E{group_id}_P{pid}_{trial}.parquet')
                gt = pd.read_parquet(f'opponent_gesture_prediction/features/E{group_id}_P{pid}_test_{trial}.parquet')
            except:
                continue
            gt_groupby_trial = []
            pred_groupby_trial = []
            ind = gt["annotation"] == f"S2_{trial}_"
            if ind.sum() < 1400:
                print(f"E{group_id}_P{pid}_test_{trial} has less than 1400 samples")
                continue
            # ind = (gt["annotation"] == f"S2_{trial}_1") | (gt["annotation"] == f"S2_{trial}_2")
            gt_groupby_trial.append(gt[ind]["opponent"].tolist())
            pred_groupby_trial.append(pred["prediction"].tolist())
        for trial in range(len(gt_groupby_trial)):
            correctness.append((np.array(gt_groupby_trial[trial]) == np.array(pred_groupby_trial[trial])).astype(int))

min_len = min(len(c) for c in correctness)
correctness = [c[:min_len] for c in correctness]
correctness = np.array(correctness)

xticks = np.arange(0, correctness.shape[1]) * (1/512) * 1000
mean_correctness = correctness.mean(axis=0)
cumulative_mean_accuracy = np.array([mean_correctness[:i+1].mean() for i in range(len(mean_correctness))])

In [None]:
colors = sns.color_palette("deep")[:4]
grey = sns.color_palette("Greys", 4)[2]

mean_correctness = correctness.mean(axis=0)

plt.figure(figsize=(3.54, 1.8))

line1, = plt.plot(xticks, mean_correctness, linewidth=.6, label='Accuracy', color="k", alpha=0.4, zorder=0)
violin1 = plt.violinplot(visual_time_diff_stage2, positions=[0.5], vert=False, showmeans=True, widths=0.15)
violin2 = plt.violinplot(emg1_time_diff_stage2, positions=[0.3], vert=False, showmeans=True, widths=0.15)
violin3 = plt.violinplot(emg2_time_diff_stage2, positions=[0.1], vert=False, showmeans=True, widths=0.15)
plt.xlabel('Time since beep (ms)')
plt.ylabel('Accuracy')

for i, violin in enumerate([violin1, violin2, violin3]):
    for pc in violin['bodies']:
        pc.set_facecolor(colors[i])
        # pc.set_alpha(0.8)
    violin['cmeans'].set_color(colors[i])
    violin['cmeans'].set_linewidth(1)
    violin['cbars'].set_color(colors[i])
    violin['cbars'].set_linewidth(1)
    violin['cmaxes'].set_color(colors[i])
    violin['cmaxes'].set_linewidth(1)
    violin['cmins'].set_color(colors[i])
    violin['cmins'].set_linewidth(1)


ax = plt.gca()
ax2 = ax.twinx()
line2, = ax2.plot(xticks, cumulative_mean_accuracy, linewidth=1.1, color=colors[-1])
ax2.set_ylabel('Mean cumulative accuracy')

ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))
ax2.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))

ax2.grid(False)

violin_patch1 = Patch(color=colors[0], label='Visual onset')
violin_patch2 = Patch(color=colors[1], label='EMG-ch1 onset')
violin_patch3 = Patch(color=colors[2], label='EMG-ch2 onset')

plt.legend([violin_patch1, violin_patch2, violin_patch3, line1, line2], ['Visual onset', 'EMG-ch1 onset', 'EMG-ch2 onset', 'Accuracy', 'Mean cumulative accuracy'], loc='lower left', ncol=2, bbox_to_anchor=(-0.02, 1.0), frameon=False)

plt.show()