# Explore tracking for a sample session

Validate LED tracking data by visualization, both in terms of positional accuracy and synchronization quality. Note synchronization quality is best checked for frames at the end of the session when lag has had a chance to affect.
        - 

In [1]:
%matplotlib widget

from pathlib import Path
import sys
from warnings import warn

from dotenv import load_dotenv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# Load environmental variables
repo_path = Path.cwd().parent.parent.parent

load_dotenv()
load_dotenv(dotenv_path = repo_path/'.env')

sys.path.insert(0, str(repo_path))
from lib import utils

sys.path.insert(0, str(Path.cwd().parent))      # from Methods.video_tracking...
import loading as vload
import transform as vtran
import plotting as vplot

## Example session tracked with DeepLabCut

Include an optional debugging time-window to look at data at fine resolution and understand any problems with the alignment of trials with tracking data

In [2]:
# Settings for this example 
# (Note: this is a block where every trial should be visual or audiovisual and so we should be able to see all stimuli)
ferret = 1613
fname = 'Ariel'
block = 'J4-13'         

dlc_threshold = 0.6         # likelihood threshold
center_point = (335, 240)

# Paths to data
data_dir = Path(os.getenv("local_home")) / 'Task_Switching'

spike_path = data_dir / 'spike_times_220606_1808.hdf5'
sensor_path = data_dir / 'sensor_ev220801_1230.hdf5'

ongoing_fps_file = data_dir / 'head_tracking/ongoing_fps' / f"F{ferret}_Block_{block}.csv"
tracking_file = data_dir / 'head_tracking/LED_positions' / 'DLC_data_221123_1040.parquet'
video_file = data_dir / 'videos' / f"F{ferret}_{fname}_Block_{block}_Vid0.avi"

In [3]:
# Load ongoing frame rates 
ongoing_fps = pd.read_csv(
    ongoing_fps_file, 
    usecols = ['starttimecorrected','start_frame','prev_start_frame','prev_start_time','fps','duration']
    )

ongoing_fps.head(3)

Unnamed: 0,starttimecorrected,start_frame,prev_start_frame,prev_start_time,duration,fps
0,63.732435,1910.973056,0.0,0.0,63.732435,29.984309
1,93.111963,2792.358886,1910.973056,63.732435,29.379528,30.0
2,162.71364,4880.4092,2792.358886,93.111963,69.601677,30.0


In [4]:
# This query can lead to duplication if multiple recording files were made in block
query = """ 
SELECT  
    mcs.starttime,
    mcs.starttimecorrected as starttime_corrected,
    mcs.mcs as mcs_time
FROM (
        SELECT * 
        FROM task_switch.sessions 
        WHERE ferret = %(fnum)s 
            AND block = %(block)s
    ) s
INNER JOIN task_switch.mcs_trials_20220314 mcs
    ON s.datetime = mcs.session_dt
ORDER BY
    mcs.starttime;
"""

trials = utils.query_postgres(query, params={'fnum':ferret, 'block':block})
trials.head(3)


Unnamed: 0,starttime,starttime_corrected,mcs_time
0,64.004,63.732435,93.17895
1,92.948,93.111963,122.55975
2,162.473,162.71364,192.1646


In [5]:
# Load LED positions for each frame in requested session
LEDs = vload.load_parquet(tracking_file, fnum=ferret, block=block)
LEDs = LEDs.reset_index(names='frame')
LEDs.head()

FileNotFoundError: /home/stephen/Data/Task_Switching/head_tracking/LED_positions/DLC_data_221123_1040.parquet

In [None]:
# Load sensor data (times are relative to tdt clock)
sensor_ev = vload.load_sensor_data(sensor_path, ferret=f"F{ferret}_{fname}", block=f"Block_{block}")

## Data transformation

First we need to add some information to the FPS record to allow us to compute times after the last synchronization marker

In [None]:
# Calculate session fps as average weighted by trial duration
mean_fps = sum(ongoing_fps['fps'] * ongoing_fps['duration']) / sum(ongoing_fps['duration'])

# Create an extra trial to use for frames after the last synchronization marker
session_end = {
    'starttimecorrected' : np.inf, 
    'start_frame' : np.inf, 
    'prev_start_frame': ongoing_fps['start_frame'].max(),
    'prev_start_time': ongoing_fps['starttimecorrected'].max(),
    'fps': mean_fps,
    'duration': np.inf}

# Combine
ongoing_fps = ongoing_fps.append(session_end, ignore_index=True)
ongoing_fps.tail(3)

  ongoing_fps = ongoing_fps.append(session_end, ignore_index=True)


Unnamed: 0,starttimecorrected,start_frame,prev_start_frame,prev_start_time,duration,fps
107,2032.241182,60953.24,59954.476682,1998.949223,33.291959,30.0
108,2067.58996,62013.7,60953.235453,2032.241182,35.348778,30.0
109,inf,inf,62013.69879,2067.58996,inf,29.993229


In [None]:
# Add trial to LED tracking data
LEDs['trial'] = np.nan

for trial_i, trial_data in ongoing_fps.iterrows():
    
    frame_idx = LEDs[
        (LEDs['frame'] >= trial_data['prev_start_frame']) &
        (LEDs['frame'] < trial_data['start_frame']) 
    ].index

    LEDs.loc[frame_idx, 'trial'] = trial_i

In [None]:
# Add frame times to LED tracking data using variable FPS
LEDs['time'] = np.nan

for trial_i, trial_LEDs in LEDs.groupby('trial'):

    frames_into_trial = trial_LEDs['frame'] - ongoing_fps.loc[trial_i, 'prev_start_frame']
    time_into_trial = frames_into_trial / ongoing_fps.loc[trial_i, 'fps']
    
    LEDs.loc[trial_LEDs.index, 'time'] = ongoing_fps.loc[trial_i, 'prev_start_time'] + time_into_trial

LEDs.tail(3)

Unnamed: 0,frame,red_LEDx,red_LEDy,red_LEDlikelihood,blue_LEDx,blue_LEDy,blue_LEDlikelihood,fnum,block,trial,time
63248,63248,157.647034,226.402023,1.0,154.511749,232.08696,0.999957,1613,J4-13,109.0,2108.742622
63249,63249,157.013367,225.828293,1.0,155.007843,229.102753,0.990664,1613,J4-13,109.0,2108.775963
63250,63250,156.277374,226.187515,1.0,155.263733,226.645248,0.157828,1613,J4-13,109.0,2108.809303


In [None]:
LEDs = vtran.filter_for_low_likelihoods(LEDs, threshold=dlc_threshold)
LEDs = vtran.interpolate_missing_frames(LEDs, nframes=20)
LEDs = vtran.compute_head_pose(LEDs, method='unweighted')
# df = vtran.add_smoothing(df, width=5)
LEDs = vtran.compute_speed(LEDs, window=-8)

In [None]:
# Merge trial and FPS tables 
trials = pd.merge( trials, ongoing_fps, left_on='starttime_corrected', right_on='starttimecorrected')
trials.tail(3)

Unnamed: 0,starttime,starttime_corrected,mcs_time,starttimecorrected,start_frame,prev_start_frame,prev_start_time,duration,fps
106,1999.549,1998.949223,2028.48455,1998.949223,59954.476682,56107.549165,1870.684972,128.264251,29.992204
107,2032.841,2032.241182,2061.77805,2032.241182,60953.235453,59954.476682,1998.949223,33.291959,30.0
108,2067.52,2067.58996,2097.12845,2067.58996,62013.69879,60953.235453,2032.241182,35.348778,30.0


In [None]:
def get_event_frame(event_timestamps:np.array, trials:pd.DataFrame):
    """ 
    Estimate the video frame in which an event occurs

    Args:
        event_timestamps: array of event timestamps (according to tdt clock)
        LEDs: dataframe containing frame times (now according to tdt clock) 

    Returns:
        event_frame: closest video frame to timestamp
    """

    event_frames = [LEDs.loc[LEDs[LEDs['time']< t].index, 'frame'].max() for t in event_timestamps]

    return np.array(event_frames)

In [None]:
# Convert onset times into frames and trials
sensor_ev['onset_frame'] = get_event_frame(sensor_ev['onsets'], LEDs[['time','frame']])
sensor_ev['offset_frame'] = get_event_frame(sensor_ev['offsets'], LEDs[['time','frame']])

sensor_ev = pd.DataFrame(sensor_ev)
sensor_ev.tail(3)

Unnamed: 0,onsets,offsets,chan,onset_frame,offset_frame
771,2069.911756,2070.510592,3,62083,62101
772,2070.551142,2072.629043,3,62102,62164
773,2074.307379,2075.444838,3,62215,62249


In [None]:
# Preassign zero array
n_frames, sensor_chans = LEDs.shape[0], sensor_ev.chan.max()+1
sensor_arr = np.zeros((n_frames, sensor_chans))

# Make array values = 1 when sensor on relevant channel is high
for _, ev in sensor_ev.iterrows():
    sensor_arr[int(ev.onset_frame) : int(ev.offset_frame), int(ev.chan)] = 1

# Create dataframe from array
sensors = pd.DataFrame(sensor_arr, columns=[f"sens{s}" for s in range(0, sensor_chans)])
sensors.index.rename('frame', inplace=True)
sensors = sensors.join(LEDs[['time','frame']], on='frame')
sensors.head(3)

Unnamed: 0_level_0,sens0,sens1,sens2,sens3,time,frame
frame,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,0.0,0.0,0.0,0.0,0.0,0
1,0.0,0.0,0.0,0.0,0.033351,1
2,0.0,0.0,0.0,0.0,0.066702,2


## Plotting

In [None]:
# Get head position only when the sensors are active 
sensors_long = pd.melt(
    sensors.drop(columns=['time']), 
    id_vars='frame', 
    var_name='sensor'
)
sensors_long['sensor'] = sensors_long['sensor'].str.replace('sens','').astype(int)

active_sensors = sensors_long[sensors_long.value > 0.0]
print(f"Sensors active on {active_sensors.shape[0]} of {LEDs.shape[0]} frames")

# Check for sensors that are co-active
if active_sensors.shape[0] != active_sensors.drop_duplicates().shape[0]:
    warn(message='Senor co-activity detected')

Sensors active on 37034 of 63251 frames


In [None]:
# Combine with tracking data
head_position_on_sensors = pd.merge( 
    active_sensors, 
    LEDs[['frame','head_x','head_y','time']],
    left_on ='frame',
    right_on = 'frame'
    )

head_position_on_sensors.head(3)

Unnamed: 0,frame,sensor,value,head_x,head_y,time
0,77,0,1.0,317.522339,240.212761,2.56801
1,78,0,1.0,317.447281,240.207039,2.601361
2,80,0,1.0,318.817795,238.727539,2.668062


In [None]:
# Plot positions of head for each sensor, with separate color maps for each sensor
fig, axs = plt.subplots(1,2, **{'figsize':(9,3)})

cmaps = ['Oranges','Reds','Blues','Greens']
settings = {'vmin':0, 'vmax':LEDs['time'].max(), 's':0.1, 'alpha':0.3}

for sensor_idx, sensor_data in head_position_on_sensors.groupby('sensor'):

    scatobj = axs[0].scatter(
        x = sensor_data['head_x'].to_numpy(),
        y = sensor_data['head_y'].to_numpy(),
        c = sensor_data['time'].to_numpy(),
        cmap = cmaps[sensor_idx],
        **settings
    )

# Plot with same color scheme
scatobj = axs[1].scatter(
    x = head_position_on_sensors['head_x'].to_numpy(),
    y = head_position_on_sensors['head_y'].to_numpy(),
    c = head_position_on_sensors['time'].to_numpy(),
    cmap = 'cool',
    **settings
)

# Axis formatting
for ax in axs:
    ax.set_ylim([0, 480])
    ax.set_xlim([0, 640])
    ax.invert_yaxis()
    ax.set_facecolor('k')

cbar = plt.colorbar(scatobj, label='Time (s)')
cbar.solids.set_edgecolor("face")
plt.show()

In [None]:
vf = vplot.video_figure(
    video_file = str(video_file),
    start_time = 2000,
    duration = 100
    )

vf.add_LEDs(LEDs[['frame','time','head_x','head_y','speed']])
vf.add_Sensors(sensors)

vf.run()

In [None]:
fig, axs = plt.subplots(3,1, sharex=True, **{'figsize':(8, 3)})

plot_time = (2000, 2100)
sensor_colors = ['k','r','g','b']
settings = {'zorder':0, 'c':'darkgrey', 'legend':False}

LEDs.plot(x='time', y='head_x', ax=axs[0], **settings)
LEDs.plot(x='time', y='head_y', ax=axs[1], **settings)
LEDs.plot(x='time', y='speed', ax=axs[2], **settings)

for sensor, sens_data in head_position_on_sensors.groupby('sensor'):
    sens_data.plot.scatter(x='time', y='head_x', ax=axs[0], s=1, c=sensor_colors[sensor])
    sens_data.plot.scatter(x='time', y='head_y', ax=axs[1], s=1, c=sensor_colors[sensor])

axs[2].set_xlim(plot_time)

plt.show()