# Visual stimulation monitor synchronisation

Bonsai controls visual display and saves a log everytime it asks for a frame to be rendered. However there is an unknown delay (~2 frames) between this render frame event and the actual display time. Furthermore some frames are skipped.

To figure out which frame is displayed when we put a photodiode in front of the monitor and display a pseudo-random sequence of alternating grey value. 

This notebook show how we use the frame logger and the photodiode signal to determine exact frame identity at each point of time. It has 3 main steps

**1. Detect frames on photodiode signal**

**2. Cross-correlate frame with sequence to find expected lag**

**3. Match cross correlation results to frame logger**


### Load example data

This next section just loads one example recording to be used for the rest of the notebook.

Define the session we want:

In [None]:
PROJECT = "hey2_3d-vision_foodres_20220101"
MOUSE = "PZAH8.2i"
SESSION = "S20230209"
RECORDING = "R174123_SpheresPermTubeRewardPlayback"
PROTOCOL = "SpheresPermTubeReward"

In [None]:
PROJECT = "blota_onix_pilote"
MOUSE = "BRYA142.5d"
SESSION = "S20231005"
PROTOCOL = "SpheresPermTubeReward"

In [None]:
use_onix = True
visstim_in_harp = False
photodiode_protocol = 5
sync_kwargs = dict(frame_detection_height=0.05)
conflicts = "overwrite"

In [None]:
import flexiznam as flz
from cottage_analysis.preprocessing.synchronisation import find_monitor_frames

flm_sess = flz.get_flexilims_session(PROJECT)
sess = flz.get_entity(flexilims_session=flm_sess, name=f"{MOUSE}_{SESSION}")
harp_recording = flz.get_entity(
    origin_id=sess.id,
    datatype="recording",
    query_key="protocol",
    query_value="harpdata",
    flexilims_session=flm_sess,
)
if visstim_in_harp:
    vis_stim_recording = harp_recording
else:
    vis_stim_recording = flz.get_entity(
        origin_id=sess.id,
        datatype="recording",
        query_key="protocol",
        query_value=PROTOCOL,
        flexilims_session=flm_sess,
    )
if use_onix:
    onix_recording = flz.get_entity(
        origin_id=sess.id,
        datatype="recording",
        query_key="protocol",
        query_value="onix",
        flexilims_session=flm_sess,
    )

Load it

In [None]:
import scipy.signal
from cottage_analysis.utilities.misc import get_str_or_recording
from cottage_analysis.io_module.harp import load_harpmessage
from cottage_analysis.io_module import onix as onix_io
from cottage_analysis.io_module.visstim import get_frame_log
from cottage_analysis.io_module.spikes import load_kilosort_folder
from cottage_analysis.preprocessing import onix as onix_prepro


project = None
assert conflicts in ["skip", "overwrite", "abort"]
if flexilims_session is None:
    assert project is not None, "project must be provided if flexilims_session is None"
    flexilims_session = flz.get_flexilims_session(project_id=project)

vis_stim_recording = get_str_or_recording(vis_stim_recording, flexilims_session)
if harp_recording is None:
    harp_recording = vis_stim_recording
else:
    harp_recording = get_str_or_recording(harp_recording, flexilims_session)
    onix_recording = get_str_or_recording(onix_recording, flexilims_session)

assert harp_recording is not None, "harp_recording must be provided"
assert onix_recording is not None, "onix_recording must be provided"
assert vis_stim_recording is not None, "vis_stim_recording must be provided"
# Get frame log
frame_log = get_frame_log(
    flexilims_session,
    harp_recording=harp_recording,
    vis_stim_recording=vis_stim_recording,
)

# Create output and reload
monitor_frames_ds = flz.Dataset.from_origin(
    origin_id=vis_stim_recording["id"],
    dataset_type="monitor_frames",
    flexilims_session=flexilims_session,
    conflicts=conflicts,
)
if monitor_frames_ds.flexilims_status() != "not online" and conflicts == "skip":
    print("Loading existing monitor frames...")
    monitor_frames_ds = pd.read_pickle(monitor_frames_ds.path_full)
    print("Done.")

monitor_frames_ds.path = monitor_frames_ds.path.parent / f"monitor_frames_df.pickle"

# Get photodiode
raw = flz.get_data_root("raw", flexilims_session=flexilims_session)
harp_message, harp_ds = load_harpmessage(
    recording=harp_recording,
    flexilims_session=flexilims_session,
    conflicts="skip",
)
if onix_recording is None:
    # get the photodiode from harp directly
    photodiode = harp_message["photodiode"]
    analog_time = harp_message["analog_time"]
else:
    onix_ds = flz.get_datasets(
        flexilims_session=flexilims_session,
        origin_name=onix_recording.name,
        dataset_type="onix",
        allow_multiple=False,
    )
    breakout = onix_io.load_breakout(raw / onix_recording.path)
    onix_data = onix_prepro.preprocess_onix_recording(
        dict(breakout_data=breakout), harp_message=harp_message
    )
    if "aio_mapping" in onix_ds.extra_attributes:
        ch_pd = onix_ds.extra_attributes["aio_mapping"]["photodiode"]
    else:
        ch_pd = onix_prepro.ANALOG_INPUTS.index("photodiode")
    photodiode = onix_data["breakout_data"]["aio"][ch_pd, :]
    analog_time = onix_data["onix2harp"](onix_data["breakout_data"]["aio-clock"])
    # to make it faster, decimate the photodiode signal
    if False:  ### X TEMP TO DEBUG
        photodiode = scipy.signal.decimate(photodiode, 5)
        analog_time = analog_time[::5]


recording_duration = frame_log.HarpTime.values[-1] - frame_log.HarpTime.values[0]
frame_rate = 1 / frame_log.HarpTime.diff().median()
print(f"Recording is {recording_duration:.0f} s long.")
print(f"Frame rate is {frame_rate:.0f} Hz.")

# Get frames from photodiode trace, depending on the photodiode protocol is 2 or 5
diagnostics_folder = monitor_frames_ds.path_full.parent / "diagnostics" / "frame_sync"
diagnostics_folder.mkdir(parents=True, exist_ok=True)

In [None]:
onix_ds = flz.get_datasets(
    origin_id=onix_recording["id"],
    dataset_type="onix",
    flexilims_session=flexilims_session,
    allow_multiple=False,
)

from cottage_analysis.io_module.onix import load_breakout

breakout = load_breakout(onix_ds.path_full, num_ai_chan=3)

In [None]:
visstim_ds.extra_attributes

In [None]:
import pandas as pd

visstim_ds = flz.get_datasets(
    origin_id=vis_stim_recording["id"],
    dataset_type="visstim",
    flexilims_session=flexilims_session,
    allow_multiple=False,
)
csvs = visstim_ds.extra_attributes["csv_files"]
param_log = pd.read_csv(visstim_ds.path_full / csvs["ParamLog"])
frame_log = pd.read_csv(visstim_ds.path_full / csvs["FrameLog"])
reward_log = pd.read_csv(visstim_ds.path_full / csvs["RewardLog"])

In [None]:
frame_log.head()

In [None]:
import numpy as np
from matplotlib import pyplot as plt

plt.figure(figsize=(20, 5))
ax = plt.subplot(2, 1, 1)
b, e = 2.012e7 + np.array([2618.7, 2619.25])
v = (harp_message["analog_time"] < e) & (harp_message["analog_time"] > b)
ax.plot(harp_message["analog_time"][v], harp_message["photodiode"][v])
v = (analog_time < e) & (analog_time > b)
ax1 = plt.subplot(2, 1, 2, sharex=ax)
m = photodiode[v].mean()
M = photodiode[v].max()
normed = (photodiode[v] - m) / (M - m)
ax1.plot(analog_time[v], normed)
v = frame_log.HarpTime < e
ax1.plot(
    frame_log.HarpTime[v] - 30e-3,
    frame_log["PhotoQuadColor"][v],
    drawstyle="steps-post",
)

In [None]:
from cottage_analysis.utilities import continuous_data_analysis as cda

sampling = 1 / np.diff(analog_time).mean()
filt_pd = cda.filter(
    photodiode, sampling, lowcut=None, highcut=700, design="butter", axis=-1
)
v = (analog_time < e) & (analog_time > b)
plt.figure(figsize=(20, 2))
plt.plot(analog_time[v], filt_pd[v])

In [None]:
normed_pd = np.array(filt_pd, dtype=float)
normed_pd -= np.quantile(normed_pd, 0.01)
normed_pd /= np.quantile(normed_pd, 0.99)

valid = analog_time < frame_log.HarpTime.values[-1]
from cottage_analysis.preprocessing.find_frames import *

# First step: Frame detection
time_column = "HarpTime"
frame_detection_height = 0.05
do_plot = True
save_folder = None
verbose = True
ignore_errors = False
last_frame_delay = 100
frames_df, db_dict, figs = create_frame_df(
    frame_log=frame_log,
    photodiode_time=analog_time[valid],
    photodiode_signal=normed_pd[valid],
    time_column=time_column,
    frame_rate=frame_rate,
    height=frame_detection_height,
    do_plot=do_plot,
    verbose=verbose,
    debug=True,
    save_folder=save_folder,
    ignore_errors=ignore_errors,
    last_frame_delay=last_frame_delay,
)

In [None]:
ampl = normed_pd[valid][frames_df.peak_sample.values]
amp_t = frames_df.peak_time.values
fig = plt.figure(figsize=(20, 5))
ax = fig.add_subplot(111)
_ = ax.hist(ampl, bins=1000)
borders = [-1, 0.04, 0.1, 0.19, 0.4, 1.1]
for b in borders[1:-1]:
    ax.axvline(b, color="grey", linestyle="--")

In [None]:
amp_bined = (np.digitize(ampl, borders)).astype(float) - 1
amp_bined /= amp_bined.max()
v, c = np.unique(amp_bined, return_counts=True)
c = c / len(amp_bined)
{v0: np.round(c0, 3) for v0, c0 in zip(v, c)}

In [None]:
seq = np.loadtxt(
    "/nemo/lab/znamenskiyp/home/shared/transfer/random_sequence_5values_alternate.csv"
)
seq

In [None]:
long_seq = np.tile(seq, 1 + len(amp_bined) // len(seq))[: len(amp_bined)]


def shift_diff(lag):
    # shift the sequence by lag and compute the sum of square difference
    shifted = np.roll(long_seq, lag)
    return np.sum((amp_bined - shifted) ** 2)


# minimize the sum of square difference
from scipy.optimize import minimize

res = minimize(shift_diff, (0,), method="Nelder-Mead")

In [None]:
from scipy import signal

corr = signal.correlate(seq, amp_bined, mode="valid")
lags = signal.correlation_lags(len(amp_bined), len(seq), mode="valid")

plt.plot(lags, corr)
max_lag = lags[np.argmax(corr)]
print(max_lag)
plt.xlim(max_lag - 100, max_lag + 100)

In [None]:
print(harp_recording.name)
plt.plot(harp_message["analog_time"], harp_message["photodiode"])

In [None]:
n = int(10e6)
for i in range(breakout["aio"].shape[0]):
    plt.plot(np.arange(n)[::10], breakout["aio"][i, :n:10] + i * 10000, label=str(i))
_ = plt.legend()

In [None]:
print(onix_recording.name)
plt.plot(analog_time[::10] - analog_time[0], photodiode[::10])
plt.xlim(0, 1)

# Normal usage

This is using the main master function:

In [None]:
from cottage_analysis.preprocessing import find_frames

processed = flz.get_data_root("processed", project=PROJECT)
params = dict(
    time_column="HarpTime",
    sequence_column="PhotoQuadColor",
    num_frame_to_corr=6,
    maxlag=3.0 / frame_rate,
    expected_lag=2.0 / frame_rate,
    frame_rate=frame_rate,
    correlation_threshold=0.8,
    relative_corr_thres=0.02,
    frame_detection_height=0.1,
    minimum_lag=1.0 / frame_rate,
    do_plot=True,
    save_folder=diagnostics_folder,
    verbose=True,
    ignore_errors=False,
)
if sync_kwargs is not None:
    params.update(sync_kwargs)

if False:
    frames_df, db_dict = find_frames.sync_by_correlation(
        frame_log,
        analog_time,
        photodiode,
        **params,
    )

# Detailled description

How does it work? The alignment is made in 3 steps:

- detect frames
- crosscorrelated with expected sequence
- align results

## Detect frames

The frame detection is simple: filter a bit to smooth local extrema, `diff` to find fast changes and detect peaks on that `diff` trace. This should detect all frame borders. In between these borders, look for the `diff` minimum to find the frame peak (be it a maximum or a minium)

Detection can be done independently using `detect_frame_onset`.


## Create frame df

In [None]:
for k, v in params.items():
    print(f"{k}={v}")

In [None]:
time_column = "HarpTime"
sequence_column = "PhotoQuadColor"
num_frame_to_corr = 6
maxlag = 0.02102399244904518
expected_lag = 0.014015994966030123
frame_rate = 142.6941151767892
correlation_threshold = 0.8
relative_corr_thres = 0.02
frame_detection_height = 0.05
minimum_lag = 0.007007997483015061
do_plot = True
save_folder = "/camp/lab/znamenskiyp/home/shared/projects/blota_onix_pilote/BRYA142.5d/S20231010/R142857_SpheresPermTubeReward/diagnostics/frame_sync"
verbose = True
ignore_errors = False
debug = True
last_frame_delay = 100

In [None]:
photodiode_time = analog_time
photodiode_signal = photodiode

from cottage_analysis.preprocessing.find_frames import *

pd_sampling = 1 / np.mean(np.diff(photodiode_time))

# Normalise photodiode signal
normed_pd = np.array(photodiode_signal, dtype=float)
normed_pd -= np.quantile(normed_pd, 0.01)
normed_pd /= np.quantile(normed_pd, 0.99)

# First step: Frame detection
frames_df, db_dict, figs = create_frame_df(
    frame_log=frame_log,
    photodiode_time=photodiode_time,
    photodiode_signal=normed_pd,
    time_column=time_column,
    frame_rate=frame_rate,
    height=frame_detection_height,
    do_plot=do_plot,
    verbose=verbose,
    debug=debug,
    save_folder=save_folder,
    ignore_errors=ignore_errors,
    last_frame_delay=last_frame_delay,
)
ndetected = len(frames_df)
npresented = len(frame_log)
if npresented < ndetected:
    msg = (
        f"Detected more frames ({ndetected}) than presented ({npresented})"
        "\n Check create_frame_df parameters"
    )
elif npresented > ndetected * 2:
    msg = (
        f"Dropped more than half of the frames ({npresented - ndetected} dropped)"
        "\n Check create_frame_df parameters"
    )
else:
    msg = None
print(msg)

In [None]:
if db_dict is not None:
    db_dict["normed_pd"] = normed_pd

if figs is not None:
    fig_dict = dict(frame_dection=figs)
else:
    fig_dict = dict()

## Do the correlation

In [None]:
frame_log.head()

In [None]:
if pd_sampling is None:
    pd_sampling = 1 / np.mean(np.diff(photodiode_time))
out_dict = {}
frame_onsets = frames_df["onset_sample"].values
# make lags into samples
maxlag_samples = int(np.round(maxlag * pd_sampling))
expected_lag_samples = int(np.round(expected_lag * pd_sampling))

# make an idealised photodiode signal
ideal_time, ideal_seqi_trace, ideal_pd = ideal_photodiode(
    frame_log,
    sampling_rate=pd_sampling,
    sequence_column="PhotoQuadColor",
    time_column="HarpTime",
    pad_frames=(maxlag + num_frame_to_corr) * 2,
    highcut=150,
)
if debug:
    out_dict["ideal_photodiode_trace"] = ideal_pd
    out_dict["ideal_time"] = ideal_time
    out_dict["ideal_seqi_trace"] = ideal_seqi_trace

# find the closest switch time for each frame according to computer time
real_switch_times = frame_log[time_column].values
closest_switch = np.clip(
    real_switch_times.searchsorted(photodiode_time[frame_onsets]),
    0,
    len(real_switch_times) - 1,
)
frames_df["closest_frame_log_index"] = closest_switch
# and the corresponding ideal photodiode sample
ideal_onset = frame_log["ideal_switch_samples"].iloc[closest_switch].values

In [None]:
from scipy import signal

dt = np.diff(ideal_time).mean()
p1 = np.searchsorted(ideal_time, [10, 20])
p2 = np.searchsorted(photodiode_time - t0, [10, 20])
out = signal.correlate(ideal_pd[p1[0] : p1[1]], normed_pd[p2[0] : p2[1]], mode="same")
lags = signal.correlation_lags(
    ideal_pd[p1[0] : p1[1]].size, normed_pd[p2[0] : p2[1]].size, mode="same"
)
plt.plot(lags * dt, out)
plt.xlim(-0.2, 0.2)

In [None]:
plt.plot(ideal_time, ideal_pd)
plt.xlim(0, 1)

In [None]:
ideal_onset_samples = ideal_onset
ideal_photodiode_trace = ideal_pd
ideal_frame_index = ideal_seqi_trace
expected_lag = expected_lag_samples
maxlag = maxlag_samples

In [None]:
pd_sampling = 1 / np.mean(np.diff(photodiode_time))

# define the 3 correlation windows, bef, center and aft
window = [
    np.array([-1, 1]) * maxlag
    + np.array(w * num_frame_to_corr / frame_rate * pd_sampling, dtype="int")
    for w in [np.array([-1, 0]), np.array([-0.5, 0.5]), np.array([0, 1])]
]
# for bef window, we add 1.5 frame to have half of the current frame included
window[0] += int(1.5 / frame_rate * pd_sampling)
# for center window, we shift by 0.5 frame to center
window[1] += int(0.5 / frame_rate * pd_sampling)

if verbose:
    start = time.time()
    print("Starting crosscorrelation", flush=True)
cc_mat = np.zeros((len(window), len(frame_onsets), maxlag * 2)) + np.nan
eq_ind = np.zeros((len(window), len(frame_onsets), maxlag * 2), dtype="int") - 1
residuals = np.zeros((len(window), len(frame_onsets))) + np.nan

In [None]:
iframe = 50000
foi = frame_onsets[iframe]
for iw, win in enumerate(window):
    if (win[0] + foi) < 0:
        if verbose:
            print(
                "Frame %d at sample %d is too close from start of recording"
                % (iframe, foi)
            )
        crash
    elif (win[1] + foi) > (len(photodiode_signal) - expected_lag):
        if verbose:
            print(
                "Frame %d at sample %d is too close from end of recording"
                % (iframe, foi)
            )
        crash
    elif (win[0] + ideal_onset_samples[iframe] - expected_lag) < 0:
        if verbose:
            print(
                "Frame %d at sample %d is too close from start of ideal pd"
                % (iframe, foi)
            )
        crash
    elif (win[1] + ideal_onset_samples[iframe] - expected_lag) > len(
        ideal_photodiode_trace
    ):
        if verbose:
            print(
                "Frame %d at sample %d is too close from end of ideal pd"
                % (iframe, foi)
            )
        crash
    # ideal_pd is drifting, so we need to look for the closest computer time

In [None]:
id_t = ideal_frame_index[slice(*win + ideal_onset_samples[iframe] - expected_lag)]
# we want the middle "maxlag * 2" samples, which is where correlation can
# be done
eq_ind[iw, iframe] = id_t[int(len(id_t) / 2 - maxlag) : int(len(id_t) / 2 + maxlag)]

In [None]:
t0 = frame_log.HarpTime.values[0]
plt.plot(analog_time[::5] - t0, normed_pd[::5])
plt.axvline(frame_log.HarpTime.values[0] - t0, color="r")
plt.axvline(frame_log.HarpTime.values[-1] - t0, color="r")
plt.plot(ideal_time[::10], ideal_pd[::10])
plt.xlim(-0.05, 0.2)

In [None]:
plt.figure(figsize=(15, 5))
plt.plot(normed_pd[slice(*win * 15 + foi)])
plt.plot(
    ideal_photodiode_trace[
        slice(*win * 15 + ideal_onset_samples[iframe] - expected_lag)
    ]
)

In [None]:
frame_skip = np.diff(frame_borders) > pd_sampling / frame_rate * 1.5
frames_df = pd.DataFrame(
    dict(
        onset_sample=frame_borders[:-1],
        offset_sample=frame_borders[1:],
        peak_sample=peak_index,
        include_skip=frame_skip,
    )
)
frames_df["onset_time"] = photodiode_time[frames_df.onset_sample]
frames_df["offset_time"] = photodiode_time[frames_df.offset_sample]
frames_df["peak_time"] = photodiode_time[frames_df.peak_sample]

# check if frames are detected after the presentation is over
after_last = frames_df["onset_time"] >= frame_log[time_column].iloc[-1]
print(f"Framed detected after the presentation is over: {after_last.sum()}")

In [None]:
frame_log

In [None]:
late_frames.columns

In [None]:
end_of_presentation = frame_log[time_column].iloc[-1]
last_frame = frames_df["onset_time"].iloc[-1]
print(f"Presentation ends at {end_of_presentation:.2f} s")
print(f"Last frame detected at {last_frame:.2f} s")

delay = frames_df["onset_time"].iloc[-1] - frame_log[time_column].iloc[-1]
print(f"{after_last.sum()} frames detected after the last render time.")

b, e = photodiode_time.searchsorted([end_of_presentation - 1, last_frame + 20])
fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(1, 1, 1)
ax.set_title("Recording ends before last frame detected")
ax.plot(
    photodiode_time[b:e] - end_of_presentation,
    photodiode_signal[b:e],
    label="photodiode",
)
ax.axvline(0, color="k", label="end of presentation")
late_frames = frames_df[frames_df.onset_time > end_of_presentation]
ax.scatter(
    late_frames.peak_time - end_of_presentation,
    photodiode_signal[late_frames.peak_sample],
    label="late frame",
    color="purple",
)
ax.legend(loc="best")
ax.set_xlabel("Time relative to end of presentation (s)")
ax.set_ylabel("Photodiode signal (a.u.)")
save_folder = None
if save_folder is not None:
    fig.savefig(Path(save_folder) / f"presentation_end_issue.png")

You can get an example of detection using `plot_frame_detection_report`. This will give you `num_examples * 2` figures. Half of them are selected on random frames, half are centered around a frame drop.

In [None]:
ignore_errors = True
ndetected = len(frames_df)
npresented = len(frame_log)
if npresented < ndetected:
    msg = (
        f"Detected more frames ({ndetected}) than presented ({npresented})"
        "\n Check create_frame_df parameters"
    )
elif npresented > ndetected * 2:
    msg = (
        f"Dropped more than half of the frames ({npresented - ndetected} dropped)"
        "\n Check create_frame_df parameters"
    )
else:
    msg = None
if msg is not None:
    if ignore_errors:
        warnings.warn(msg)
    else:
        raise ValueError(msg)

In [None]:
plot_window = np.array([-7.5, 7.5]) / frame_rate * pd_sampling
figs = find_frames.plot_frame_detection_report(
    border_index=frame_borders,
    peak_index=peak_index,
    debug_dict=db_dict,
    num_examples=1,
    plot_window=plot_window,
    photodiode=photodiode,
    frame_rate=frame_rate,
    photodiode_sampling=pd_sampling,
    highcut=frame_rate * 3,
)

# Crosscorrelation

After having detected frames we will try to find where each of them falls in the photodiode sequence. To do that, we start by normalising the photodiode signal between 0 and 1

In [None]:
frames_df, db_di = find_frames.run_cross_correlation(
    frames_df,
    frame_log,
    photodiode_time,
    normed_pd,
    time_column,
    sequence_column,
    num_frame_to_corr,
    maxlag,
    expected_lag,
    frame_rate,
    verbose=True,
    debug=True,
    pd_sampling=pd_sampling,
)

In [None]:
frames_df = find_frames._match_fit_to_logger(
    frames_df,
    correlation_threshold=correlation_threshold,
    relative_corr_thres=relative_corr_thres,
    minimum_lag=minimum_lag,
    verbose=True,
)

# Then interpolate the missing frames
find_frames.interpolate_sync(frames_df, verbose=True)
# and remove the last double detected frames
frames_df = find_frames._remove_double_frames(frames_df, verbose=True)

In [None]:
db_dict.update(db_di)
fig = plt.figure()
ax = fig.add_subplot(2, 1, 1)
find_frames.plot_crosscorr_matrix(
    ax, db_dict["cc_dict"], db_dict["lags_sample"], frames_df
)
ax = fig.add_subplot(2, 1, 2)
find_frames.plot_crosscorr_matrix(
    ax, db_dict["cc_dict"], db_dict["lags_sample"], frames_df
)
xl = ax.get_xlim()
mid = (xl[0] + xl[1]) / 2
ax.set_xlim(mid - 100, mid + 100)
fig.savefig(save_folder / "crosscorr_matrix.png")

In [None]:
# plt.plot(frames_df.closest_frame_log_index.values)

ok = frames_df.closest_frame_log_index.values >= (
    frames_df.closest_frame_log_index.iloc[-1] - 1
)

# plt.plot(frames_df.closest_frame_log_index.values)
ok.argmax()

In [None]:
frame = frames_df.index[919623 - 150]
fig = find_frames.plot_one_frame_check(
    frame,
    frames_df,
    frame_log,
    real_time=photodiode_time,
    normed_pd=normed_pd,
    ideal_time=db_dict["ideal_time"],
    ideal_pd=db_dict["ideal_photodiode_trace"],
    ideal_seqi=db_dict["ideal_seqi_trace"],
    num_frame_to_corr=None,
)
fig.suptitle(
    f"Frame {frame} matching frame log "
    f"{frames_df.loc[frame, 'closest_frame']}\n"
    f"Is interpolated: {not frames_df.loc[frame, 'interpolation_seeds']}"
)
fig.savefig(save_folder / f"frame_{frame}_check.png")

### Idealised photodiode

Then we generate an idealised version of what the photodiode signal should be (had their been no frame drops).

In [None]:
seq_trace, ideal_pd = find_frames.ideal_photodiode(
    time_base=ao_time,
    switch_time=frame_log["HarpTime"].values,
    sequence=frame_log["PhotoQuadColor"].values,
)

fig = plt.figure()
w = np.array([10000, 10100])
t0 = ao_time[w[0]]
plt.plot(ao_time[slice(*w)] - t0, normed_pd[slice(*w)], label="Normed photodiode")
plt.plot(
    ao_time[slice(*w)] - t0,
    seq_trace[slice(*w)],
    label="Sequence",
    color="grey",
    alpha=0.5,
)
plt.plot(ao_time[slice(*w)] - t0, ideal_pd[slice(*w)], label="Filtered sequence")
l = plt.legend()

### Data chunking

Now we want to run the crosscorrelation around each frame.

We need to take a chunk of data that is big enough but short enough. Five or 6 frames seems to get good unique match with the sequence. Use `num_frame_to_corr` to set that.

Then we need to shift the photodiode by a given lag and cut the same chunk of data to correlate. There is no point in testing all the shifts, we now it will be about 2 frames. So we have `expected_lag ~= int(2/frame_rate*ao_sampling)` (in samples). 

To make things reasonably fast we also limit the search to a 3 frames of lag (+/- around expected_lag). With `maxlag ~= int(3/frame_rate*ao_sampling)`

In [None]:
num_frame_to_corr = 6
maxlag_samples = int(np.round(3 / frame_rate * ao_sampling))  # make it into samples
expected_lag_samples = int(
    np.round(2 / frame_rate * ao_sampling)
)  # make it into samples

Finally we need to decide if we take the chunk of data before the frame, centered on the frame or after the frame. The best choice depends on if there was a frame drop recently or not. So let's just do the 3.

In [None]:
window = [
    np.array([-1, 1]) * maxlag_samples
    + np.array(w * num_frame_to_corr / frame_rate * ao_sampling, dtype="int")
    for w in [np.array([-1, 0]), np.array([-0.5, 0.5]), np.array([0, 1])]
]
# for bef window, we add 1 frame to have the current frame included
window[0] += int(1 / frame_rate * ao_sampling)
# for center window, we shift by 0.5 frame to center
window[1] += int(0.5 / frame_rate * ao_sampling)

example_frame = 5234
frame_sample = frame_borders[example_frame]
lab = ["bef", "center", "aft"]
t0 = ao_time[frame_sample]
for iw, w in enumerate(window):
    part = slice(*w + frame_sample)
    plt.plot(
        ao_time[part][maxlag_samples : -maxlag_samples + 1] - t0,
        normed_pd[part][maxlag_samples : -maxlag_samples + 1] + iw * 0.5,
        label=lab[iw],
    )
plt.axvspan(
    ao_time[frame_sample] - t0,
    ao_time[frame_borders[example_frame + 1]] - t0,
    alpha=0.5,
)
_ = plt.legend()

In [None]:
maxlag_samples

In [None]:
import matplotlib.pyplot as plt

_ = plt.hist(
    frames_df.lag.values * 1000,
    bins=np.arange(frames_df.lag.min() * 1000, frames_df.lag.max() * 1000),
)

In [None]:
bad = np.diff(frames_df.closest_frame.values) < 1
frames_df.iloc[1:][bad]

Add that to the dataframe

## Match cross correlation results to frame logger

Ideally, if there is no frame drop, it does not matter if we look at the frames perceeding or following the frame we want to sync. That should be most of the case.

In [None]:
db = db_dict["debug_info"]
normed_pd = np.array(photodiode, dtype=float)
normed_pd -= np.quantile(normed_pd, 0.01)
normed_pd /= np.quantile(normed_pd, 0.99)
db.keys()

In [None]:
rng = np.random.default_rng(102)
w = frames_df[frames_df.sync_reason == "photodiode matching"].index
random_select = [w[i] for i in rng.integers(len(w), size=10)]
bad = np.diff(frames_df.closest_frame.values) < 1
badi = np.where(bad)[0]
random_select = frames_df.iloc[badi[20] + np.array([0, 1, 2], dtype=int)].index
labels = ["bef", "center", "aft"]
num_frame_to_corr = 5
maxlag = int(5.0 / frame_rate * ao_sampling)
expected_lag = int(2.0 / frame_rate * ao_sampling)
window = [
    np.array([-1, 1]) * maxlag
    + np.array(w * num_frame_to_corr / frame_rate * ao_sampling, dtype="int")
    for w in [np.array([-1, 0]), np.array([-0.5, 0.5]), np.array([0, 1])]
]
seq_trace = db["seq_trace"]

for frame in random_select:
    # frame = frames_df[~good].index[num]
    # frame = frames_df.index[num]
    fseries = frames_df.loc[frame]
    on_s = fseries.onset_sample
    off_s = fseries.offset_sample
    on_t = fseries.onset_time
    off_t = fseries.offset_time
    w = np.array([-50, 50])
    vfdf = frames_df[
        (frames_df.onset_sample > w[0] + on_s)
        & (frames_df.offset_sample < w[1] + off_s)
    ]
    qc = np.array([fseries[["quadcolor_%s" % w for w in labels]]])
    best = fseries.crosscorr_picked
    fig = plt.figure(figsize=(7, 7))
    plt.gca().get_yaxis().set_visible(False)

    col = dict(bef="r", center="g", aft="b")
    for i in range(3):
        label = "Photodiode" if i == 1 else None
        plt.plot(
            ao_time[slice(*w + on_s)] - on_t,
            normed_pd[slice(*w + on_s)] + i,
            label=label,
            color="purple",
        )
        label = "Frame #%d" % frame if i == 1 else None
        plt.axvspan(0, off_t - on_t, color="purple", alpha=0.2, label=label)
        plt.plot(fseries.peak_time - on_t, fseries.photodiode, "o", color="purple")

    vlog = frame_log[
        (frame_log.HarpTime > w[0] / ao_sampling + on_t - fseries.lag_bef)
        & (frame_log.HarpTime < w[1] / ao_sampling + off_t)
    ]
    plt.plot(
        vlog.HarpTime.values - on_t,
        vlog.PhotoQuadColor - 1.5,
        drawstyle="steps-post",
        label="Render frame",
    )

    i = 0
    for win, lab in zip(window, ["bef", "center", "aft"]):
        cut_win = win + maxlag * np.array([1, -1], dtype=int)
        l = fseries["lag_%s" % lab]
        part = seq_trace[slice(*win + on_s)]
        cut_part = seq_trace[slice(*cut_win + on_s)]
        x = normed_pd[slice(*win + on_s)][maxlag : -maxlag + 1]

        plt.plot(
            ao_time[slice(*win + on_s)] - on_t + l,
            part + i,
            alpha=0.75,
            lw=2,
            color=col[lab],
        )
        plt.plot(
            ao_time[slice(*win + on_s)][maxlag : -maxlag + 1] - on_t,
            x + i,
            alpha=0.5,
            lw=4,
            ls="--",
            color=col[lab],
        )

        cl = fseries["closest_frame_%s" % lab]
        plt.plot(
            frame_log.iloc[cl].HarpTime - on_t,
            frame_log.iloc[cl].PhotoQuadColor - 1.5 + i / 6,
            "o",
            color=col[lab],
        )
        if lab == best:
            plt.plot(
                frame_log.iloc[cl].HarpTime - on_t,
                frame_log.iloc[cl].PhotoQuadColor - 1.5 + i / 6,
                "o",
                mfc="None",
                mec="k",
                ms=10,
                mew=2,
            )
            plt.plot(
                frame_log.iloc[cl].HarpTime - on_t + l,
                frame_log.iloc[cl].PhotoQuadColor + i,
                "o",
                color="k",
            )
            plt.plot(
                ao_time[slice(*cut_win + on_s)] - on_t + l,
                cut_part + i,
                alpha=1,
                lw=1,
                color="k",
                label="Selected match",
            )

        i += 1
        plt.title("%s" % fseries.onset_sample)

    plt.legend(loc="lower right")

In [None]:
# find what is the actual photodiode value and how does it depend on previous value

df = pd.DataFrame(frames_df.iloc[1:][["quadcolor", "photodiode"]]).reset_index()
bef = pd.DataFrame(frames_df.iloc[:-1][["quadcolor", "photodiode"]]).reset_index()
df["quadcolor_before"] = bef["quadcolor"]
df["photodiode_before"] = bef["photodiode"]
df.head()

In [None]:
mat_data = (
    df.groupby(["quadcolor_before", "quadcolor"]).aggregate(np.nanmean).photodiode
)
n_data = df.groupby(["quadcolor_before", "quadcolor"]).aggregate(len).photodiode
m = mat_data.values.reshape((5, 5))
n = n_data.values.reshape((5, 5))
plt.figure(figsize=(12, 3))
plt.subplot(1, 3, 1)
plt.imshow(m.T, origin="lower")
cm = plt.colorbar()
plt.xlabel("quad n-1")
plt.ylabel("quad n")
plt.title("Photodiode")
plt.subplot(1, 3, 2)
plt.title("Difference")
plt.imshow((m - np.linspace(0, 1, 5)).T, origin="lower", cmap="RdBu_r")
plt.xlabel("quad n-1")
plt.ylabel("quad n")
cm = plt.colorbar()
plt.subplot(1, 3, 3)
plt.title("N transitions")
plt.imshow(n.T, origin="lower")
plt.xlabel("quad n-1")
plt.ylabel("quad n")
cm = plt.colorbar()

In [None]:
fseries = frames_df.loc[4713]
fseries

In [None]:
t0 = fseries.offset_time
frame_log["HarpTime"][1940:1955] - fseries.lag_aft - t0

In [None]:
frames_df.sync_reason.value_counts()

# Divers stuff

Figures to explain things for my lab meeting (09/11/2022)

## Sequence principle

In [None]:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
b = 1000
w = 100
shift = 1
0
seq = seq_trace[b : b + w]
bad_seq = np.array(seq_trace[b + shift : b + w + shift])
bad_seq[int(w / 3) : int(w / 3 + w / 3 * 0.6)] = bad_seq[int(w / 3)]
ax.plot((ao_time[b : b + w] - ao_time[b]) * 1e3, bad_seq)
ax.plot((ao_time[b : b + w] - ao_time[b]) * 1e3, seq + 1)
ax.set_xlabel("Time (ms)")
ax.yaxis.set_visible(False)
for w in ["top", "left", "right"]:
    ax.spines[w].set_visible(False)