In [None]:
from v1_depth_analysis.config import PROJECT
import flexiznam as flz
import numpy as np 
import pickle
from pathlib import Path
import matplotlib.pyplot as plt

flm_sess = flz.get_flexilims_session(project_id=PROJECT)

In [None]:
playback_rec = flz.get_entities(
    datatype="recording",
    query_key="protocol",
    query_value="SpheresPermTubeRewardPlayback",
    flexilims_session=flm_sess,
)
print(f'Found {len(playback_rec)} recordings with playback')

In [None]:
def get_rsos(param_logger,     speed_thr = 0.01, log=True ):
    vrs = np.array(param_logger.EyeZ.diff() / param_logger.HarpTime.diff(),dtype=float)
    vrs = np.clip(vrs, speed_thr, None)
    rs = np.array(param_logger.MouseZ.diff() / param_logger.HarpTime.diff(),dtype=float)
    rs = np.clip(rs, speed_thr, None)
    depth = np.array(param_logger.Depth, copy=True, dtype=float)
    depth[depth < 0 ] = np.nan
    of = np.degrees(vrs / depth)
    func = np.log10 if log else lambda x:x
    lrs = func(rs * 100)
    lvrs = func(vrs * 100)
    lof = func(of)
    return lrs, lof, lvrs


In [None]:

def get_data(session):
    loggers = dict(closed_loop=[], open_loop=[])
    sess_children = flz.get_children(parent_id=session.id, flexilims_session=flm_sess, children_datatype='recording')
    recs_closeloop = sess_children[sess_children.protocol=="SpheresPermTubeReward"]
    recs_openloop = sess_children[sess_children.protocol=="SpheresPermTubeRewardPlayback"]
    sess_ds =  flz.get_children(parent_id=session.id, flexilims_session=flm_sess, children_datatype='dataset')
    suite_2p = sess_ds[sess_ds.dataset_type=='suite2p_rois']
    assert len(suite_2p) == 1
    suite_2p = flz.Dataset.from_flexilims(data_series=suite_2p.iloc[0], flexilims_session=flm_sess)
    ops = np.load(suite_2p.path_full/'suite2p'/'plane0'/'ops.npy', allow_pickle=True).item()
    processed = Path(flz.PARAMETERS['data_root']['processed'])
    
    for _, rec_closeloop in recs_closeloop.iterrows():
        print(f"Analysing {rec_closeloop.name}")    
        with open(processed / rec_closeloop.path / 'img_VS.pickle', 'rb') as handle:
            img_VS_original = pickle.load(handle)
        loggers['closed_loop'].append(img_VS_original)
    for _, rec_playback in recs_openloop.iterrows():
        print(f"Analysing {rec_playback.name}")    
        with open(processed / rec_playback.path /'img_VS.pickle', 'rb') as handle:
            img_VS_playback = pickle.load(handle)
        loggers['open_loop'].append(img_VS_playback)
    return loggers, ops


In [None]:
data = dict()
import pandas as pd
errors = []
for rec_name, rec_playback in playback_rec.groupby('origin_id'):
    sess = flz.get_entity(id=rec_playback.iloc[0].origin_id, flexilims_session=flm_sess)
    print(rec_name)
    try:
        loggers, ops = get_data(sess)
    except FileNotFoundError as err:
        print(f'Error: {err}')
        errors.append(err)
        continue
    
    for kind, logs in loggers.items():
        for il, logger in enumerate(logs):
            crs, cof, cvrs = get_rsos(logger, speed_thr=0.01)
            data[f"{rec_name}_{kind}_{il}"] = dict(
                rs=crs,
                of=cof,
                vrs=cvrs,
            )
print(f'Could not load {len(errors)} recordings')

In [None]:
p = [Path(err.filename).relative_to(flz.PARAMETERS['data_root']['processed']) for err in errors]
for l in p:
    print(l)

In [None]:
# make bins

rs_bin_log_min= 0
rs_bin_log_max= 2.5
rs_bin_num= 6
of_bin_log_min= -1.5
of_bin_log_max= 3.5
of_bin_num= 11
log_base= 10


rs_bin = np.linspace(rs_bin_log_min, rs_bin_log_max, rs_bin_num)
of_bin = np.linspace(of_bin_log_min, of_bin_log_max, of_bin_num)
print(rs_bin)
print(of_bin)

    
for v in rs_bin:
    plt.axvline(v, color='purple', ymin=0.7)
for v in of_bin:
    plt.axvline(v, color='orange', ymin=0.8)
_ = plt.hist(lrs, histtype='step', bins=np.arange(np.nanmin(lrs), np.nanmax(lrs), 0.5), color='purple')
_ = plt.hist(lof, histtype='step', bins=np.arange(np.nanmin(lof), np.nanmax(lof), 0.5), color='orange')

In [None]:
delta_range = np.array(np.arange(-10, 11) * 2.5, dtype=int)
def timewalk_onthemap(lrs, lof, pos, delta_range=delta_range, rs_bin=rs_bin, of_bin=of_bin):

    blrs = np.array(np.round((lrs - rs_bin[0]) * 2))
    blof = np.array(np.round((lof - of_bin[0]) * 2))

    assert np.nanmin(blrs)>=0
    assert np.nanmin(blof)>=0
    assert np.nanmax(blrs) < len(rs_bin)
    assert np.nanmax(blof) < len(of_bin)

    rs_ind, of_ind = np.meshgrid(rs_bin, of_bin)
    out = np.zeros((*rs_ind.shape, len(delta_range)))

    start_pts, = np.where((blof == pos[0]) & (blrs == pos[1]))

    for ide, delta in enumerate(delta_range):
        pts = start_pts + delta
        vals = np.vstack([blof[pts], blrs[pts]])
        bad = np.isnan(np.sum(vals, axis=0))
        vals = np.array(vals[:, ~bad], dtype=int)
        h, xe, ye = np.histogram2d(vals[0], vals[1], [np.arange(len(of_bin)+1), np.arange(len(rs_bin)+1)])
        out[..., ide] = h / vals.shape[1]
    return out


In [None]:
# coords are (of, rs)
pos = [7, 4]

In [None]:
lrs,lof , lvrs= get_rsos(img_VS_original,     speed_thr = 0.01 )
valid = ~np.isnan(lof)

blrs = np.array(np.round((lrs - rs_bin[0]) * 2))
blof = np.array(np.round((lof - of_bin[0]) * 2))

h, xe, ye = np.histogram2d(blof[valid], blrs[valid], [np.arange(len(of_bin)+1), np.arange(len(rs_bin)+1)])
plt.imshow(h, origin='lower')
_ = plt.xticks(np.arange(len(rs_bin)), labels=np.round(10**rs_bin, 1), rotation=90)
_ = plt.yticks(np.arange(len(of_bin)), labels=np.round(10**of_bin, 1))
cb = plt.colorbar()
cb.set_label('# frames')
plt.xlabel('Running speed (cm/s)')
plt.ylabel('Optic flow (degree/s)')


In [None]:
fig, axes = plt.subplots(3, 7)
fig.set_size_inches(10, 9)
lrs,lof, lvrs = get_rsos(img_VS_original,     speed_thr = 0.01 )
out = timewalk_onthemap(lrs, lof, pos, rs_bin=rs_bin, of_bin=of_bin)
for idelta, delta in enumerate(delta_range):
    ax = axes.flatten()[idelta]
    ax.set_title(f"{np.round(delta/ops['fs'],1)}s")
    img = ax.imshow(out[..., idelta], vmin=0, vmax=0.5, origin='lower')
    # fig.colorbar(img, ax=ax)
    ax.axis('off')
plt.subplots_adjust(wspace=0.02, hspace=0.001)

In [None]:
fig, axes = plt.subplots(3, 7)
fig.set_size_inches(10, 9)

lrs,lof, lvrs = get_rsos(img_VS_playback,     speed_thr = 0.01 )
out = timewalk_onthemap(lrs, lof, pos, rs_bin=rs_bin, of_bin=of_bin)
for idelta, delta in enumerate(delta_range):
    ax = axes.flatten()[idelta]
    ax.set_title(f"{np.round(delta/ops['fs'],1)}s")
    img = ax.imshow(out[..., idelta], vmin=0, vmax=0.5, origin='lower')
    # fig.colorbar(img, ax=ax)
    ax.axis('off')
plt.subplots_adjust(wspace=0.02, hspace=0.001)

In [None]:
from scipy import signal

def norm_corr(x,y):

    return signal.correlate(x/np.linalg.norm(x), y/np.linalg.norm(y), mode='full')

lrs,lof, lvrs = get_rsos(img_VS_original,  speed_thr = 0.01, log=False)
valid = ~np.isnan(lof)
corr = norm_corr(lrs[valid], lrs[valid])
lags = signal.correlation_lags(len(lrs[valid]), len(lrs[valid]), mode='full') / ops['fs']
xl = np.searchsorted(lags, [-10 , 10])


plt.subplot(4,1,1)
plt.plot(lags[slice(*xl)], corr[slice(*xl)])

corr = norm_corr(lof[valid], lof[valid])
plt.subplot(4,1,2)
plt.plot(lags[slice(*xl)], corr[slice(*xl)])

corr = norm_corr(lof[valid], lrs[valid])
plt.subplot(4,1,3)
plt.plot(lags[slice(*xl)], corr[slice(*xl)])

corr = norm_corr(lvrs[valid], lrs[valid])
plt.subplot(4,1,4)
plt.plot(lags[slice(*xl)], corr[slice(*xl)])

plt.gcf().set_size_inches(7,6)

In [None]:
from scipy import signal

def norm_corr(x,y):

    return signal.correlate(x/np.linalg.norm(x), y/np.linalg.norm(y), mode='full')

fig, axes = plt.subplots(4,1)
xlims = [-3000, 3000]
labels = ['Closed loop', 'Open loop']
for iw, w in enumerate([img_VS_original, img_VS_playback]):
    lrs,lof, lvrs = get_rsos(w,  speed_thr = 0.01, log=False)
    valid = ~np.isnan(lof)

    lags = signal.correlation_lags(len(lrs[valid]), len(lrs[valid]), mode='full') / ops['fs']
    xl = np.searchsorted(lags, xlims)

    corr = norm_corr(lrs[valid], lrs[valid])
    axes[0].plot(lags[slice(*xl)], corr[slice(*xl)], label=labels[iw])
    axes[0].set_title('Running speed autocorrelation')

    corr = norm_corr(lof[valid], lof[valid])
    axes[1].set_title('Optic flow autocorrelation')
    axes[1].plot(lags[slice(*xl)], corr[slice(*xl)])

    corr = norm_corr(lof[valid], lrs[valid])
    axes[2].plot(lags[slice(*xl)], corr[slice(*xl)])
    axes[2].set_title('Optic flow - Running speed correlation')

    corr = norm_corr(lvrs[valid], lrs[valid])
    axes[3].set_title('Virtual speed - Running speed correlation')
    axes[3].plot(lags[slice(*xl)], corr[slice(*xl)])

for x in axes:
    x.set_ylabel('Normalised correlation')
    x.set_xlabel('Lag (s)')
    x.set_xlim(*xlims)
axes[0].legend(loc="upper right")
fig.set_size_inches(5,10)
fig.subplots_adjust(hspace=0.5)

In [None]:
rs,of, vrs = get_rsos(img_VS_playback,  speed_thr = 0.01, log=False)

In [None]:
fig, ax = plt.subplots(3, 1, sharex=True)
ax[2].plot(rs, lw=1)
ax[1].plot(vrs, lw=1)
ax[0].plot(of, lw=1)



In [None]:
plt.plot(img_VS_playback.Mouse)

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(1,2,1,aspect='auto')
plt.scatter(lrs, lof, alpha=.1, s=1)
plt.subplot(1,2,2,aspect='auto')
plt.scatter(lrs, lvrs, alpha=.1, s=1)