## Generate figures to test camera and behavioral data synchronization

### Setup
Here we will import the libraries that we will use for our analysis, and we will define some functions for importing the data from the matlab file format.

In [1]:
# import libraries
import cv2
import imageio
import glob
import os
import matplotlib.image as mpimg
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
import math

from scipy.io import loadmat, matlab
from pathlib import Path

In [2]:
# define general functions 
def mat_obj_to_dict(mat_struct):
    """Recursive function to convert nested matlab struct objects to dictionaries."""
    dict_from_struct = {}
    for field_name in mat_struct.__dict__['_fieldnames']:
        dict_from_struct[field_name] = mat_struct.__dict__[field_name]
        if isinstance(dict_from_struct[field_name], matlab.mio5_params.mat_struct):
            dict_from_struct[field_name] = mat_obj_to_dict(dict_from_struct[field_name])
        elif isinstance(dict_from_struct[field_name], np.ndarray):
            try:
                dict_from_struct[field_name] = mat_obj_to_array(dict_from_struct[field_name])
            except TypeError:
                continue
    return dict_from_struct


def mat_obj_to_array(mat_struct_array):
    """Construct array from matlab cell arrays.
    Recursively converts array elements if they contain mat objects."""
    if has_struct(mat_struct_array):
        array_from_cell = [mat_obj_to_dict(mat_struct) for mat_struct in mat_struct_array]
        array_from_cell = np.array(array_from_cell)
    else:
        array_from_cell = mat_struct_array

    return array_from_cell


def has_struct(mat_struct_array):
    """Determines if a matlab cell array contains any mat objects."""
    return any(
        isinstance(mat_struct, matlab.mio5_params.mat_struct) for mat_struct in mat_struct_array)

def convert_mat_file_to_dict(mat_file_name):
    """
    Convert mat-file to dictionary object.
    It calls a recursive function to convert all entries
    that are still matlab objects to dictionaries.
    """
    data = loadmat(mat_file_name, struct_as_record=False, squeeze_me=True)
    for key in data:
        if isinstance(data[key], matlab.mio5_params.mat_struct):
            data[key] = mat_obj_to_dict(data[key])
    return data


### Set data inputs
Enter the information here for what session and animal you want to look at

In [3]:
# set inputs
animal = 26
date = 211006
session = 1

save_new_video_images = 1 # set to 1 to overwrite existing video frame temp files

### Import the data
Here we will import the virmen behavioral .mat files and the camera .mp4 files. We will convert these data into python objects that we can work with to test the synchronization

In [4]:
# import virmen data
virmen_filename = Path(f"/Volumes/labs/singer/Virmen Logs/UpdateTask/S{animal}_{date}_{session}/virmenDataRaw.mat")
virmen_data = convert_mat_file_to_dict(virmen_filename)
virmen_df = pd.DataFrame(virmen_data['virmenData']['data'], columns=virmen_data['virmenData']['dataHeaders'])
virmen_df

Unnamed: 0,time,xPos,yPos,viewAngle,transVeloc,rotVeloc,taskState,currentZone,choice,trialType,...,protocolChanged,trialTypeUpdate,updateCue,updateOccurred,delayUpdateOccurred,teleportOccurred,syncVoltage,syncTrigger,syncPulse,cameraTrigger
0,738435.424993,0.000000,5.000000,0.0,0.000000,0.525131,2.0,0.0,,1.0,...,0.0,1.0,,0.0,0.0,1.0,,0.0,0.0,0.0
1,738435.424995,0.000000,5.000000,0.0,0.704664,0.607533,2.0,0.0,,1.0,...,0.0,1.0,,0.0,0.0,1.0,,0.0,0.0,0.0
2,738435.424995,0.000000,5.000000,0.0,19.191762,0.607533,2.0,0.0,,1.0,...,0.0,1.0,,0.0,0.0,1.0,,0.0,0.0,0.0
3,738435.424995,0.000000,5.000000,0.0,27.358958,0.607533,2.0,0.0,,1.0,...,0.0,1.0,,0.0,0.0,1.0,,0.0,0.0,0.0
4,738435.424996,0.000000,5.000000,0.0,25.030635,0.623616,2.0,0.0,,1.0,...,0.0,1.0,,0.0,0.0,1.0,,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
380750,738435.466502,-10.112429,15.539565,0.0,-21.555825,0.348832,8.0,0.0,,2.0,...,0.0,1.0,,0.0,0.0,1.0,,0.0,0.0,0.0
380751,738435.466502,-10.112304,15.451952,0.0,-20.431581,0.346029,8.0,0.0,,2.0,...,0.0,1.0,,0.0,0.0,1.0,,0.0,0.0,0.0
380752,738435.466502,-10.112175,15.364635,0.0,-21.485862,0.346029,8.0,0.0,,2.0,...,0.0,1.0,,0.0,0.0,1.0,,0.0,0.0,0.0
380753,738435.466502,-10.112042,15.277319,0.0,-20.306141,0.346029,8.0,0.0,,2.0,...,0.0,1.0,,0.0,0.0,1.0,,0.0,0.0,0.0


In [5]:
# import video frames
video_dir = Path(f"/Volumes/labs/singer/CameraData/UpdateTask/S{animal}_{date}/")
video_file = video_dir / "*.mp4"
video_file = str(video_file)
video_file = glob.glob(video_file)[session-1]
cap = cv2.VideoCapture(video_file)

In [6]:
# check that the virmen data and the camera data have the same number of frames so that we can properly align them
video_frame_count= int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
virmen_frame_count = sum(virmen_df.cameraTrigger)

assert video_frame_count == virmen_frame_count

### Prepare example trial data for plots
Here, we will select a few example trials to make animated figures for

In [7]:
# get trial start and end indices
task_state_names = ('trial_start', 'initial_cue', 'update_cue', 'delay_cue', 'choice_made', 'reward', 'trial_end', 'inter_trial')
task_state_dict = dict(zip(task_state_names,range(1,9)))

trial_starts = virmen_df.index[virmen_df.taskState == task_state_dict['trial_start']].to_numpy()
trial_ends = virmen_df.index[virmen_df.taskState == task_state_dict['trial_end']].to_numpy()

# only use trials during which camera was on
camera_trig = virmen_df.index[virmen_df.cameraTrigger == 1].to_list() 
trial_starts = trial_starts[np.logical_and(trial_starts >= camera_trig[0], trial_starts <= camera_trig[-1])]
trial_ends = trial_ends[np.logical_and(trial_ends >= camera_trig[0], trial_ends <= camera_trig[-1])]

# select first three and last three trials as examples
n_trials = 1
ts_first = trial_starts[0:n_trials]
te_first = trial_ends[trial_ends > trial_starts[0]][0:n_trials]
ts_last = trial_starts[trial_starts < trial_ends[-1]][-n_trials-1:-1]
te_last = trial_ends[-n_trials-1:-1]

trial_intervals = list(zip(np.concatenate((ts_first, ts_last)), np.concatenate((te_first, te_last))))

In [8]:
if save_new_video_images:
    # save images for each frame from those trials
    for i, interval in enumerate(trial_intervals):
        print(f"Saving frames for trial {i}")
        camera_frames = np.where(np.logical_and(camera_trig > interval[0], camera_trig < interval[1]))[0]  
        frame_no = camera_frames[0]

        cap.set(cv2.CAP_PROP_POS_FRAMES,frame_no)
        while frame_no <= camera_frames[-1]:
            ret, frame = cap.read() 
            name = Path(f"{video_dir}/temp/frame{frame_no}.png")
            name = str(name)
            cv2.imwrite(name, frame)
            frame_no += 1

    cap.release()
    cv2.destroyAllWindows()

Saving frames for trial 0


KeyboardInterrupt: 

### Generate animated figures
Here, we will make a series of plots using the camera frames and the behavioral data 

In [23]:
# pos_window = 60 #samples before and after to plot for position

for i, interval in enumerate(trial_intervals):
    trial_mask = np.logical_and(camera_trig > interval[0], camera_trig < interval[1])
    camera_frames = np.where(trial_mask)[0]  
    camera_inds_in_df = np.array(camera_trig)[trial_mask]
 
    # make plot for each camera frame
    print(f"Making gif for trial {i}")
    plot_filenames = []
    for ind, frame_no in enumerate(camera_frames):
        # read in the image
        frame_filename = Path(f"{video_dir}/temp/frame{frame_no}.png")
        img = mpimg.imread(frame_filename)
        
        # extract the current position
        # pre_ind = max(interval[0], camera_inds_in_df[ind]-pos_window)   
        # pre_ind = interval[0]
        current_ind = camera_inds_in_df[ind]
        x_pos = virmen_df.xPos[interval[0]+1:current_ind]
        y_pos = virmen_df.yPos[interval[0]+1:current_ind]

        # establish track boundaries
        tx1, ty1 = [2,2], [1,250]
        tx2, ty2 = [-2,-2], [1,250]
        tx3, ty3 = [2,10], [250,260]
        tx4, ty4 = [-10,-2], [260,250]
        tx5, ty5 = [-10,-10], [1,280]
        tx6, ty6 = [10,10], [1,280]
        tx7, ty7 = [-10,10], [1,1]
        tx8, ty8 = [-10,10], [280,280]
        
        # extract lick data (if 1, lick detected)
        lick = virmen_df.numLicks[current_ind] - virmen_df.numLicks[current_ind - 1]
        circle1 = plt.Circle((0,0), 0.5*lick, color = 'w')
        
        # extract view angle
        view_angle = virmen_df.viewAngle[current_ind]
        dx_angle = -math.cos(view_angle + math.pi/2)
        dy_angle = math.sin(view_angle + math.pi/2)
        view_angle2 = virmen_df.viewAngle[interval[0]+1:current_ind]
        
        # establish view angle box
        vax1, vay1 = [-1.1,-1.1], [0,1.1]
        vax2, vay2 = [1.1,1.1], [0,1.1]
        vax3, vay3 = [-1.1,1.1], [0,0]
        vax4, vay4 = [-1.1,1.1], [1.1,1.1]
        
        # establish lick detector box
        ldx1, ldy1 = [-0.6,-0.6], [-0.6, 0.6]
        ldx2, ldy2 = [0.6,0.6], [-0.6, 0.6]
        ldx3, ldy3 = [-0.6,0.6], [-0.6,-0.6]
        ldx4, ldy4 = [-0.6, 0.6], [0.6,0.6]
        
        # extract time
        time = virmen_df.time[interval[0]:current_ind]
        trial_start_time = virmen_df.time[interval[0]]
        trial_end_time = virmen_df.time[interval[-1]]
        time_elapsed = (time-trial_start_time)*60*60*24 # to convert to seconds
        
        # extract velocity
        rot_veloc = virmen_df.rotVeloc[interval[0]:current_ind]
        trans_veloc = virmen_df.transVeloc[interval[0]:current_ind]
        
        # plot the data        
        mosaic = """
                BBAAAACC
                BBAAAADD
                BBEEFFNN
                """
        plt.style.use('dark_background')
        ax_dict = plt.figure(constrained_layout=True, figsize=(20,10)).subplot_mosaic(mosaic)
        
        ax_dict["A"].imshow(img)
        ax_dict["A"].set_xticks([])
        ax_dict["A"].set_yticks([])
        
        ax_dict["B"].plot(x_pos, y_pos)
        ax_dict["B"].plot(tx1, ty1, tx2, ty2, tx3, ty3, tx4, ty4, tx5, ty5, tx6, ty6, tx7, ty7, tx8, ty8, color = 'w')
        ax_dict["B"].set_xlabel('X position (au)')
        ax_dict["B"].set_ylabel('Y position (au)')
        ax_dict["B"].set_title('Position')
        ax_dict["B"].set_ylim(1,300)
        ax_dict["B"].set_xlim(10,-10)
        ax_dict["B"].axis('off')
        
        ax_dict["C"].plot(time_elapsed, rot_veloc)
        ax_dict["C"].set_xlabel('Time (s)')
        ax_dict["C"].set_ylabel('Speed (au)')
        ax_dict["C"].set_title('Rotational velocity')
        ax_dict["C"].set_xlim(0, (trial_end_time-trial_start_time)*60*60*24)
        ax_dict["C"].set_ylim(-2.75,2.75)
        
        ax_dict["D"].plot(time_elapsed, trans_veloc)
        ax_dict["D"].set_xlabel('Time (s)')
        ax_dict["D"].set_ylabel('Speed (au)')
        ax_dict["D"].set_title('Translational velocity')
        ax_dict["D"].set_xlim(0, (trial_end_time-trial_start_time)*60*60*24)
        ax_dict["D"].set_ylim(0,60)
        
        ax_dict["E"].plot(vax1, vay1, vax2, vay2, vax3, vay3, vax4, vay4, color = 'w')
        ax_dict["E"].arrow(0, 0, dx_angle, dy_angle, head_width = 0.05, head_length = 0.1)
        ax_dict["E"].set_title('View Angle')
        ax_dict["E"].set_xlim(-1.1,1.1)
        ax_dict["E"].set_ylim(0,1.1)
        ax_dict["E"].axis('off')
        
        ax_dict["F"].plot(y_pos, view_angle2)
        ax_dict["F"].set_title('View Angle')
        ax_dict["F"].set_xlabel('Y Position (au)')
        ax_dict["F"].set_ylabel('View Angle')
        ax_dict["F"].set_xlim(0,300)
        ax_dict["F"].set_ylim(-1.5,1.5)
        
        ax_dict["N"].axis('off')
        
        # save the entire plot for each camera frame
        plot_filename = Path(f"{video_dir}/results/camera/trial{i}_frame{frame_no}.png")
        plot_filenames.append(plot_filename)
        plt.savefig(plot_filename)
        plt.close()
    
    # build gif
    gif_filename = Path(f"{video_dir}/results/camera/trial{i}.gif")
    with imageio.get_writer(gif_filename, mode='I') as writer:
        for filename in plot_filenames:
            image = imageio.imread(filename)
            writer.append_data(image)

Making gif for trial 0
Making gif for trial 1
