# Appendix 1
## Basic EM Visualizations Available Through pEYES

In [None]:
import warnings
import copy

import pandas as pd
import cv2
from plotly.subplots import make_subplots
import plotly.io as pio

from analysis._article_results.lund2013._helpers import *

# pio.renderers.default = "browser"

FONT_FAMILY, FONT_COLOR = "Calibri", "black"
TITLE_FONT = dict(family=FONT_FAMILY, size=40, color=FONT_COLOR)
AXIS_LABEL_FONT = dict(family=FONT_FAMILY, size=32, color=FONT_COLOR)
AXIS_TICK_FONT = dict(family=FONT_FAMILY, size=28, color=FONT_COLOR)

## Load Data
### (1) Stimulus
_(We use the image's resolution to determine the figure's resolution.)_

In [None]:
STIMULUS_TO_SHOW = 'konijntjes'
PATH_TO_STIMULI = os.path.join(u.BASE_DIR, "stimuli", DATASET_NAME.capitalize(), STIMULUS_TYPE)

img = cv2.imread(os.path.join(PATH_TO_STIMULI, f"{STIMULUS_TO_SHOW}.png"))
resolution = (img.shape[1], img.shape[0])

### (2) Gaze Data

In [None]:
dataset = u.load_dataset(DATASET_NAME, verbose=False)
image_dataset = dataset[dataset[peyes.constants.STIMULUS_TYPE_STR] == peyes.constants.IMAGE_STR]

# extract single-trial data
relevant_trial_ids = (
        ~image_dataset.groupby(peyes.constants.TRIAL_ID_STR)['RA'].apply(lambda trl: trl.isnull().all()) &
        ~image_dataset.groupby(peyes.constants.TRIAL_ID_STR)['MN'].apply(lambda trl: trl.isnull().all()) &
        image_dataset.groupby(peyes.constants.TRIAL_ID_STR)[peyes.constants.STIMULUS_NAME_STR].apply(lambda trl: trl.iloc[0] == STIMULUS_TO_SHOW)
)
trial_id = relevant_trial_ids[relevant_trial_ids].index[0]
trial_data = image_dataset[image_dataset[peyes.constants.TRIAL_ID_STR] == trial_id]

# extract gaze data from single-trial
pixel_size = trial_data[peyes.constants.PIXEL_SIZE_STR].values[0]
viewer_distance = trial_data[peyes.constants.VIEWER_DISTANCE_STR].values[0]
t = trial_data[peyes.constants.T].values
x = trial_data[peyes.constants.X].values
y = trial_data[peyes.constants.Y].values

x[(x < 0) | (x >= resolution[0])] = np.nan
y[(y < 0) | (y >= resolution[1])] = np.nan


### (3) Labels

In [None]:
labels_df = pd.read_pickle(os.path.join(PROCESSED_DATA_DIR, DATASET_NAME, peyes.constants.LABELS_STR + ".pkl"))
labels_df = labels_df.xs(1, level=peyes.constants.ITERATION_STR, axis=1)            # Only use first iteration
labels_df = labels_df.xs(trial_id, level=peyes.constants.TRIAL_ID_STR, axis=1)      # Only use the relevant trial
labels_df = labels_df.dropna(how="all", axis=0)                                     # Drop rows with all NaNs

labeler_names = u.sort_labelers(labels_df.columns.get_level_values(peyes.constants.LABELER_STR).unique())

### (4) Events

In [None]:
stim_trial_ids = u.get_trials_for_stimulus_type(DATASET_NAME, STIMULUS_TYPE)

all_events = pd.read_pickle(os.path.join(PROCESSED_DATA_DIR, DATASET_NAME, "events.pkl"))
all_events = all_events.xs(1, level=peyes.constants.ITERATION_STR, axis=1)
all_events = all_events.loc[:, all_events.columns.get_level_values(peyes.constants.TRIAL_ID_STR).isin(stim_trial_ids)]
all_events = all_events.dropna(axis=0, how="all")

all_labelers = all_events.columns.get_level_values(peyes.constants.LABELER_STR).unique()
events_by_labelers = {
    lblr: all_events.xs(lblr, level=peyes.constants.LABELER_STR, axis=1).stack().dropna() for lblr in all_labelers
}

## Appendix A1: Single-Trial Visualizations

In [None]:
heatmap = peyes.visualize.gaze_heatmap(
    x=x, y=y, resolution=resolution, title="Gaze Heatmap",
    bg_image=img, bg_image_format='rgb', bg_alpha=0.5,
    sigma=10, scale=100, opacity=0.7, colorscale='Jet'
)

# heatmap.update_layout(
#     title=None,
#     width=heatmap.layout.width // 2,
#     height=heatmap.layout.height // 2,
#     margin=dict(l=0, r=0, b=0, t=0, pad=0),
# )
# heatmap.show()

In [None]:
pixels_vs_time = peyes.visualize.gaze_over_time(
    x=x, y=y, t=t, resolution=resolution, title="Gaze Over Time",
    v=peyes._utils.pixel_utils.calculate_velocities(x, y, t), v_measure='px/s'
)

# pixels_vs_time.update_layout(
#     paper_bgcolor='rgba(0, 0, 0, 0)',
#     plot_bgcolor='rgba(0, 0, 0, 0)',
#     height=heatmap.layout.height,
#     width=heatmap.layout.width * 2,
#     title=None,
#     legend=dict(
#         font=dict(size=10),
#         orientation='v',
#         yanchor='top', y=1,
#         xanchor='left', x=0,
#     ),
#     yaxis=dict(
#         title=dict(text='gaze position (px)', standoff=0),
#         showgrid=True, zeroline=False, showline=True, rangemode='tozero',
#     ),
#     yaxis2=dict(
#         title=dict(text='gaze velocity (px/s)', standoff=0),
#         showgrid=False, zeroline=False, showline=True, rangemode='tozero',
#     ),
#     xaxis=dict(
#         showgrid=False, zeroline=False, showline=True,
#         range=[-(t.max() - t.min()) / 200, t.max() + (t.max() - t.min()) / 200]
#     ),
#     margin=dict(l=0, r=0, b=0, t=0, pad=0),
# )
# pixels_vs_time.show()

In [None]:
# labeler_subset = labeler_names
labeler_subset = ["RA", "MN"]

scarfplot = peyes.visualize.scarfplot_comparison_figure(
    t,
    # *[labels_df[labeler_name] for labeler_name in labeler_names],
    *[labels_df[lblr] for lblr in labeler_names if lblr in labeler_subset],
    names=labeler_subset,
)

# scarfplot.update_layout(
#     paper_bgcolor='rgba(0, 0, 0, 0)',
#     plot_bgcolor='rgba(0, 0, 0, 0)',
#     height=pixels_vs_time.layout.height // 2,
#     width=pixels_vs_time.layout.width,
#     title=None,
#     yaxis=dict(
#         title=dict(text='Labeler', standoff=0),
#         showgrid=False, zeroline=False, showline=True, rangemode='tozero',
#     ),
#     xaxis=dict(
#         showgrid=False, zeroline=False, showline=True, title=dict(text='time (sample)', standoff=0),
#     ),
#     margin=dict(l=0, r=0, b=0, t=0, pad=0),
# )
# scarfplot.show()

### Finalize Figure A1

In [None]:
NAME = "supp_fig_A1"
WIDTH, HEIGHT = 2000, 800

#######

fig1 = make_subplots(
    rows=2, cols=2, shared_xaxes=True, shared_yaxes=False,
    vertical_spacing=0.05, horizontal_spacing=0.025,
    specs=[
        [{"type": "scatter"}, {"type": "image", "rowspan": 2}],
        [{"type": "heatmap"}, None]
    ],
)

# copy line traces from pixels_vs_time to top-left subplot
for tr in pixels_vs_time.data:
    if tr["name"] == 'v':
        # skip velocity trace - hard to see
        continue
    tr["line_width"] = 5     # explicitly set the line width
    fig1.add_trace(tr, row=1, col=1)

# copy heatmap traces from scarfplot to bottom-left subplot
for tr in scarfplot.data:
    new_tr = copy.deepcopy(tr)
    new_tr["showscale"] = False
    fig1.add_trace(new_tr, row=2, col=1)

# copy image traces from heatmap to right subplot
heatmap.for_each_trace(lambda trace: fig1.add_trace(trace, row=1, col=2))

fig1.update_layout(
    width=WIDTH, height=HEIGHT,
    font_family=FONT_FAMILY,
    paper_bgcolor='rgba(0, 0, 0, 0)', plot_bgcolor='rgba(0, 0, 0, 0)',
    legend=dict(
        orientation="h",
        yanchor="top", y=1.01,
        xanchor="left", x=0.01,
        bgcolor='rgba(0, 0, 0, 0)',
        font=AXIS_TICK_FONT,
        itemwidth=75,
    ),
    margin=dict(l=0, r=0, t=0, b=0),

    # line plot axes
    xaxis=dict(showticklabels=False, showgrid=False,),
    yaxis=dict(
        title=dict(text="Gaze Location (px)", font=AXIS_LABEL_FONT, standoff=8),
        showgrid=True, gridcolor='lightgray', gridwidth=5,
        zeroline=True, zerolinecolor='lightgray', zerolinewidth=5,
        rangemode="tozero", showticklabels=True, tickfont=AXIS_TICK_FONT,
    ),

    # scarfplot axes
    xaxis3=dict(
        title=dict(text="Time (ms)", font=AXIS_LABEL_FONT, standoff=4),
        showgrid=False,
        tickangle=30, tickmode='array', tickfont=AXIS_TICK_FONT,
        tickvals=np.arange(0, 10001, 1000), ticktext=[f"{val//1000}k" for val in np.arange(0, 10001, 1000)],
    ),
    yaxis3=dict(
        title=dict(text="Annotator", font=AXIS_LABEL_FONT, standoff=4), showgrid=False, tickangle=0,
        tickmode='array', tickvals=scarfplot.layout['yaxis']['tickvals'],
        ticktext=[f"{txt}\t" for txt in scarfplot.layout['yaxis']['ticktext']],
        tickfont=AXIS_TICK_FONT,
    ),
    # heatmap axes - remove ticks and labels
    xaxis2=dict(showticklabels=False, showgrid=False), yaxis2=dict(showticklabels=False, showgrid=False),
)

fig1.write_image(os.path.join(FIGURES_DIR, f"{NAME}.png"), scale=3)
fig1.write_json(os.path.join(FIGURES_DIR, f"{NAME}.json"))
fig1.show()

## Appendix A2: Event Feature Distributions

In [None]:
NAME = "supp_fig_A2"
HEIGHT = 1200
WIDTH = int(2 * HEIGHT)

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    fig_1b_ra = peyes.visualize.event_summary(events_by_labelers["RA"], show_outliers=False)
    fig_1b_mn = peyes.visualize.event_summary(events_by_labelers["MN"], show_outliers=False)

fig2 = make_subplots(
    rows=3, cols=2,
    shared_xaxes=True, shared_yaxes=True,
    vertical_spacing=0.05, horizontal_spacing=0.025,
    column_titles=["RA", "MN"],
)

# copy image traces from RA's/MN's figures to the left/right column
for c, GT in enumerate([GT1, GT2]):
    orig_fig = fig_1b_ra if GT == "RA" else fig_1b_mn
    for trace in orig_fig.data:
        if trace["name"] == "SMOOTH_PURSUIT":
            trace["x0"] = trace["name"] = trace["legendgroup"] = "SP"
        if trace["name"] == "BLINK" and 'box' in trace:
            # `box` prop only exists in the violin plots
            trace["box"]["line"] = dict(color='lightgray')
        trace["name"] = f"{trace['name']} ({GT})"
        if trace.yaxis == "y":
            fig2.add_trace(trace, row=1, col=c+1)
        elif trace.yaxis == "y2":
            trace["points"] = False
            fig2.add_trace(trace, row=2, col=c+1)
        elif trace.yaxis == "y3":
            trace["points"] = False
            fig2.add_trace(trace, row=3, col=c+1)
        else:
            # ignore other features
            continue

fig2.for_each_yaxis(lambda yax: yax.update(
    showgrid=True, gridcolor='lightgray', gridwidth=2.5,
    zeroline=True, zerolinecolor='lightgray', zerolinewidth=5,
))
fig2.for_each_annotation(lambda ann: ann.update(font=TITLE_FONT, yanchor="top", y=1.015))

# update layout and axes
fig2.update_layout(
    width=WIDTH, height=HEIGHT,
    font_family=FONT_FAMILY,
    paper_bgcolor='rgba(0, 0, 0, 0)', plot_bgcolor='rgba(0, 0, 0, 0)',
    margin=dict(l=0, r=0, t=10, b=0),
    legend=dict(
        orientation="h",
        yanchor="bottom", y=0.01,
        xanchor="left", x=0.1,
        bgcolor='rgba(0, 0, 0, 0)',
        font=AXIS_TICK_FONT,
        itemwidth=100,
    ),
    showlegend=False,

    # count plot y-axis
    yaxis=dict(title=dict(
        text="# Instances", font=AXIS_LABEL_FONT, standoff=18),
        showticklabels=True, tickfont=AXIS_TICK_FONT,
    ),
    # duration plot y-axis
    yaxis3=dict(
        title=dict(text="Duration (ms)", font=AXIS_LABEL_FONT, standoff=18),
        showticklabels=True, tickfont=AXIS_TICK_FONT,
    ),
    # amplitude plot y-axis
    yaxis5=dict(
        title=dict(text="Amplitude (DVA)", font=AXIS_LABEL_FONT, standoff=18),
        showticklabels=True, tickfont=AXIS_TICK_FONT,
    ),
    # x-axis for all subplots
    xaxis5=dict(
        showgrid=False, zeroline=False, showline=True, tickangle=0, tickfont=AXIS_TICK_FONT,
    ),
    xaxis6=dict(
        showgrid=False, zeroline=False, showline=True, tickangle=0, tickfont=AXIS_TICK_FONT,
    ),
)

fig2.write_image(os.path.join(FIGURES_DIR, f"{NAME}.png"), scale=3)
# fig2.write_json(os.path.join(FIGURES_DIR, f"{NAME}.json"))
fig2.show()

## Appendix A3: Main-Sequence Example

In [None]:
def create_main_sequence_fig(all_labelers_events, labeler: str) -> (go.Figure, pd.DataFrame):
    labeler_events = all_labelers_events[labeler]
    labeler_saccades = labeler_events[
        labeler_events.map(lambda evnt: evnt.label == peyes._DataModels.EventLabelEnum.EventLabelEnum.SACCADE)
    ].dropna()
    fig, stats = peyes.visualize.main_sequence(labeler_saccades, y_feature=peyes.constants.DURATION_STR, include_outliers=False)
    for trace in fig.data:
        if trace['name'] == "Overall Trendline" and trace['xaxis'] == "x":
            trace['line']['width'] = 5
    fig.update_xaxes(
        tickfont=AXIS_TICK_FONT,
        title=dict(font=AXIS_LABEL_FONT),
        showgrid=True, gridcolor='lightgray', gridwidth=2.5,
        zeroline=True, zerolinecolor='lightgray', zerolinewidth=2.5,
    )
    fig.update_yaxes(
        tickfont=AXIS_TICK_FONT,
        title=dict(font=AXIS_LABEL_FONT),
        showgrid=True, gridcolor='lightgray', gridwidth=2.5,
        zeroline=True, zerolinecolor='lightgray', zerolinewidth=2.5,
    )
    fig.update_layout(
        title=None,
        width=800, height=500,
        font_family=FONT_FAMILY,
        paper_bgcolor='rgba(0, 0, 0, 0)', plot_bgcolor='rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, t=0, b=0),
        showlegend=False,
    )
    return fig, stats

#### (1) Annotator RA's Saccades

In [None]:
NAME = "supp_fig_A3_RA"
WIDTH, HEIGHT = 1600, 900

fig3_ra, stats_ra = create_main_sequence_fig(events_by_labelers, "RA")
fig3_ra.update_layout(width=WIDTH, height=HEIGHT,)

display(stats_ra.iloc[0,0].summary())

fig3_ra.write_image(os.path.join(FIGURES_DIR, f"{NAME}.png"), scale=3)
# fig3_ra.write_json(os.path.join(FIGURES_DIR, f"{NAME}.json"))
fig3_ra.show()

#### (2) Annotator MN's Saccades

In [None]:
NAME = "supp_fig_A3_MN"
WIDTH, HEIGHT = 1600, 900

fig3_mn, stats_mn = create_main_sequence_fig(events_by_labelers, "MN")
fig3_mn.update_layout(width=WIDTH, height=HEIGHT,)

display(stats_mn.iloc[0,0].summary())
# stats_mn.write_image(os.path.join(FIGURES_DIR, f"{NAME}.png"), scale=3)
# stats_mn.write_json(os.path.join(FIGURES_DIR, f"{NAME}.json")
# )
fig3_mn.show()