# Visual stimulation monitor synchronisation

Bonsai controls visual display and saves a log everytime it asks for a frame to be rendered. However there is an unknown delay (~2 frames) between this render frame event and the actual display time. Furthermore some frames are skipped.

To figure out which frame is displayed when we put a photodiode in front of the monitor and display a pseudo-random sequence of alternating grey value. 

This notebook show how we use the frame logger and the photodiode signal to determine exact frame identity at each point of time. It has 3 main steps

**1. Detect frames on photodiode signal**

**2. Cross-correlate frame with sequence to find expected lag**

**3. Match cross correlation results to frame logger**


### Load example data

This next section just loads one example recording to be used for the rest of the notebook.

Define the session we want:

In [None]:
PROJECT = 'hey2_3d-vision_foodres_20220101'
MOUSE = 'PZAH8.2f'
SESSION = 'S20230126'
RECORDING = 'R144331_SpheresPermTubeReward'
PROTOCOL = 'SpheresPermTubeReward'
MESSAGES = 'harpmessage.bin'

Load it

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

import flexiznam as flm
from cottage_analysis.io_module import harp

data_root = flm.PARAMETERS["data_root"]
msg = Path(data_root["raw"]) / PROJECT / MOUSE / SESSION / RECORDING / MESSAGES
p_msg = Path(data_root["processed"]) / PROJECT / MOUSE / SESSION / RECORDING / (PROTOCOL+'_suite2p_traces_0')
p_msg = p_msg / (msg.stem + ".npz")
if p_msg.is_file():
    harp_message = np.load(p_msg)
else:
    harp_message = harp.load_harp(
        msg, di_names=('frame_triggers','lick_detection','di2_encoder_initial_state')
    )
    p_msg.parent.mkdir(parents=True, exist_ok=True)
    np.savez(p_msg, **harp_message)

frame_log = pd.read_csv(msg.parent / "FrameLog.csv")
expected_sequence = (
    pd.read_csv(msg.parent / 'random_sequence_5values_alternate.csv', header=None).loc[:, 0].values
)
step_values = frame_log.PhotoQuadColor.unique()
ao_time = harp_message["analog_time"]
photodiode = harp_message["photodiode"]
ao_sampling = 1 / np.mean(np.diff(ao_time))

print("Data loaded.")
print(
    "Recording is %d s long."
    % (frame_log.HarpTime.values[-1] - frame_log.HarpTime.values[0])
)


# Normal usage

This is using the main master function:

In [None]:
from cottage_analysis.preprocessing import find_frames

frame_rate = 144
frames_df, db_dict = find_frames.sync_by_correlation(
    frame_log,
    ao_time,
    photodiode,
    time_column="HarpTime",
    sequence_column="PhotoQuadColor",
    num_frame_to_corr=6,
    maxlag=3.0 / frame_rate,
    expected_lag=2.0 / frame_rate,
    frame_rate=frame_rate,
    correlation_threshold=0.8,
    relative_corr_thres=0.02,
    minimum_lag=1.0 / frame_rate,
    do_plot=False,
    verbose=True,
    debug=True,
)


In [None]:
import pickle
save_folder = Path(data_root["processed"]) / PROJECT / MOUSE / SESSION / RECORDING / (PROTOCOL+'_suite2p_traces_0')

frames_df.to_pickle(save_folder/'frames_df.pickle')  
with open(save_folder/'db_dict.pickle', 'wb') as handle:
    pickle.dump(db_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open(save_folder/'db_dict.pickle', 'rb') as handle:
    db_dict = pickle.load(handle)
with open(save_folder/'frames_df.pickle', 'rb') as handle:
    frames_df = pickle.load(handle)

In [None]:
frames_df.columns

In [None]:
frames_df[50:100]

In [None]:
frames_df.sync_reason.value_counts()

# Detailled description

How does it work? The alignment is made in 3 steps:

- detect frames
- crosscorrelated with expected sequence
- align results

## Detect frames

The frame detection is simple: filter a bit to smooth local extrema, `diff` to find fast changes and detect peaks on that `diff` trace. This should detect all frame borders. In between these borders, look for the `diff` minimum to find the frame peak (be it a maximum or a minium)

Detection can be done independently using `detect_frame_onset`.


In [None]:
pd_sampling = 1 / np.mean(np.diff(ao_time))
out = find_frames.detect_frame_onset(
    photodiode=photodiode,
    frame_rate=frame_rate,
    photodiode_sampling=pd_sampling,
    highcut=frame_rate * 3,
    debug=True,
)
frame_borders, peak_index, db_dict = out

You can get an example of detection using `plot_frame_detection_report`. This will give you `num_examples * 2` figures. Half of them are selected on random frames, half are centered around a frame drop.

In [None]:
plot_window = np.array([-7.5, 7.5]) / frame_rate * pd_sampling
figs = find_frames.plot_frame_detection_report(
    border_index=frame_borders,
    peak_index=peak_index,
    debug_dict=db_dict,
    num_examples=1,
    plot_window=plot_window,
    photodiode=photodiode,
    frame_rate=frame_rate,
    photodiode_sampling=pd_sampling,
    highcut=frame_rate * 3,
)


## Crosscorrelation

After having detected frames we will try to find where each of them falls in the photodiode sequence. To do that, we start by normalising the photodiode signal between 0 and 1

In [None]:
normed_pd = np.array(photodiode, dtype=float)
normed_pd -= np.quantile(normed_pd, 0.01)
normed_pd /= np.quantile(normed_pd, 0.99)


### Idealised photodiode

Then we generate an idealised version of what the photodiode signal should be (had their been no frame drops).

In [None]:
seq_trace, ideal_pd = find_frames.ideal_photodiode(time_base=ao_time,
                                switch_time=frame_log['HarpTime'].values,
                                sequence=frame_log['PhotoQuadColor'].values)

fig = plt.figure()
w = np.array([10000, 10100])
t0 = ao_time[w[0]]
plt.plot(ao_time[slice(*w)] - t0, normed_pd[slice(*w)], label='Normed photodiode')
plt.plot(ao_time[slice(*w)] - t0, seq_trace[slice(*w)], label='Sequence',
         color='grey', alpha=0.5)
plt.plot(ao_time[slice(*w)] - t0, ideal_pd[slice(*w)], label='Filtered sequence')
l = plt.legend()

### Data chunking

Now we want to run the crosscorrelation around each frame.

We need to take a chunk of data that is big enough but short enough. Five or 6 frames seems to get good unique match with the sequence. Use `num_frame_to_corr` to set that.

Then we need to shift the photodiode by a given lag and cut the same chunk of data to correlate. There is no point in testing all the shifts, we now it will be about 2 frames. So we have `expected_lag ~= int(2/frame_rate*ao_sampling)` (in samples). 

To make things reasonably fast we also limit the search to a 3 frames of lag (+/- around expected_lag). With `maxlag ~= int(3/frame_rate*ao_sampling)`

In [None]:
num_frame_to_corr = 6
maxlag_samples = int(np.round(3 / frame_rate * ao_sampling))  # make it into samples
expected_lag_samples = int(np.round(2 /frame_rate * ao_sampling))  # make it into samples

Finally we need to decide if we take the chunk of data before the frame, centered on the frame or after the frame. The best choice depends on if there was a frame drop recently or not. So let's just do the 3.

In [None]:
window = [np.array([-1, 1]) * maxlag_samples + 
          np.array(w * num_frame_to_corr / frame_rate
                    * ao_sampling, dtype='int')
              for w in [np.array([-1, 0]), np.array([-0.5, 0.5]), np.array([0, 1])]]
# for bef window, we add 1 frame to have the current frame included
window[0] += int(1/frame_rate * ao_sampling)
# for center window, we shift by 0.5 frame to center
window[1] += int(0.5/frame_rate * ao_sampling)

example_frame = 5234
frame_sample = frame_borders[example_frame]
lab = ['bef', 'center', 'aft']
t0 = ao_time[frame_sample]
for iw, w in enumerate(window):
    part = slice(*w + frame_sample)
    plt.plot(ao_time[part][maxlag_samples:-maxlag_samples+1] - t0, 
             normed_pd[part][maxlag_samples:-maxlag_samples+1] + iw * 0.5, 
             label=lab[iw])
plt.axvspan(ao_time[frame_sample] - t0, ao_time[frame_borders[example_frame + 1]] - t0, 
            alpha=0.5)
_ = plt.legend()

In [None]:
maxlag_samples

In [None]:
import matplotlib.pyplot as plt
_=plt.hist(frames_df.lag.values*1000, 
           bins=np.arange(frames_df.lag.min()*1000,frames_df.lag.max()*1000))

In [None]:
bad = np.diff(frames_df.closest_frame.values) < 1
frames_df.iloc[1:][bad]

Add that to the dataframe

## Match cross correlation results to frame logger

Ideally, if there is no frame drop, it does not matter if we look at the frames perceeding or following the frame we want to sync. That should be most of the case.

In [None]:
db = db_dict['debug_info']
normed_pd = np.array(photodiode, dtype=float)
normed_pd -= np.quantile(normed_pd, 0.01)
normed_pd /= np.quantile(normed_pd, 0.99)
db.keys()

In [None]:
rng = np.random.default_rng(102)
w = frames_df[frames_df.sync_reason == 'photodiode matching'].index
random_select = [w[i] for i in rng.integers(len(w), size=10)]
bad = np.diff(frames_df.closest_frame.values) < 1
badi = np.where(bad)[0]
random_select = frames_df.iloc[badi[20]+np.array([0, 1, 2], dtype=int)].index
labels = ['bef', 'center', 'aft']
num_frame_to_corr=5
maxlag=int(5./frame_rate*ao_sampling)
expected_lag=int(2./frame_rate*ao_sampling)
window = [np.array([-1, 1]) * maxlag + 
          np.array(w * num_frame_to_corr/frame_rate * ao_sampling, dtype='int')
          for w in [np.array([-1,0]), np.array([-0.5, 0.5]), np.array([0, 1])]]
seq_trace = db['seq_trace']

for frame in random_select:
    #frame = frames_df[~good].index[num]
    #frame = frames_df.index[num]
    fseries = frames_df.loc[frame]
    on_s = fseries.onset_sample
    off_s = fseries.offset_sample
    on_t = fseries.onset_time
    off_t = fseries.offset_time
    w = np.array([-50, 50])
    vfdf = frames_df[(frames_df.onset_sample > w[0] + on_s) &
                     (frames_df.offset_sample < w[1] + off_s)]
    qc = np.array([fseries[['quadcolor_%s' % w for w in labels]]])
    best = fseries.crosscorr_picked
    fig = plt.figure(figsize=(7, 7))
    plt.gca().get_yaxis().set_visible(False)

    col = dict(bef='r', center='g', aft='b')
    for i in range(3):
        label = 'Photodiode' if i == 1 else None
        plt.plot(ao_time[slice(*w+on_s)]-on_t, normed_pd[slice(*w+on_s)] + i,
                 label=label, color='purple')
        label = 'Frame #%d' % frame if i == 1 else None
        plt.axvspan(0, off_t-on_t, color='purple', alpha=0.2, label=label)
        plt.plot(fseries.peak_time - on_t, fseries.photodiode, 'o', color='purple')


    vlog = frame_log[(frame_log.HarpTime > w[0] / ao_sampling + on_t - fseries.lag_bef) &
                     (frame_log.HarpTime < w[1] / ao_sampling + off_t)]
    plt.plot(vlog.HarpTime.values - on_t, vlog.PhotoQuadColor - 1.5, drawstyle='steps-post', 
             label='Render frame')

    i = 0
    for win, lab in zip(window, ['bef', 'center', 'aft']):
        cut_win = win + maxlag * np.array([1,-1], dtype=int)
        l = fseries['lag_%s' % lab]
        part = seq_trace[slice(*win + on_s)]
        cut_part = seq_trace[slice(*cut_win + on_s)]
        x = normed_pd[slice(*win+on_s)][maxlag :-maxlag + 1]
        
        plt.plot(ao_time[slice(*win+on_s)]-on_t + l, part + i, alpha=0.75, lw=2,
                color=col[lab])
        plt.plot(ao_time[slice(*win+on_s)][maxlag :-maxlag + 1] -on_t,
                 x + i, alpha=0.5, lw=4, ls='--', color=col[lab])
        
        cl = fseries['closest_frame_%s'% lab]
        plt.plot(frame_log.iloc[cl].HarpTime - on_t, frame_log.iloc[cl].PhotoQuadColor-1.5 + i/6, 'o',
                color=col[lab])
        if lab == best:
            plt.plot(frame_log.iloc[cl].HarpTime - on_t, frame_log.iloc[cl].PhotoQuadColor-1.5 + i/6, 'o',
                    mfc='None', mec='k', ms=10, mew=2)
            plt.plot(frame_log.iloc[cl].HarpTime - on_t + l, frame_log.iloc[cl].PhotoQuadColor + i, 'o',
                    color='k')
            plt.plot(ao_time[slice(*cut_win+on_s)]-on_t + l, cut_part + i, alpha=1, lw=1,
                     color='k', label='Selected match')

        i+=1
        plt.title('%s' % fseries.onset_sample)

    plt.legend(loc='lower right')



In [None]:
# find what is the actual photodiode value and how does it depend on previous value

df = pd.DataFrame(frames_df.iloc[1:][['quadcolor', 'photodiode']]).reset_index()
bef = pd.DataFrame(frames_df.iloc[:-1][['quadcolor', 'photodiode']]).reset_index()
df['quadcolor_before'] = bef['quadcolor']
df['photodiode_before'] = bef['photodiode']
df.head()

In [None]:
mat_data = df.groupby(['quadcolor_before', 'quadcolor']).aggregate(np.nanmean).photodiode
n_data = df.groupby(['quadcolor_before', 'quadcolor']).aggregate(len).photodiode
m = mat_data.values.reshape((5, 5))
n = n_data.values.reshape((5, 5))
plt.figure(figsize=(12,3))
plt.subplot(1,3,1)
plt.imshow(m.T, origin='lower')
cm = plt.colorbar()
plt.xlabel('quad n-1')
plt.ylabel('quad n')
plt.title('Photodiode')
plt.subplot(1,3,2)
plt.title('Difference')
plt.imshow((m-np.linspace(0,1,5)).T, origin='lower', cmap='RdBu_r')
plt.xlabel('quad n-1')
plt.ylabel('quad n')
cm = plt.colorbar()
plt.subplot(1,3,3)
plt.title('N transitions')
plt.imshow(n.T, origin='lower')
plt.xlabel('quad n-1')
plt.ylabel('quad n')
cm = plt.colorbar()


In [None]:
fseries = frames_df.loc[4713]
fseries

In [None]:
t0 = fseries.offset_time
frame_log['HarpTime'][1940:1955] - fseries.lag_aft - t0

In [None]:
frames_df.sync_reason.value_counts()

# Divers stuff

Figures to explain things for my lab meeting (09/11/2022)

## Sequence principle

In [None]:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
b = 1000
w = 100
shift = 1
0
seq = seq_trace[b:b+w]
bad_seq = np.array(seq_trace[b+shift: b+w+shift])
bad_seq[int(w/3):int(w/3 + w/3 * 0.6)] = bad_seq[int(w/3)]
ax.plot((ao_time[b:b+w] - ao_time[b])*1e3, bad_seq)
ax.plot((ao_time[b:b+w] - ao_time[b])*1e3, seq + 1)
ax.set_xlabel('Time (ms)')
ax.yaxis.set_visible(False)
for w in ['top', 'left', 'right']:
    ax.spines[w].set_visible(False)