In [11]:
import  torch

torch.cuda.is_available()

True

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.backends.backend_pdf as pdf
from matplotlib.patches import Circle

## Clean raw data by deleting the noise before reading; do not merge samples to associations/associations etc.

In [53]:
raw = Path(f'/content/drive/MyDrive/MoTR/provo_divided_by_reader').glob('*.csv')
out = Path(f'/content/drive/MyDrive/MoTR/provo_plots_test/provo_cleaned_raw_association_not_merged')

for f in raw:
  df = pd.read_csv(f)

  df.loc[:, 'sbm_id'] = df['submission_id'].astype(str)
  df.loc[:, 'expr_id'] = df['Experiment'].astype(int)
  df.loc[:, 'cond_id'] = df['Condition'].astype(int)
  df.loc[:, 'trial_id'] = df['trial_id'].astype(int)
  df.loc[:, 'para_nr'] = df['ItemId'].astype(int)
  df.loc[:, 'word_nr'] = df['Index'].astype(int)
  df.loc[:, 'word'] = df['Word'].astype(str)
  df.loc[:, 't'] = df['responseTime'].astype(int)
  df.loc[:, 'x'] = df['mousePositionX'].astype(int)
  df.loc[:, 'y'] = df['mousePositionY'].astype(int)
  df.loc[:, 'wb'] = df['wordPositionBottom'].astype(str)
  df.loc[:, 'wt'] = df['wordPositionTop'].astype(str)
  df.loc[:, 'wl'] = df['wordPositionLeft'].astype(str)
  df.loc[:, 'wr'] = df['wordPositionRight'].astype(str)
  df.loc[:, 'response'] = df['response'].astype(str)

  dfw = df[['sbm_id', 'expr_id', 'cond_id', 'trial_id', 'para_nr', 'word_nr', 'word', 't',
            'x', 'y', 'wb', 'wt', 'wl', 'wr', 'response']]

  # # If needed, create a dictionary to map para_nr to trial sequence number
  # trial_sequence = {}
  # current_trial = 0
  # for para_nr in dfw['para_nr'].unique():
  #     trial_sequence[para_nr] = current_trial
  #     current_trial += 1

  # # Add a new column 'trial_nr' to dfw using the trial_sequence mapping
  # dfw['trial_id'] = dfw['para_nr'].map(trial_sequence)

  grouped_df = dfw.groupby(['cond_id', 'trial_id'])
  filtered_df = pd.DataFrame()

  for name, group in grouped_df:
      filtered_group = group[group['word_nr'].isin([0, 1, 2, 3])]
      if not filtered_group.empty:
        first_idx = filtered_group.index[0]
        # Delete all rows before the first row with 'word_nr' in [0, 1, 2, 3]
        group = group.loc[first_idx:]

        # Concatenate the filtered group to the filtered_df DataFrame
        filtered_df = pd.concat([filtered_df, group], ignore_index=True)


  filtered_df = filtered_df.reset_index(drop=True)
  filtered_df.to_csv(f'{out}/{f.stem}.csv')

## Define a Velocity-based association detection function.

In [93]:
def most_frequent(series):
    """
    Determine the most frequent value in a Pandas Series.
    :Parameters series: The Pandas series for which the mode is to be calculated.
    """
    if series.mode().empty:
        return "%2c%"
    mode_value = series.mode()[0]
    return mode_value if not pd.isna(mode_value) else "%2c%"


In [94]:
def ivt(gaze_data, vel_thres, dur_thres_fix_low, dur_thres_fix_high, accer_thres, dur_thres_sac):
    """
    Identify associations and saccades in gaze data using the I-VT algorithm.

    :param gaze_data: DataFrame with columns ['x', 'y', 't'] for gaze points.
    :param vel_thres: Velocity threshold to differentiate saccades from associations.
    :param dur_thres_fix_low: Duration threshold to confirm associations (in milliseconds).
    :param dur_thres_fix_high: Duration threshold to confirm associations (in milliseconds).
    :param accer_thres: Accerelation threshold to detect potential saccades.
    :param dur_thres_sac: Duration threshold to confirm saccades (in milliseconds).
    :return: DataFrame with an additional 'type' column labeling each point as 'association' or 'saccade'.
    """
    # gaze_data = gaze_data.dropna()

    # Calculate distances and velocities
    dx = np.diff(gaze_data['x'])
    dy = np.diff(gaze_data['y'])
    dt = np.diff(gaze_data['t'])

    # # Prepend 0 to the differences
    dx = np.insert(dx, 0, 0)
    dy = np.insert(dy, 0, 0)
    dt = np.insert(dt, 0, 0)

    # Avoid division by zero
    dt = dt.astype(float)
    dt[dt == 0] = 1e6

    distances = np.sqrt(dx**2 + dy**2)
    velocities = distances / dt
    dv = np.diff(velocities)
    dv = np.insert(dv, 0, 0)
    acceleration = dv / dt

    # Classify points as associations, saccades or slidings (association_vel + 0.05 px/ms).

    gaze_data['type'] = np.where(
        (np.absolute(acceleration) < accer_thres) & (velocities < vel_thres),
    'association',
    np.where(
        (np.absolute(acceleration) < accer_thres) & (velocities < vel_thres + 0.05),
        'sliding',
        'saccade'
      )
    )
    gaze_data['velocities'] = velocities
    gaze_data['acceleration'] = acceleration

    # Exclude rows where word_nr is -100 (for comprehension question)
    gaze_data = gaze_data[gaze_data['word_nr'] != -100]

    # Group consecutive points and filter based on duration
    gaze_data['group'] = (gaze_data['type'] != gaze_data['type'].shift()).cumsum()
    associations = gaze_data[gaze_data['type'] == 'association'].groupby('group').filter(lambda x: dur_thres_fix_low <= (x['t'].iloc[-1] - x['t'].iloc[0]) <= dur_thres_fix_high)
    saccades = gaze_data[gaze_data['type'] == 'saccade'].groupby('group').filter(lambda x: (x['t'].iloc[-1] - x['t'].iloc[0]) >= dur_thres_sac)


    association_centroid = associations.groupby('group').agg(
    sbm_id=('sbm_id', 'first'),
    expr_id=('expr_id', 'first'),
    cond_id=('cond_id', 'first'),
    trial_id = ('trial_id', 'first'),
    para_nr=('para_nr', 'first'),
    word_nr=('word_nr', most_frequent),
    word=('word', most_frequent),
    x_mean=('x', 'mean'),
    y_mean=('y', 'mean'),
    start_t=('t', 'first'),
    end_t=('t', 'last')
    )
    association_centroid['duration'] = association_centroid['end_t'] - association_centroid['start_t']
    # Filter out groups where word_nr is -1 or word is null (fix at blank area)
    association_centroid = association_centroid[(association_centroid['word_nr'] != -1) & (association_centroid['word'] != "%2c%")]

    # Calculate statistics for each saccade group
    saccade_stats = saccades.groupby('group').agg(
        sbm_id=('sbm_id', 'first'),
        expr_id=('expr_id', 'first'),
        cond_id=('cond_id', 'first'),
        trial_id = ('trial_id', 'first'),
        para_nr=('para_nr', 'first'),
        mean_x=('x', 'mean'),
        mean_y=('y', 'mean'),
        start_x=('x', 'first'),
        end_x=('x', 'last'),
        start_y=('y', 'first'),
        end_y=('y', 'last'),
        start_t=('t', 'first'),
        end_t=('t', 'last'),
        # mean_velocity=('velocity', 'mean'),
        # mean_acceleration=('acceleration', 'mean')
    )
    saccade_stats['duration'] = saccade_stats['end_t'] - saccade_stats['start_t']

    return association_centroid, saccade_stats, gaze_data.iloc[:, 1:-1]

## Read in all the files and get associations, saccades

In [None]:
# Here, the setting of thres is different from association
vel_thres = 0.1
dur_thres_fix_low = 100
dur_thres_fix_high = 3000
accer_thres = 0.01
dur_thres_sac = 60

reading_data_path = Path(f'/content/drive/MyDrive/MoTR/provo_cleaned_raw_association_not_merged')

# Iterate over each file in the directory
for file_path in reading_data_path.iterdir():
    # Check if it's a file and not a directory
    if file_path.is_file():
        print(f"Processing file: {file_path}")
        reader = str(file_path).split('/')[-1][:-4]
        print(f"Reader: {reader}")
        reading_data = pd.read_csv(file_path)
        reading_data.rename(columns={
        'submission_id': 'sbm_id',
        'Experiment': 'expr_id',
        'Condition': 'cond_id',
        'ItemId': 'para_nr',
        'Index': 'word_nr',
        'Word': 'word',
        'responseTime': 't',
        'mousePositionX': 'x',
        'mousePositionY': 'y',
        # Uncomment and rename other columns if needed
        # 'wordPositionBottom': 'wb',
        # 'wordPositionTop': 'wt',
        # 'wordPositionLeft': 'wl',
        # 'wordPositionRight': 'wr',
        'response': 'response'
        }, inplace=True)

        all_associations = []
        all_saccades = []
        all_gaze = []

        for para_nr in reading_data['para_nr'].unique():
            item_data = reading_data[reading_data['para_nr'] == para_nr]
            # Extract necessary information
            associations, saccades, gaze_data = ivt(item_data, vel_thres, dur_thres_fix_low, dur_thres_fix_high, accer_thres, dur_thres_sac)

            # Append associations and saccades to all_associations and all_saccades
            all_associations.append(associations)
            all_saccades.append(saccades)
            all_gaze.append(gaze_data)

        # Combine all item results into single DataFrames
        all_associations_df = pd.concat(all_associations, ignore_index=True)
        all_saccades_df = pd.concat(all_saccades, ignore_index=True)
        all_gaze_df = pd.concat(all_gaze, ignore_index=True)

        # Write to CSV
        all_associations_df.to_csv(f'/content/drive/MyDrive/MoTR/provo_2024/associations_2024/associations_{reader}.csv', index=False)
        all_saccades_df.to_csv(f'/content/drive/MyDrive/MoTR/provo_2024/Saccades_2024/Saccades_{reader}.csv', index=False)
        all_gaze_df.to_csv(f'/content/drive/MyDrive/MoTR/provo_2024/associations_Saccades_Slidings/associations_Saccades_Slidings_unmerged_{reader}.csv', index=False)

## Define a function, take two dfs which has been grouped over cond as arguement.

In [96]:
def generate_circle_points(center_x, center_y, radius, ax, num_points=100):
    # Calculate aspect ratio
    aspect_ratio = ax.get_data_ratio()

    # Adjust the radius for the x and y coordinates
    radius_x = radius
    radius_y = radius * 1.3 * aspect_ratio

    # Generate points for the circle
    theta = np.linspace(0, 2 * np.pi, num_points)
    x_points = center_x + radius_x * np.cos(theta)
    y_points = center_y + radius_y * np.sin(theta)

    return x_points, y_points

In [97]:
import textwrap
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerPatch, HandlerLine2D
from matplotlib.legend_handler import HandlerBase

# Define new colors
soft_blue = (100/255, 149/255, 237/255)  # Cornflower Blue
dark_blue = (70/255, 130/255, 180/255)   # Steel Blue
light_grey = (220/255, 220/255, 220/255) # Gainsboro
light_purple = (237/255, 239/255, 248/255)
subtle_orange = (255/255, 165/255, 0/255) # Orange
subtle_grey = (200/255, 200/255, 200/255)
bright_green = (0/255, 255/255, 0/255)
bright_orange = (255/255, 165/255, 0/255)
Magenta = (255/255, 0/255, 255/255)

class CustomAssociationHandler(HandlerBase):
    def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
        # Height for the patch (half of the total height)
        patch_height = height // 2

        # Create patch (upper half)
        patch = mpatches.Rectangle([xdescent, ydescent + patch_height], width, patch_height,
                                   color=light_purple, alpha=0.5, transform=trans)

        # Create line (bottom of the patch)
        line_y_position = ydescent
        line = mlines.Line2D([xdescent, xdescent + width], [line_y_position, line_y_position],
                             color=subtle_grey, linestyle='dashed', linewidth=1, transform=trans)

        return [patch, line]

In [98]:
def visualization_multi(grouped_df_cond, grouped_dff_cond, grouped_association_data, grouped_saccade_data, start_index, fig, axes):
  """
  plot the reading path with associations and regressions marked.
  """
  title_fontsize = 8
  word_fontsize = 7
  tick_fontsize = 7
  num_rows = len(axes)

  # Convert the grouped object to a list for easier indexing
  grouped_list = list(grouped_df_cond)
  wrapper = textwrap.TextWrapper(width=180)

  for i in range(num_rows):
      group_index = start_index + i
      if group_index < len(grouped_list):
          para_nr, group_cond = grouped_list[group_index]
          time = group_cond['t'].tolist()[:-1]
          x = group_cond['x'].tolist()[:-1]
          y = group_cond['y'].tolist()[:-1]

          ax = axes[i]  # Simplified indexing

          # Plot x and y on the subplot
          ax.plot(time, x)
          ax.plot(time, y)

      wrapped_text = wrapper.fill(text=group_cond['text'].iloc[0])
      if (group_cond['correct'] == 1).all():
          ax.set_title(f"Trial {group_cond['trial_id'].iloc[0]}: {wrapped_text}", color='black', fontsize=title_fontsize)
      else:
          ax.set_title(f"Trial {group_cond['trial_id'].iloc[0]}: {wrapped_text}", color='red', fontsize=title_fontsize)

      if para_nr in grouped_dff_cond.groups:
          group_dff_cond = grouped_dff_cond.get_group(para_nr)
          for index, f in group_dff_cond.iterrows():
              if f['end_t'] > time[-1]:
                  end_t = time[-1]
              else:
                  end_t = f['end_t']
              ax.axvspan(f['start_t'], end_t, color=light_purple)
              ax.axvline(end_t, color=subtle_grey, linestyle='dashed', linewidth=1)
              word = f['word']
              word_nr = f['word_nr'] + 1
              x_pos = (f['start_t'] + end_t) / 2
              y_pos = y[-1] + 250
              ax.text(x_pos, y_pos, f"{word_nr} {word}", rotation=90, ha='center', va='center', fontsize=word_fontsize)

      if para_nr in grouped_association_data.groups:
          association_df = grouped_association_data.get_group(para_nr)
          for _, association in association_df.iterrows():
              center_x = association['start_t'] + association['duration'] / 2
              center_y = association['x_mean']
              radius = association['duration'] / 2

              x_points, y_points = generate_circle_points(center_x, center_y, radius, ax)
              ax.plot(x_points, y_points, color=(245/255, 245/255, 245/255), linewidth=0.8)
              ax.fill(x_points, y_points, color=Magenta, alpha=0.5)


      if para_nr in grouped_saccade_data.groups:
          saccade_df = grouped_saccade_data.get_group(para_nr)
          band_width = 30
          for _, saccade in saccade_df.iterrows():
              start_time = saccade['start_t']
              end_time = saccade['end_t']
              # start_x = saccade['start_x']
              # end_x = saccade['end_x']

              # Use boolean indexing to find the rows where 't' is between start_time and end_time
              subset = group_cond[(group_cond['t'] >= start_time) & (group_cond['t'] <= end_time)]

              # Extract the relevant segments of 't' and 'x' from the subset
              subselected_time = subset['t'].tolist()
              subselected_x = subset['x'].tolist()

              # Plotting the subselected segment
              if subselected_time and subselected_x:

                  # Calculate the upper and lower boundaries of the band
                  upper_bound = [x_val + band_width / 2 for x_val in subselected_x]
                  lower_bound = [x_val - band_width / 2 for x_val in subselected_x]

                  # Plot the band
                  ax.fill_between(subselected_time, lower_bound, upper_bound, color=bright_green, alpha=0.5)

      association_circle = mlines.Line2D([], [], color=Magenta, marker='o', markersize=5, label='association', linestyle='None')
      saccade_patch = mpatches.Patch(color=bright_green, alpha=0.5, label='Saccade')
      horizontal_line = mlines.Line2D([], [], color=dark_blue, label='Horizontal Movement')
      vertical_line = mlines.Line2D([], [], color=(222/255, 154/255, 96/255), label='Vertical Movement')
      association_patch = mpatches.Patch(color=light_purple, alpha=0.5)
      association_line = mlines.Line2D([], [], color=subtle_grey, linestyle='dashed', linewidth=1)

      #Create a legend for the plot
      legend_elements = [horizontal_line, vertical_line, (association_patch, association_line), association_circle, saccade_patch]

      legend_labels = ['Horizontal Movement', 'Vertical Movement', 'Association', 'association', 'Saccade']

      # Adjust y-axis limits, set labels, and add the legend
      ax.set_xlabel('time(ms)')
      ax.set_ylabel('position in pixels')
      ax.legend(handles=legend_elements, labels=legend_labels,
                handler_map={association_line: CustomAssociationHandler()},
                loc='upper left', fontsize=6)

      custom_ticks = np.arange(0, time[-1], 1000)
      ax.set_xticks(custom_ticks)
      ax.tick_params(axis='both', which='both', labelsize=tick_fontsize)

  # Adjust spacing between subplots
  plt.subplots_adjust(hspace=0.6, wspace=0.2)
  plt.tight_layout()

  # Show the plot
  plt.show()

## Plot Association-association-Saccade for multiple files


In [None]:
trial_data = Path('/content/drive/MyDrive/MoTR/trial_data/provo_items.tsv')

reading_data_path = Path(f'/content/drive/MyDrive/MoTR/provo_plots_test/provo_cleaned_raw_association_not_merged')

# Iterate over each file in the directory
for file_path in reading_data_path.iterdir():
    # Check if it's a file and not a directory
    if file_path.is_file():
        print(f"Processing file: {file_path}")
        reader = str(file_path.stem)
        print(f"Reader: {reader}")

        association_data = Path(f'/content/drive/MyDrive/MoTR/provo_associations/2_provo_association_160ms/{reader}_clean.csv')
        # Check if association_data exists
        if not association_data.exists():
            print(f"Association data for {reader} does not exist. Skipping...")
            continue


        association_data = Path(f'/content/drive/MyDrive/MoTR/provo_2024/associations_2024/associations_{reader}.csv')
        saccade_data = Path(f'/content/drive/MyDrive/MoTR/provo_2024/Saccades_2024/Saccades_{reader}.csv')

        dfw = pd.read_csv(file_path)
        dff = pd.read_csv(association_data)
        dft = pd.read_csv(trial_data, sep='\t')

        df_associations = pd.read_csv(association_data)
        df_saccades = pd.read_csv(saccade_data)
        # Group association and saccade data by 'para_nr'
        grouped_association_data = df_associations.groupby('trial_id')
        grouped_saccade_data = df_saccades.groupby('trial_id')

        # Create a PDF file to save the plots
        pdf_filename = f'/content/drive/MyDrive/MoTR/provo_2024/AFS_Plots/AFS_plot_{reader}.pdf'
        pdf_pages = pdf.PdfPages(pdf_filename)

        # # Check if association_data exists
        # if Path(pdf_filename).exists():
        #     print(f"Plot for {reader} already exist. Skipping...")
        #     continue

        condition = 1
        dfw_cond = dfw[dfw['cond_id'] == condition]
        dft_cond = dft[dft['condition_id'] == condition]
        dff_cond = dff[dff['cond_id'] == condition]
        df_associations_cond = df_associations[df_associations['cond_id'] == condition]
        df_saccades_cond = df_saccades[df_saccades['cond_id'] == condition]

        # todo: add trial id to association files
        dff_cond = dff_cond[['para_nr', 'word_nr', 'word', 'duration', 'start_t', 'end_t', 'x_mean', 'y_mean']]

        dft_cond = dft_cond[['experiment_id', 'condition_id', 'item_id', 'text', 'response_true']]
        new_column_name = {'experiment_id': 'expr_id', 'condition_id': 'cond_id', 'item_id': 'para_nr'}
        dft_cond = dft_cond.rename(columns=new_column_name)

        df_cond = pd.merge(dfw_cond, dft_cond, on=['expr_id', 'cond_id', 'para_nr'])
        df_cond = df_cond.assign(correct=0)
        df_cond.loc[df_cond['response_true'] == df_cond['response'], 'correct'] = 1
        df_cond = df_cond[['cond_id', 'trial_id', 'para_nr', 'word_nr', 'word', 'text', 't', 'x', 'y', 'response', 'response_true', 'correct']]
        # try to get trial_id for association df.
        trial_id_mapping = df_cond.drop_duplicates('para_nr').set_index('para_nr')['trial_id'].to_dict()

        # Map the trial_id to dff_cond using the created dictionary
        dff_cond['trial_id'] = dff_cond['para_nr'].map(trial_id_mapping)

        grouped_df_cond = df_cond.groupby('trial_id')
        grouped_dff_cond = dff_cond.groupby('trial_id')

        num_rows = 3
        num_cols = 1

        for i in range(0, len(grouped_df_cond), num_rows):
            fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 16))
            plt.subplots_adjust(hspace=0.3)

            visualization_multi(grouped_df_cond, grouped_dff_cond, grouped_association_data, grouped_saccade_data, i, fig, axes)
            pdf_pages.savefig(fig)

        # Close the PDF file
        pdf_pages.close()

        print(f'Plots saved to {pdf_filename}.')




In [89]:
readers = ['3']

for reader in readers:
  trial_data = Path('/content/drive/MyDrive/MoTR/trial_data/provo_items.tsv')

  reading_data = Path(f'/content/drive/MyDrive/MoTR/provo_plots_test/provo_cleaned_raw_association_not_merged/reader_{reader}.csv')

  association_data = Path(f'/content/drive/MyDrive/MoTR/provo_associations/2_provo_association_160ms/reader_{reader}_clean.csv')
  association_data = Path(f'/content/drive/MyDrive/MoTR/provo_plots_test/associations_reader_{reader}.csv')
  saccade_data = Path(f'/content/drive/MyDrive/MoTR/provo_plots_test/Saccades_reader_{reader}.csv')

  dfw = pd.read_csv(reading_data)
  dff = pd.read_csv(association_data)
  dft = pd.read_csv(trial_data, sep='\t')

  df_associations = pd.read_csv(association_data)
  df_saccades = pd.read_csv(saccade_data)
  grouped_association_data = df_associations.groupby('trial_id')
  grouped_saccade_data = df_saccades.groupby('trial_id')


  # Create a PDF file to save the plots
  pdf_filename = f'/content/drive/MyDrive/MoTR/provo_plots_test/reader_{reader}_plots3.pdf'
  pdf_pages = pdf.PdfPages(pdf_filename)

  condition = 1
  dfw_cond = dfw[dfw['cond_id'] == condition]
  dft_cond = dft[dft['condition_id'] == condition]
  dff_cond = dff[dff['cond_id'] == condition]
  df_associations_cond = df_associations[df_associations['cond_id'] == condition]
  df_saccades_cond = df_saccades[df_saccades['cond_id'] == condition]

  # todo: add trial id to association files
  dff_cond = dff_cond[['para_nr', 'word_nr', 'word', 'duration', 'start_t', 'end_t', 'x_mean', 'y_mean']]

  dft_cond = dft_cond[['experiment_id', 'condition_id', 'item_id', 'text', 'response_true']]
  new_column_name = {'experiment_id': 'expr_id', 'condition_id': 'cond_id', 'item_id': 'para_nr'}
  dft_cond = dft_cond.rename(columns=new_column_name)

  df_cond = pd.merge(dfw_cond, dft_cond, on=['expr_id', 'cond_id', 'para_nr'])
  df_cond = df_cond.assign(correct=0)
  df_cond.loc[df_cond['response_true'] == df_cond['response'], 'correct'] = 1
  df_cond = df_cond[['cond_id', 'trial_id', 'para_nr', 'word_nr', 'word', 'text', 't', 'x', 'y', 'response', 'response_true', 'correct']]
  # try to get trial_id for association df.
  trial_id_mapping = df_cond.drop_duplicates('para_nr').set_index('para_nr')['trial_id'].to_dict()

  # Map the trial_id to dff_cond using the created dictionary
  dff_cond['trial_id'] = dff_cond['para_nr'].map(trial_id_mapping)


  # df_associations_cond = df_associations_cond.sort_values(by='trial_id')
  # df_saccades_cond = df_saccades_cond.sort_values(by='trial_id')
  # df_cond = df_cond.sort_values(by='trial_id')
  # dff_cond = dff_cond.sort_values(by='trial_id')

  # Group association and saccade data by 'para_nr'
  grouped_df_cond = df_cond.groupby('trial_id')
  grouped_dff_cond = dff_cond.groupby('trial_id')

  num_rows = 3
  num_cols = 1

  for i in range(0, len(grouped_df_cond), num_rows):
      fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 16))
      plt.subplots_adjust(hspace=0.3)

      visualization_multi(grouped_df_cond, grouped_dff_cond, grouped_association_data, grouped_saccade_data, i, fig, axes)
      pdf_pages.savefig(fig)

  # Close the PDF file
  pdf_pages.close()

  print(f'Plots saved to {pdf_filename}.')


Output hidden; open in https://colab.research.google.com to view.