# Figure 1
## Comparing Event Features between Labelers/Detectors

In [1]:
import copy

import pandas as pd
from plotly.subplots import make_subplots
import plotly.io as pio

from analysis._article_results.lund2013._helpers import *

pio.renderers.default = "browser"

  import scipy.linalg


## Load Data
Load events.pkl

In [2]:
stim_trial_ids = u.get_trials_for_stimulus_type(DATASET_NAME, STIMULUS_TYPE)

all_events = pd.read_pickle(
    os.path.join(PROCESSED_DATA_DIR, DATASET_NAME, "events.pkl")
).xs(1, level=peyes.constants.ITERATION_STR, axis=1)
all_events = all_events.loc[:, all_events.columns.get_level_values(peyes.constants.TRIAL_ID_STR).isin(stim_trial_ids)]
all_events = all_events.dropna(axis=0, how="all")

all_labelers = all_events.columns.get_level_values(peyes.constants.LABELER_STR).unique()
events_by_labelers = {
    lblr: all_events.xs(lblr, level=peyes.constants.LABELER_STR, axis=1).stack().dropna() for lblr in all_labelers
}
fixations_by_labelers = {
    lblr: events_by_labelers[lblr][events_by_labelers[lblr].map(lambda evnt: evnt.label == peyes._DataModels.EventLabelEnum.EventLabelEnum.FIXATION)] for lblr in all_labelers
}
saccades_by_labelers = {
    lblr: events_by_labelers[lblr][events_by_labelers[lblr].map(lambda evnt: evnt.label == peyes._DataModels.EventLabelEnum.EventLabelEnum.SACCADE)] for lblr in all_labelers
}

## Create Figure
#### TOP: Fixation Features
#### BOTTOM: Saccade Features

In [3]:
NAME = "fig_1"
W, H = 1600, 900

FEATURES = [peyes.constants.COUNT_STR, peyes.constants.DURATION_STR, peyes.constants.AMPLITUDE_STR]
ROW_TITLES = ["Fixations", "Saccades"]
COLUMN_TITLES = ["# Instances", "Duration (ms)", "Amplitude (°)"]
PLOT_TITLES = [f"{evnt[:-1]} {feat.title()}" for evnt in ROW_TITLES for feat in FEATURES]

FONT_FAMILY, FONT_COLOR = "Calibri", "black"
TITLE_FONT = dict(family=FONT_FAMILY, size=28, color=FONT_COLOR)
SUBTITLE_FONT = dict(family=FONT_FAMILY, size=26, color=FONT_COLOR)
AXIS_LABEL_FONT = dict(family=FONT_FAMILY, size=22, color=FONT_COLOR)
AXIS_TICK_FONT = dict(family=FONT_FAMILY, size=18, color=FONT_COLOR)
GRID_WIDTH, GRID_COLOR = 1.0, "lightgray"

In [4]:
peyes_fixations_fig = peyes.visualize.feature_comparison(
    FEATURES, *list(fixations_by_labelers.values()), labels=fixations_by_labelers.keys(), colors={k: v[1] for k, v in LABELER_PLOTTING_CONFIG.items()}
)
peyes_saccades_fig = peyes.visualize.feature_comparison(
    FEATURES, *list(saccades_by_labelers.values()), labels=saccades_by_labelers.keys(), colors={k: v[1] for k, v in LABELER_PLOTTING_CONFIG.items()}
)

In [5]:
final_fig = make_subplots(
    rows=2, cols=len(FEATURES),
    shared_yaxes='rows', shared_xaxes=False,
    vertical_spacing=0.1, horizontal_spacing=0.02,
    subplot_titles=PLOT_TITLES, column_titles=COLUMN_TITLES,
)
for r in range(2):
    existing_fig = peyes_fixations_fig if r == 0 else peyes_saccades_fig
    for c in range(len(FEATURES)):
        yaxis = "y" if c == 0 else f"y{c+1}"
        for tr in existing_fig['data']:
            if tr['yaxis'] == yaxis:
                new_tr = copy.deepcopy(tr)
                new_tr['showlegend'] = (r==0) & (c==0)
                new_tr['opacity'] = 0.95
                if c == 0:
                    new_tr['offset'] = 0
                    new_tr['width'] = 0.8
                    # add "cross" pattern to GT annotators:
                    new_tr['marker_pattern_shape'] = 'x' if new_tr['name'] in [GT1, GT2] else ''
                else:
                    new_tr['width'] = 1.8
                    new_tr['meanline'] = new_tr['box'] = None
                    new_tr['points'] = False
                final_fig.add_trace(new_tr, row=r+1, col=c+1)

# rename detector traces
for tr in final_fig.data:
    name = tr['name']
    if name in [GT1, GT2]:
        tr['name'] = tr["legendgroup"] = f"Ann. {name}"
    elif name.startswith("i"):
        tr['name'] = tr["legendgroup"] = name.replace("i", "I-").upper()
    elif name == "remodnav":
        tr['name'] = tr["legendgroup"] = "REMoDNaV"
    else:
        tr['name'] = tr["legendgroup"] = name.upper()
    if tr['width'] == 0.8:  # (differentiates Box and Violin plots)
        tr['y'] = [tr['name']]

# update axes titles and ticks
final_fig.update_yaxes(
    tickfont=AXIS_TICK_FONT,
    showgrid=True, gridcolor=GRID_COLOR, gridwidth=GRID_WIDTH,
    zeroline=True, zerolinecolor=GRID_COLOR, zerolinewidth=GRID_WIDTH,
)
final_fig.update_yaxes(
    title=dict(text="Detector", font=AXIS_LABEL_FONT, standoff=4),
    row=1, col=1,
)
final_fig.update_yaxes(
    title=dict(text="Detector", font=AXIS_LABEL_FONT, standoff=4),
    row=2, col=1,
)
final_fig.update_xaxes(
    tickfont=AXIS_TICK_FONT,
    showgrid=True, gridcolor=GRID_COLOR, gridwidth=GRID_WIDTH,
    zeroline=True, zerolinecolor=GRID_COLOR, zerolinewidth=GRID_WIDTH,
)

# update annotation sizes and locations
for ann in final_fig.layout.annotations:
    if ann.text in COLUMN_TITLES:
        ann.update(font=AXIS_LABEL_FONT, yref='paper', yanchor='top', y=-0.03,)
    elif ann.text in ROW_TITLES:
        ann.update(
            font=TITLE_FONT, textangle=0,
            xref='paper', xanchor='center', x=0.5,
            yref='paper', yanchor='top', y=1.075 if ann.text == ROW_TITLES[0] else 0.52,
        )
    elif ann.text in PLOT_TITLES:
        ann.update(
            font=SUBTITLE_FONT, textangle=0,
            xref='paper', xanchor='center', x=[0.14, 0.5, 0.825][PLOT_TITLES.index(ann.text) % 3],
            yref='paper', yanchor='top', y=1.04 if ann.text.startswith("Fixation") else 0.48,
        )

final_fig.update_layout(
    width=W, height=H,
    paper_bgcolor='rgba(0, 0, 0, 0)', plot_bgcolor='rgba(0, 0, 0, 0)',
    margin=dict(l=0, r=0, b=50, t=50, pad=0),

    # move legend to bottom
    legend=dict(orientation="h", yanchor="top", xanchor="center", xref='container', yref='container', x=0.5, y=0.05),
    showlegend=False,   # hide legend

    # set x-axis range
    xaxis2=dict(range=[50, 850], tickmode='linear', tick0=50, dtick=200),
    xaxis3=dict(range=[0, 5], tickmode='linear', tick0=0, dtick=1),
    xaxis5=dict(range=[10, 80], tickmode='linear', tick0=10, dtick=17.5),
    xaxis6=dict(range=[0, 15], tickmode='linear', tick0=0, dtick=3),
)

final_fig.write_image(os.path.join(FIGURES_DIR, f"{NAME}.png"), scale=3)
# final_fig.write_json(os.path.join(FIGURES_DIR, f"{NAME}.json"))
final_fig.show()