# Visual stimulation synchronization

Notebook to get some numbers on frame drop and closed loop VR delay.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# select session
import matplotlib as mpl
import seaborn as sns
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'
from matplotlib import cm
from cottage_analysis.analysis import spheres
from cottage_analysis.preprocessing.synchronisation import find_monitor_frames
from cottage_analysis.preprocessing import find_frames
from v1_depth_map.figure_utils import get_session_list
import flexiznam as flz

In [None]:
VERSION=10
SAVE_ROOT = Path(
    f"/camp/lab/znamenskiyp/home/shared/presentations/v1_manuscript_2023/ver{VERSION}"
)
SAVE_ROOT.mkdir(parents=True, exist_ok=True)

In [None]:
project = "hey2_3d-vision_foodres_20220101"
flexilims_session = flz.get_flexilims_session(project_id=project)
sessions = get_session_list.get_sessions(
    flexilims_session=flexilims_session,
    exclude_sessions=(),
    exclude_openloop=False,
    exclude_pure_closedloop=False,
    v1_only=True,
)
print(f"Found {len(sessions)} sessions for closed loop only")

mice = [sess.split('_')[0] for sess in sessions]
mice = set(mice)
print(f"{len(mice)} mice")

valid_mice = ['PZAH6.4b', 'PZAG3.4f']
sessions_2 = [s for s in sessions if s.split('_')[0] in valid_mice]
sessions_5 = [s for s in sessions if s.split('_')[0] not in valid_mice]
print(f"Found {len(sessions_2)} sessions for 2 color square")
print(f"Found {len(sessions_5)} sessions for 5 color square")

In [None]:
# Concatenate all sessions from 2 color square
photodiode_protocol = 2
all_sessions_2 = []
project_recordings = flz.get_entities(datatype='recording', flexilims_session=flexilims_session)
for session in sessions_2:
    sess_df = flz.get_entity(name=session, datatype='session', flexilims_session=flexilims_session)
    recording = project_recordings[project_recordings.origin_id==sess_df['id']]
    recording = recording[recording.protocol=='SpheresPermTubeReward'].iloc[0]
    recording = spheres.get_str_or_recording(recording, flexilims_session)
    monitor_frames_df = find_monitor_frames(
        vis_stim_recording=recording,
        flexilims_session=flexilims_session,
        photodiode_protocol=photodiode_protocol,
        harp_recording=None,
        onix_recording=None,
        conflicts='skip',
        sync_kwargs=None,
        verbose=False,
    )
    monitor_frames_df['session'] = session
    all_sessions_2.append(monitor_frames_df)
all_sessions_2 = pd.concat(all_sessions_2)

summary_df_2 = {}
for sess, df in all_sessions_2.groupby('session'):
    ifi = np.diff(df['peak_time'])
    frame_rate = 1/np.nanmedian(ifi)
    avg_frame_rate = 1/np.nanmean(ifi)
    skip = ifi > 1.5 / frame_rate
    perc_skip = np.sum(skip) / len(skip) * 100
    summary_df_2[sess] = dict(frame_rate=frame_rate, avg_frame_rate=avg_frame_rate, perc_skip=perc_skip)
summary_df_2 = pd.DataFrame(summary_df_2).T
ok = 100 - summary_df_2.perc_skip
desc = ok.describe()
print(f"Percentage of frame displayed correctly: {desc['mean']:.2f} +/- {desc['std']:.2f} %")
desc_frame_rate = summary_df_2.avg_frame_rate.describe()
print(f"Average frame rate: {desc_frame_rate['mean']:.2f} +/- {desc_frame_rate['std']:.2f} Hz")

# pick and example session with an average frame rate
avg = summary_df_2.avg_frame_rate
closest = np.argmin(np.abs(summary_df_2.avg_frame_rate - avg))
example_session = summary_df_2.index[closest]
print(f"Example session: {example_session}")


# Get the photodiode trace for the example session for 2 color square
from cottage_analysis.io_module.harp import load_harpmessage
from cottage_analysis.io_module.visstim import get_frame_log, get_param_log
sess_df = flz.get_entity(name=example_session, datatype='session', flexilims_session=flexilims_session)
recording = project_recordings[project_recordings.origin_id==sess_df['id']]
recording = recording[recording.protocol=='SpheresPermTubeReward'].iloc[0]
recording = spheres.get_str_or_recording(recording, flexilims_session)

vis_stim_recording = spheres.get_str_or_recording(recording, flexilims_session)
harp_message, harp_ds = load_harpmessage(
        recording=recording,
        flexilims_session=flexilims_session,
        conflicts="skip",
    )
photodiode = harp_message['photodiode']
ao_time = harp_message['analog_time']
monitor_frames_df = find_monitor_frames(
        vis_stim_recording=recording,
        flexilims_session=flexilims_session,
        photodiode_protocol=photodiode_protocol,
        harp_recording=None,
        onix_recording=None,
        conflicts='skip',
        sync_kwargs=None,
        verbose=False,
    )
frame_log = get_frame_log(flexilims_session, harp_recording=recording,
        vis_stim_recording=recording,
    )

# plot the photodiode trace for example session of 2 color square
start = len(photodiode) // 2 + 10000
samples = 300
t_part = 1000 * (ao_time[start:start+samples] - ao_time[start])
fig = plt.figure(figsize=(20, 5))
ax = fig.add_subplot(3,2,1)
data = photodiode[start:start+samples].reshape(1, -1)
data = np.vstack([np.ones_like(data), data])
ax.imshow(data, aspect='auto', cmap='Greys',
           extent=[t_part[0], t_part[-1], 0, 2], interpolation='None')
ax.plot(t_part, data[1] / data[1].max() + 1, 'k')
ax.set_yticks([0.5])
ax.set_xlabel('Time (ms)')


ax1 = fig.add_subplot(3,2,3)
pt = monitor_frames_df['peak_time'].values
ifi = np.diff(pt)
t0 = ao_time[start]
valid = (pt > t0) & (pt < t0 + 0.3)
ax1.plot((pt[valid] - t0), 1/ifi[valid[:-1]], '.k')
ax.scatter((pt[valid] - t0)*1000, np.arange(valid.sum()) % 2 *0.8+ 1.1 , c='k')

ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Frame rate (Hz)')

In [None]:
# Concatenate all sessions from 5 color square
photodiode_protocol = 5
all_sessions_5 = []
for session in sessions_5:
    project_recordings = flz.get_entities(datatype='recording', flexilims_session=flexilims_session)
    sess_df = flz.get_entity(name=session, datatype='session', flexilims_session=flexilims_session)
    recording = project_recordings[project_recordings.origin_id==sess_df['id']]
    recording = recording[recording.protocol=='SpheresPermTubeReward'].iloc[0]
    recording = spheres.get_str_or_recording(recording, flexilims_session)
    monitor_frames_df = find_monitor_frames(
        vis_stim_recording=recording,
        flexilims_session=flexilims_session,
        photodiode_protocol=photodiode_protocol,
        harp_recording=None,
        onix_recording=None,
        conflicts='skip',
        sync_kwargs=None,
        verbose=False,
    )
    monitor_frames_df['session'] = session
    all_sessions_5.append(monitor_frames_df)
all_sessions_5 = pd.concat(all_sessions_5)

summary_df_5 = {}
for sess, df in all_sessions_5.groupby('session'):
    ifi = np.diff(df['peak_time'])
    frame_rate = 1/np.nanmedian(ifi)
    avg_frame_rate = 1/np.nanmean(ifi)
    skip = ifi > 1.5 / frame_rate
    perc_skip = np.sum(skip) / len(skip) * 100
    summary_df_5[sess] = dict(frame_rate=frame_rate, avg_frame_rate=avg_frame_rate, perc_skip=perc_skip, lag=np.nanmean(df.lag))
summary_df_5 = pd.DataFrame(summary_df_5).T
ok = 100 - summary_df_5.perc_skip
desc = ok.describe()
print(f"Percentage of frame displayed correctly: {desc['mean']:.2f} +/- {desc['std']:.2f} %")
desc_frame_rate = summary_df_5.avg_frame_rate.describe()
print(f"Average frame rate: {desc_frame_rate['mean']:.2f} +/- {desc_frame_rate['std']:.2f} Hz")
desc_lag = (summary_df_5.lag *1000).describe()
print(f"Average lag: {desc_lag['mean']:.2f} +/- {desc_lag['std']:.2f} ms")
# pick and example session with an average frame rate
avg = summary_df_5.lag
closest = np.argmin(np.abs(summary_df_5.lag - avg))
example_session = summary_df_5.index[closest]
print(f"Example session: {example_session}")

# get the photodiode trace for example session of 5 color square
from cottage_analysis.io_module.harp import load_harpmessage
from cottage_analysis.io_module.visstim import get_frame_log, get_param_log
sess_df = flz.get_entity(name=example_session, datatype='session', flexilims_session=flexilims_session)
recording = project_recordings[project_recordings.origin_id==sess_df['id']]
recording = recording[recording.protocol=='SpheresPermTubeReward'].iloc[0]
recording = spheres.get_str_or_recording(recording, flexilims_session)

vis_stim_recording = spheres.get_str_or_recording(recording, flexilims_session)
harp_message, harp_ds = load_harpmessage(
        recording=recording,
        flexilims_session=flexilims_session,
        conflicts="skip",
    )
photodiode = harp_message['photodiode']
ao_time = harp_message['analog_time']
running = harp_message['rotary_meter']
monitor_frames_df = find_monitor_frames(
        vis_stim_recording=recording,
        flexilims_session=flexilims_session,
        photodiode_protocol=photodiode_protocol,
        harp_recording=None,
        onix_recording=None,
        conflicts='skip',
        sync_kwargs=None,
        verbose=False,
    )
frame_log = get_frame_log(flexilims_session, harp_recording=recording,
        vis_stim_recording=recording,
    )

In [None]:
# Find the software lag
# we get distance from rotary meter and find the frame where this distance is reached
dst = harp_message["rotary_meter"].cumsum() * 100
running_index = ao_time.searchsorted(frame_log["HarpTime"].values)
running_at_frame = dst[running_index]

delays = np.zeros_like(running_at_frame) + np.nan
for frame, f_series in frame_log.iterrows():
    mouse_z = f_series.MouseZ
    actual_dst = running_at_frame[frame]
    if actual_dst < mouse_z:
        continue
    delay = 0
    findex = running_index[frame]
    # look only at moment where the mouse is running > 5cm/s
    window = 100
    if (dst[findex] - dst[findex - window]) < 5 * window / 1000:
        continue
    while dst[findex - delay] > mouse_z:
        delay += 1
    part_searched = dst[findex - delay : findex]
    if any(np.diff(part_searched) < 0):
        continue
    delays[frame] = delay / 1000  # because 1kHz
    
# plot encoder position and frame log position for example session of 5 color square
start = len(photodiode) // 2 + 10000
t0 = ao_time[start]
p0 = dst[start]
samples = 60
plt.figure(figsize=(3, 3))
plt.subplot(1,1,1)
plt.plot(ao_time[start:start+samples] - t0, dst[start:start+samples]-p0, label='Encoder position')
log_part = frame_log[(frame_log.HarpTime > t0) & (frame_log.HarpTime < t0 + samples/1000)]
plt.plot(log_part.HarpTime.values - t0, log_part.MouseZ.values - p0,  label='Frame log position', drawstyle='steps-post')
plt.xlabel('Time (s)')
plt.ylabel('Position (cm)')
plt.legend()


In [None]:
# add software and display lag to the monitor frames
matched_frame_id = monitor_frames_df.closest_frame.values
ok = ~np.isnan(matched_frame_id) & ~np.isnan(monitor_frames_df.lag)
match_frame_time = np.nan * np.ones(len(ok))
match_frame_time[ok] = frame_log.loc[matched_frame_id[ok].astype(int)].HarpTime.values
display_lag = monitor_frames_df["peak_time"].values - match_frame_time
dl = np.ones(len(display_lag)) * np.nan
dl[ok] = display_lag[ok]
monitor_frames_df["display_lag"] = dl
# then add the hardware lag
sl = np.ones(len(display_lag)) * np.nan
sl[ok] = delays[matched_frame_id[ok].astype(int)]
monitor_frames_df["software_lag"] = sl
monitor_frames_df['total_lag'] = monitor_frames_df['display_lag'] + monitor_frames_df['software_lag']

w = 7
dt = np.nanmedian(np.diff(ao_time))
position_ao = np.cumsum(running)
rs_filt = (position_ao[w:] - position_ao[:-w]) / w / dt
frame_dt = np.nanmedian(frame_log['HarpTime'].diff())
frame_log['speed'] = frame_log['MouseZ'].diff() / frame_dt
frame_log['speed'] = frame_log['speed'].rolling(3).mean()

In [None]:
# plot main figure
fontsize_dict = {"title": 7, "label": 7, "tick": 5, "legend": 5}
import matplotlib
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
med_ifi = np.median(np.diff(monitor_frames_df.peak_time))
start = len(photodiode) // 2 + 6200
samples = 500
t0 = ao_time[start]
t_part = 1000 * (ao_time[start:start+samples] - t0)
fig = plt.figure(figsize=(4, 4))

seq = frame_log['PhotoQuadColor'].values
seq_ind = frame_log.HarpTime.values.searchsorted(ao_time[start:start+samples])
seq_val = seq[seq_ind]
shifted_ind = monitor_frames_df.offset_time.values.searchsorted(ao_time[start:start+samples])
shifted_seq_ind = monitor_frames_df.closest_frame.values[shifted_ind]
shifted_seq = np.zeros_like(shifted_seq_ind) + np.nan
shifted_seq[~np.isnan(shifted_seq_ind)] = frame_log['PhotoQuadColor'].loc[shifted_seq_ind[~np.isnan(shifted_seq_ind)]].values
pd_part = photodiode[start:start+samples].astype(float)
pd_part -= pd_part.min()
pd_part /= pd_part.max()

## plot the photodiode
ax = fig.add_subplot(2,1,1)
nper_line = 1
data = np.zeros((nper_line * 3 + 2,len(t_part)))
data[:nper_line] += seq_val
data[nper_line] = np.nan
data[nper_line + 1:nper_line * 2 + 1] += pd_part 
data[nper_line * 2 + 1] = np.nan
data[nper_line * 2 + 2:nper_line * 3 + 2] += shifted_seq
img = ax.imshow(data, aspect='auto', cmap='Greys_r',
           extent=[t_part[0], t_part[-1], 0, 8], interpolation='None')
valid_frame = monitor_frames_df[(monitor_frames_df.peak_time > t0) & (monitor_frames_df.peak_time < t0 + samples/1000)]
valid_log = frame_log[(frame_log.HarpTime > t0-med_ifi) & (frame_log.HarpTime < t0 + samples/1000 + med_ifi)]
frame_log_time = (valid_log.HarpTime - t0 - med_ifi/2)*1000
real_frame_time = (valid_frame.peak_time - t0)*1000
colors = cm.Greys_r
for ind, log in valid_log.iterrows():
    if ind in valid_frame.closest_frame.values:
        matching = valid_frame[valid_frame.closest_frame == ind]
        assert len(matching) == 1
        matching = matching.iloc[0]
        real_time = real_frame_time.loc[matching.name]
        ax.plot([frame_log_time.loc[ind], real_time], [6.4, 4.8], color='k', zorder=-2, lw=1, ls=':')
        
ax.set_yticks([4, 7.5])
ax.set_yticklabels(['Photodiode', 'Sequence'])
ax.set_xlim(0, 150)
ax.set_ylim(3, 8)
ax.set_xlabel('Time (ms)', fontsize=fontsize_dict["label"])
ax.tick_params(axis='both', which='major', labelsize=fontsize_dict["tick"])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)


## plot the running speed or position
ax1 = fig.add_subplot(2,2,3)
start = len(photodiode) // 2 + 63700 
t0 = ao_time[start]
samples = 200
t_part = (ao_time[start:start+samples] - t0)
m_df_part = monitor_frames_df[(monitor_frames_df.onset_time > t0) & (monitor_frames_df.offset_time < t0 + samples/1000)]
frame_i = m_df_part.closest_frame.values
pos_part = position_ao[start:start+samples].astype(float) * 100
log_pos = np.zeros(len(frame_i)) + np.nan
log_pos[~np.isnan(frame_i)] = frame_log.loc[frame_i[~np.isnan(frame_i)], 'MouseZ']
p0 = log_pos.min()
ax1.plot(t_part * 1000, (pos_part - p0), 'k',  lw=1, label='Real position')
ax1.plot((m_df_part.onset_time.values - t0) *1000, (log_pos - p0), 'dodgerblue', lw=1, drawstyle='steps-post',label='VR position')
ax1.legend(loc='upper left', ncol=1, bbox_to_anchor=(0.0, 1.0,0.9,0.1), labelspacing=0.1,
           frameon=False, borderaxespad=0, fontsize=fontsize_dict["legend"])
ax1.set_ylabel('Position (cm)', fontsize=fontsize_dict["label"])
# ax3.set_ylim(0, 70)
common_time = ao_time[start:start+samples].searchsorted(m_df_part.onset_time)
pos_common = pos_part[common_time]
if False:
    ax1twin = ax1.twinx()
    ax1twin.plot((m_df_part.onset_time - t0) *1000, pos_common - log_pos, 'indianred', lw=1, drawstyle='steps-post',)
    ax1twin.set_ylabel('Position error (cm)', color='indianred')
    ax1twin.spines["top"].set_visible(False)
    ax1twin.set_yticks([0, 1, 2], labels=['0', '1', '2'], color='indianred')
ax1.set_xlabel('Time (ms)', fontsize=fontsize_dict["label"])
ax1.spines["top"].set_visible(False)
ax1.spines["right"].set_visible(False)
ax1.set_xlim(0, 150)
ax1.tick_params(axis='both', which='major', labelsize=fontsize_dict["tick"])

# Histogram of the lags
ax2 = fig.add_subplot(2,2,4)
ax2.hist(
    1000 * (monitor_frames_df.total_lag),
    color="k",
    alpha=0.6,
    density=True,
    bins=np.arange(0, 70),
    label="Total",
)
ax2.set_xticks(np.round(np.arange(10) / frame_rate * 1000).astype(int))
ax2.set_xlabel("Lag (ms)", fontsize=fontsize_dict["label"])
ax2.set_ylabel("Proportion of Frames", fontsize=fontsize_dict["label"])
ax2.set_yticks([0, 0.3])
ax2.set_xlim(14, 44)
ax2.tick_params(axis='both', which='major', labelsize=fontsize_dict["tick"])
ax2.spines["top"].set_visible(False)
ax2.spines["right"].set_visible(False)
fig.tight_layout()

fig.savefig(SAVE_ROOT/'lag_example.svg', bbox_inches='tight', dpi=300)