# Multi depth experiments

Corridors with spheres at multidepths presented simultaneously

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
project = "colasa_3d-vision_revisions"
session_name = "PZAG16.3b_S20250317"
session_name = "PZAH17.1e_S20250403"

In [None]:
# Yiran workflow is to call sbatch_session on all sessions, like that:

if False:
    from cottage_analysis.pipelines import pipeline_utils

    pipeline_filename = "run_analysis_pipeline.sh"
    conflicts = "overwrite"
    photodiode_protocol = 5
    pipeline_utils.sbatch_session(
        project=project,
        session_name=session_name,
        pipeline_filename=pipeline_filename,
        conflicts=conflicts,
        photodiode_protocol=photodiode_protocol,
    )

# Analyse the multidepth experiments

In theory the same pipeline should work, if we change the protocol base

In [None]:
from cottage_analysis.pipelines import analysis_pipeline
import flexiznam as flz

# We can just run the same pipeline. It will skip depth and rsof fit and just run the
# the rf fit
protocol_base = "SpheresPermTubeReward"
if False:
    analysis_pipeline.main(
        project,
        session_name,
        conflicts="overwrite",
        photodiode_protocol=5,
        use_slurm=False,
        run_depth_fit=False,
        run_rf=True,
        run_rsof_fit=False,
        run_plot=True,
        protocol_base="SpheresPermTubeReward_multidepth",
    )
else:
    # here we do it step by step for debugging
    from cottage_analysis.analysis import spheres

    flexilims_session = flz.get_flexilims_session(project)
    print("---Start synchronisation...---")
    vs_df_all, trials_df_all = spheres.sync_all_recordings(
        session_name=session_name,
        flexilims_session=flexilims_session,
        project=project,
        filter_datasets={"anatomical_only": 3},
        recording_type="two_photon",
        protocol_base=protocol_base,
        photodiode_protocol=5,
        return_volumes=True,
        conflicts="skip",
    )

    # Add trial number to flexilims
    trial_no_closedloop = len(trials_df_all[trials_df_all["closed_loop"] == 1])
    trial_no_openloop = len(trials_df_all[trials_df_all["closed_loop"] == 0])
    ndepths = len(trials_df_all["depth"].unique())
    flz.update_entity(
        "session",
        name=session_name,
        mode="update",
        attributes={
            "closedloop_trials": trial_no_closedloop,
            "openloop_trials": trial_no_openloop,
            "ndepths": ndepths,
        },
        flexilims_session=flexilims_session,
    )

    suite2p_datasets = flz.get_datasets(
        origin_name=session_name,
        dataset_type="suite2p_rois",
        project_id=project,
        flexilims_session=flexilims_session,
        return_dataseries=False,
        filter_datasets={"anatomical_only": 3},
    )
    suite2p_dataset = suite2p_datasets[0]
    frame_rate = suite2p_dataset.extra_attributes["fs"]

    is_multidepth = "multidepth" in protocol_base

In [None]:
import flexiznam as flz
from cottage_analysis.pipelines import analysis_pipeline
import flexiznam as flz

exclude_datasets = None

harp_is_in_recording = True
use_onix = False
conflicts = "skip"
sync_kwargs = None
ephys_kwargs = None
# We can just run the same pipeline. It will skip depth and rsof fit and just run the
# the rf fit
protocol_base = "SpheresPermTubeReward_multidepth"
flexilims_session = flz.get_flexilims_session(project_id=project)
assert flexilims_session is not None or project is not None
filter_datasets = {"anatomical_only": 3}
recording_type = "two_photon"
protocol_base = protocol_base
photodiode_protocol = 5
return_volumes = True

if flexilims_session is None:
    flexilims_session = flz.get_flexilims_session(project_id=project)

exp_session = flz.get_entity(
    datatype="session", name=session_name, flexilims_session=flexilims_session
)
recordings = flz.get_entities(
    datatype="recording",
    origin_id=exp_session["id"],
    query_key="recording_type",
    query_value=recording_type,
    flexilims_session=flexilims_session,
)
recordings = recordings[recordings.name.str.contains(protocol_base)]
if "exclude_reason" in recordings.columns:
    recordings = recordings[recordings["exclude_reason"].isna()]
from cottage_analysis.analysis.spheres import *

load_onix = False if recording_type == "two_photon" else True
for i, recording_name in enumerate(recordings.name):
    print(f"Processing recording {i+1}/{len(recordings)}")
    recording, harp_recording, onix_rec = get_relevant_recordings(
        recording_name, flexilims_session, harp_is_in_recording, load_onix
    )
    break
vs_df = synchronisation.generate_vs_df(
    recording=recording,
    photodiode_protocol=photodiode_protocol,
    flexilims_session=flexilims_session,
    harp_recording=harp_recording,
    onix_recording=onix_rec if use_onix else None,
    project=project,
    conflicts=conflicts,
    sync_kwargs=sync_kwargs,
    protocol_base=protocol_base,
)
imaging_df = synchronisation.generate_imaging_df(
    vs_df=vs_df,
    recording=recording,
    flexilims_session=flexilims_session,
    filter_datasets=filter_datasets,
    exclude_datasets=exclude_datasets,
    return_volumes=return_volumes,
)
imaging_df = format_imaging_df(imaging_df=imaging_df, recording=recording)

In [None]:
from cottage_analysis.io_module.visstim import get_frame_log, get_param_log

harp_ds = flz.get_datasets(
    flexilims_session=flexilims_session,
    origin_name=recording.name,
    dataset_type="harp",
    allow_multiple=False,
    return_dataseries=False,
)
frame_log = get_frame_log(
    harp_ds.flexilims_session,
    harp_recording=harp_recording,
    vis_stim_recording=recording,
)
param_log = get_param_log(
    harp_ds.flexilims_session,
    harp_recording=harp_recording,
    vis_stim_recording=recording,
    multidepth=True,
)

In [None]:
vs_df.columns

In [None]:
onsets = imaging_df.index[imaging_df.stim.diff() == 1]
offsets = imaging_df.index[imaging_df.stim.diff() == -1]

# we want to clean fast alternation, keep only onsets where the previous 2s were not
# stim
onset_times = imaging_df["imaging_harptime"].loc[onsets]
start_windows = imaging_df["imaging_harptime"].searchsorted(onset_times - 2)
is_valid_onset = np.zeros(len(onsets), dtype=bool)
for i, start in enumerate(start_windows):
    chunk = imaging_df["stim"].iloc[start : onsets[i]]
    is_valid_onset[i] = chunk.sum() == 0
clean_onsets = onsets[is_valid_onset]

# Similarly for offsets, we want to keep only offsets where the next 2s were not stim
offset_times = imaging_df["imaging_harptime"].loc[offsets]
end_windows = imaging_df["imaging_harptime"].searchsorted(offset_times + 2)
end_windows = np.clip(end_windows, 0, len(imaging_df) - 1)
is_valid_offset = np.zeros(len(offsets), dtype=bool)
for i, end in enumerate(end_windows):
    chunk = imaging_df["stim"].iloc[offsets[i] : end]
    is_valid_offset[i] = chunk.sum() == 0
clean_offsets = offsets[is_valid_offset]

In [None]:
# Load the reward logs
rewards = {}
for fid, fname in harp_ds.extra_attributes["csv_files"].items():
    if fid.startswith("Reward"):
        depth = int(fid.split("_")[1].strip("cm.csv"))
        rewards[depth] = pd.read_csv(harp_ds.path_full / fname)

In [None]:
import matplotlib.pyplot as plt

t0 = imaging_df["imaging_harptime"].min()
plt.figure(figsize=(25, 6))
ax1 = plt.subplot(311)
depths = list((imaging_df["depth"].unique() * 100).astype(int))
for i, (log, df) in enumerate(param_log.groupby("logger_fname")):
    depth = int(log.split("_")[1].strip("cm.csv"))
    print(f"{log}: {len(df)}")
    stim = (df.Radius > 0).astype(float) * 0.5
    # stim[stim == 0] = np.nan
    x = (df.HarpTime - t0).values
    # x = df.Frameindex.values
    plt.plot(x, stim + depths.index(depth), "|", label=depth)
    rw = rewards[depth]
    plt.plot(
        rw["HarpTime"] - t0,
        rw["Reward"] * 0.5 + depths.index(depth),
        "|",
        color="k",
        alpha=1,
        ms=20,
    )
plt.legend()
plt.ylabel("Visual stim per depth - param_log")

plt.subplot(312, sharex=ax1)
plt.plot(vs_df.monitor_harptime - t0, vs_df.Radius > 0, "|")
plt.ylabel("vs_df")
plt.subplot(313, sharex=ax1)
# now using imaging_df
plt.plot(imaging_df["imaging_harptime"] - t0, imaging_df["depth"] > 0, "|")
plt.plot(imaging_df["imaging_harptime"] - t0, imaging_df["stim"] * 0.5)
plt.plot(
    imaging_df.loc[onsets, "imaging_harptime"] - t0,
    imaging_df.loc[onsets, "stim"] * 0.5,
    "o",
    color="g",
    alpha=0.5,
)
plt.plot(
    imaging_df.loc[offsets, "imaging_harptime"] - t0,
    imaging_df.loc[offsets, "stim"] * 0.5,
    "o",
    color="r",
    alpha=0.5,
)
plt.plot(
    imaging_df.loc[clean_onsets, "imaging_harptime"] - t0,
    imaging_df.loc[clean_onsets, "stim"] * 0.5 + 0.2,
    "*",
    color="g",
    alpha=0.5,
)
plt.plot(
    imaging_df.loc[clean_offsets, "imaging_harptime"] - t0,
    imaging_df.loc[clean_offsets, "stim"] * 0.5 + 0.2,
    "*",
    color="r",
    alpha=0.5,
)
plt.ylabel("imaging_df")
plt.xlabel("Time (s)")
ax1.set_xlim(1000, 2000)

In [None]:
df = param_log
stim = (df.Radius > 0).astype(float) * 0.5
# stim[stim == 0] = np.nan
x = (df.HarpTime - t0).values
x = df.index.values
plt.plot(x, stim + i, "o")

In [None]:
imaging_df.index

In [None]:
onsets

In [None]:
offsets[0]

In [None]:
plt.scatter(onsets, imaging_df.stim[onsets], c="red")
plt.scatter(offsets, imaging_df.stim[offsets], c="blue")
plt.plot(imaging_df.index, imaging_df.stim)
plt.xlim(500, 550)


was_blank = imaging_df.stim.rolling(window=10).sum() == 0
plt.plot(imaging_df.index, was_blank.shift(-10) * 0.5)

# plt.plot(imaging_df.index, was_blank.shift(1))

In [None]:
imaging_df.columns

In [None]:
np.diff(offsets).min()

In [None]:
import matplotlib.pyplot as plt

imaging_df["stim"] = np.nan
imaging_df.loc[imaging_df.depth.notnull(), "stim"] = 1
imaging_df.loc[imaging_df.depth < 0, "stim"] = 0
onsets = imaging_df.stim.diff() == 1
offsets = imaging_df.stim.diff() == -1

# Stim is somewhat bistable at onsext and offset. Only keep onset that have at least 10
# frames of non-stim before
onsets = imaging_df.stim.diff() == 1
offsets = imaging_df.stim.diff() == -1
# Keep only onset where the previous 10 frames were blank and the offsets where
# the next 10 frames are blank
was_blank = imaging_df.stim.rolling(window=10).sum() == 0
onsets = onsets & was_blank.shift(1)
onsets = imaging_df.index[onsets]
offsets = offsets & was_blank.shift(-10)
offsets = imaging_df.index[offsets]
if offsets[0] < onsets[0]:
    print("Warning: offsets start before onsets! Double check!")
    offsets = offsets[1:]
    assert (
        offsets[0] > onsets[0]
    ), "Warning: 2 offsets start before onsets! Double check!"
is_stim = pd.Series(index=imaging_df.index, data=False, dtype=bool)
for on, off in zip(onsets, offsets):
    is_stim[on:off] = True


ax = plt.subplot(1, 1, 1)
plt.plot(imaging_df["stim"], label="stim")
plt.scatter(imaging_df.index[onsets], imaging_df["stim"][onsets], marker="o", color="g")
plt.scatter(
    imaging_df.index[offsets], imaging_df["stim"][offsets], marker="o", color="r"
)
plt.plot(is_stim, label="is_stim", color="purple")

plt.xlim([3591 - 1000, 3591 + 5000])

In [None]:
b = 280
t0 = imaging_df.iloc[b]["imaging_harptime"]
subset_df = imaging_df.iloc[b : b + 10]
print(
    subset_df[
        [
            "imaging_harptime",
            "monitor_harptime",
            "stimulus_harptime",
            "mouse_z_harptime",
        ]
    ]
    - t0
)
subset_df.depth

In [None]:
import numpy as np

imaging_df["stim"] = np.nan
imaging_df.loc[imaging_df.depth.notnull(), "stim"] = 1
imaging_df.loc[imaging_df.depth < 0, "stim"] = 0
imaging_df_simple = imaging_df[
    (imaging_df["stim"].diff() != 0) & (imaging_df["stim"]).notnull()
].copy()
imaging_df_simple.depth = np.round(imaging_df_simple.depth, 2)

In [None]:
plt.plot(imaging_df.stim)
plt.xlim(200, 400)

In [None]:
imaging_df["temp_time"] = (
    imaging_df["imaging_harptime"] - imaging_df["imaging_harptime"].iloc[284]
)
imaging_df.iloc[275:300][["eye_z", "stim", "trial_idx", "depth", "temp_time"]]

In [None]:
imaging_df.columns

In [None]:
tr

In [None]:
is_closedloop = True
sfx = "_closedloop"
frames_all, imaging_df_all = spheres.regenerate_frames_all_recordings(
    session_name=session_name,
    flexilims_session=flexilims_session,
    project=None,
    filter_datasets={"anatomical_only": 3},
    recording_type="two_photon",
    is_closedloop=is_closedloop,
    is_multidepth=is_multidepth,
    protocol_base=protocol_base,
    photodiode_protocol=5,
    return_volumes=True,
    verbose=False,
    resolution=5,
)

In [None]:
frames_all.shape

In [None]:
print(f"Fitting RF{sfx}...")
# The first step is to estimate hyperparameters
if False:
    (
        coef,
        r2,
        best_reg_xys,
        best_reg_depths,
    ) = spheres.fit_3d_rfs_hyperparam_tuning(
        imaging_df_all,
        frames_all[:, :, int(frames_all.shape[2] // 2) :],
        reg_xys=np.geomspace(2.5, 10240, 13),
        reg_depths=np.geomspace(2.5, 10240, 13),
        shift_stim=2,
        use_col="dffs",
        k_folds=5,
        tune_separately=True,
        validation=False,
    )
else:
    imaging_df = imaging_df_all
    frames = frames_all[..., int(frames_all.shape[2] // 2) :]
    reg_xys = [20, 40, 80, 160, 320]
    reg_depths = [20, 40, 80, 160, 320]
    shift_stim = 2
    use_col = "dffs"
    k_folds = 5
    tune_separately = True
    validation = True
    r2_threshold = 0.01

In [None]:
import numpy as np
from cottage_analysis.analysis.spheres import fit_3d_rfs, fit_3d_rfs_multidepth

depth_list = imaging_df.depth.dropna().unique()
depth_list = np.sort(depth_list[depth_list > 0])
all_coef = np.zeros(
    (
        len(reg_xys) * len(reg_depths),
        k_folds,
        frames.shape[-2] * frames.shape[-1] * len(depth_list) + 1,
        imaging_df.loc[0, "dffs"].shape[1],
    )
)
all_r2s = np.zeros(
    (len(reg_xys) * len(reg_depths), imaging_df.loc[0, "dffs"].shape[1], 2)
)
hyperparams = np.zeros((len(reg_xys) * len(reg_depths), 2))
good_neuron_percs = np.zeros((len(reg_xys), len(reg_depths)))
nrois = imaging_df.loc[0, "dffs"].shape[1]
if frames.ndim == 4:
    fit_func = fit_3d_rfs_multidepth
elif frames.ndim == 3:
    fit_func = fit_3d_rfs
else:
    raise ValueError("frames must be 3D or 4D")
idx = 0

In [None]:
frames.ndim

In [None]:
all_r2s.shape

In [None]:
i = 0
j = 0
reg_xy = reg_xys[i]
reg_depth = reg_depths[j]

In [None]:
print(fit_func)

In [None]:
reg_xy = reg_xy
reg_depth = reg_depth
shift_stim = shift_stim
use_col = use_col
k_folds = k_folds
choose_rois = ()
validation = validation

In [None]:
frames.shape

In [None]:
from scipy.stats import zscore

ndepths, nframes, nelev, nazim = frames.shape
resps = zscore(np.concatenate(imaging_df[use_col]), axis=0)
if len(choose_rois) > 0:
    resps = resps[:, choose_rois]
depths = imaging_df.depth.unique()
depths = depths[~np.isnan(depths)]
depths = depths[depths > 0]
depths = np.sort(depths)

is_stim = imaging_df.depth > 0
trial_start_stop = np.diff(is_stim.astype(int))
trial_idx = np.cumsum(np.hstack([0, trial_start_stop == 1])).astype(float)
trial_idx[imaging_df.depth.isna()] = np.nan
trial_idx[imaging_df.depth < 0] = np.nan
imaging_df["trial_idx"] = trial_idx

assert depths.shape[0] == frames.shape[0]
# Shift to account for response lag
X = np.roll(frames, shift_stim, axis=1)
X = np.swapaxes(X, 0, 1)  # put back frame number as first axis
# (now we have frame, depth, ele, azi)
X = X.reshape(nframes, -1)  # flatten

L = spheres.laplace_matrix(nelev, nazim)

In [None]:
imaging_df.trial_idx.value_counts()

In [None]:
print(L.shape)
print(X.shape)
nazim * nelev

In [None]:
Ls = []
Ls_depth = []

for idepth, depth in enumerate(depths):
    L_xy = np.zeros((L.shape[0], X.shape[1]))
    L_xy[:, idepth * L.shape[1] : (idepth + 1) * L.shape[1]] = L
    Ls.append(L_xy)
    # add regularization penalty on the second derivative of the coefficients
    # along the depth axis
    L_depth = np.zeros((L.shape[1], X.shape[1]))
    L_depth[:, idepth * L.shape[1] : (idepth + 1) * L.shape[1]] = (
        np.identity(L.shape[1]) * 2
    )
    if idepth > 0:
        L_depth[:, (idepth - 1) * L.shape[1] : idepth * L.shape[1]] = -np.identity(
            L.shape[1]
        )
    if idepth < depths.shape[0] - 1:
        L_depth[:, (idepth + 1) * L.shape[1] : (idepth + 2) * L.shape[1]] = (
            -np.identity(L.shape[1])
        )
    Ls_depth.append(L_depth)
L = np.concatenate(Ls, axis=0)
L = np.concatenate([L, np.zeros((L.shape[0], 1))], axis=1)
L_depth = np.concatenate(Ls_depth, axis=0)
L_depth = np.concatenate([L_depth, np.zeros((L_depth.shape[0], 1))], axis=1)
# add bias
X = np.concatenate([X, np.ones((X.shape[0], 1))], axis=1)
coefs = []

In [None]:
print(L_depth.shape)
print(L.shape)
print(X.shape)

In [None]:
from sklearn.model_selection import KFold

# 0 for train and -1 for test, 1 for validation prediction
n_splits = 3 if validation else 2
Y_pred = np.zeros((resps.shape[0], resps.shape[1], n_splits)) * np.nan
# randomly split trials into training and test sets
kfold = KFold(n_splits=k_folds, random_state=42, shuffle=True)
# Use validation set to select the best regularization parameters (train, val, test),
# or use test set to evaluate performance (train, test)
trials = imaging_df.trial_idx.dropna().unique()

In [None]:
for train_trials, test_trials in kfold.split(trials):
    if validation:
        train_trials, validation_trials = spheres.train_test_split(
            train_trials,
            test_size=(1 / (k_folds - 1)),
        )
        validation_idx = np.isin(imaging_df.trial_idx, validation_trials)
    train_idx = np.isin(imaging_df.trial_idx, train_trials)
    test_idx = np.isin(imaging_df.trial_idx, test_trials)
    break

In [None]:
X_train = np.concatenate([X[train_idx, :], reg_xy * L, reg_depth * L_depth], axis=0)
print(X_train.shape)

In [None]:
Q = np.linalg.inv(X_train.T @ X_train) @ X_train.T
print(Q.shape)

In [None]:
print(resps.shape)

In [None]:
Y_train = np.concatenate(
    [
        resps[train_idx, :],
        np.zeros((L.shape[0], resps.shape[1])),
        np.zeros((L_depth.shape[0], resps.shape[1])),
    ],
    axis=0,
)
coef = Q @ Y_train
coefs.append(coef)

if validation:
    idxs = [train_idx, validation_idx, test_idx]
else:
    idxs = [train_idx, test_idx]
for isplit, idx in enumerate(idxs):
    Y_pred[idx, :, isplit] = X[idx, :] @ coef

In [None]:
print(Y_pred.shape)

In [None]:
r2 = np.zeros((resps.shape[1], n_splits)) * np.nan
for isplit in range(n_splits):
    use_idx = np.isfinite(Y_pred[:, 0, isplit])
    residual_var = np.sum(
        (Y_pred[use_idx, :, isplit] - resps[use_idx, :]) ** 2,
        axis=0,
    )
    total_var = np.sum(
        (resps[use_idx, :] - np.mean(resps[use_idx, :], axis=0)) ** 2, axis=0
    )
    r2[:, isplit] = 1 - residual_var / total_var

In [None]:
plt.plot(imaging_df.depth)
plt.xlim(100, 1000)
plt.ylim(0, 2)

In [None]:
plt.plot(imaging_df.depth > 0)

In [None]:
from cottage_analysis.analysis.spheres import fit_3d_rfs_hyperparam_tuning

(
    coef,
    r2,
    best_reg_xys,
    best_reg_depths,
) = spheres.fit_3d_rfs_hyperparam_tuning(
    imaging_df_all,
    frames_all[..., int(frames_all.shape[-1] // 2) :],
    reg_xys=np.geomspace(2.5, 10240, 13),
    reg_depths=np.geomspace(2.5, 10240, 13),
    shift_stim=2,
    use_col="dffs",
    k_folds=5,
    tune_separately=True,
    validation=False,
)

In [None]:
from cottage_analysis.analysis.spheres import fit_3d_rfs_ipsi

print("Fitting ipsi RF...")
coef_ipsi, r2_ipsi = spheres.fit_3d_rfs_ipsi(
    imaging_df_all,
    frames_all[:, :, : int(frames_all.shape[2] // 2)],
    best_reg_xys,
    best_reg_depths,
    shift_stim=2,
    use_col="dffs",
    k_folds=5,
    validation=False,
)

In [None]:
# Swap the depth and frame axes of the frames array and make in 2D, with shape
# (n_frames, n_depths * nelev * nazi)
frames = np.swapaxes(frames, 0, 1)
frames = frames.reshape(frames.shape[0], -1)
frames.shape

In [None]:
assert depths.shape[0] == frames.shape[0]
X = np.zeros((frames.shape[1], frames.shape[2] * frames.shape[3] * depths.shape[0]))
print(X.shape)

In [None]:
trial_idx = np.zeros_like(imaging_df.depth)
trial_idx = np.cumsum(
    np.logical_and(np.abs(imaging_df.depth.diff()) > 0, imaging_df.depth > 0)
)
trial_idx

In [None]:
n_spheres = {}
frame_reconstruction = {}
for csv_id, csv_file in harp_dataset.extra_attributes["csv_files"].items():
    if not csv_id.startswith("NewParams"):
        # not a parameter file
        continue
    depth = int(csv_id.split("_")[-1][:-2])
    param_log = pd.read_csv(harp_dataset.path_full / csv_file)
    frames, n_spheres_per_frame = regenerate_frames(
        frame_times=imaging_df.imaging_harptime,
        trials_df=trials_df,
        vs_df=vs_df,
        param_logger=param_log,
        time_column="HarpTime",
        resolution=resolution,
        sphere_size=sphere_size,
        azimuth_limits=(20, 120),
        elevation_limits=(-40, 40),
        verbose=False,
        output_datatype="int8",
        output=None,
        return_sphere_number=True,
        # flip_x=True,
    )
    has_px = frames.sum(axis=(1, 2)) > 0
    frame_reconstruction[depth] = frames
    n_spheres[depth] = n_spheres_per_frame
    print(f"Depth {depth}: {len(frames)} frames, {n_spheres_per_frame.sum()} spheres")

In [None]:
harp_csvs = harp_dataset.extra_attributes["csv_files"]
rewards_logs = {}
newparams_logs = {}
trial_ends = {}
for csvname, filename in harp_csvs.items():
    if "RewardLog" in csvname:
        rewards_logs[csvname] = pd.read_csv(harp_dataset.path_full / filename)
    elif "NewParams" in csvname:
        newparams_logs[csvname] = pd.read_csv(harp_dataset.path_full / filename)
        depth = int(csvname.split("_")[-1][:-2])
        radius = newparams_logs[csvname]["Radius"].values.astype(float)
        r = (radius == -9999).astype(int)
        corridor_end = np.diff(r) == 1
        trial_ends[depth] = newparams_logs[csvname].iloc[:-1][corridor_end].HarpTime


print(
    f"Found {len(rewards_logs)} RewardLog files and {len(newparams_logs)} NewParams files"
)

In [None]:
import matplotlib.pyplot as plt

depths = sorted(frame_reconstruction.keys())
n_per_frame = np.vstack([n_spheres[depth] for depth in depths])
fig = plt.figure()
ax = plt.subplot(111)
for idepth, depth in enumerate(depths):
    ax.scatter(
        np.arange(n_per_frame.shape[1]), n_per_frame[idepth], label=f"Depth {depth}"
    )
ax.legend(loc="upper right")
ax.set_title("Number of Spheres per Frame at Different Depths")
ax.set_xlabel("Frame Index")
ax.set_ylabel("Number of Spheres")

In [None]:
print(frame_reconstruction[depths[0]].shape)
all_frames = np.concatenate(
    [frame_reconstruction[depth][None, ...] for depth in depths], axis=0
)

In [None]:
t0 = imaging_df.imaging_harptime.min()
frame_times = imaging_df.imaging_harptime
m = n_per_frame.sum(axis=0)
tend = trial_ends[depths[0]].values - t0
plt.scatter(tend, np.zeros_like(tend) + 4, color="darkred", label="Trial end")
plt.plot(frame_times - t0, m)
plt.ylim(0, 8)
plt.xlim(2500, 2700)
plt.xlabel("Time (s)")
plt.ylabel("Number of spheres")
plt.legend(loc="upper right")

In [None]:
frame_reconstruction[10][5].max()

In [None]:
plt.imshow(frame_reconstruction[10][500])

In [None]:
ndepths = len(depths)
fig, axes = plt.subplots(ndepths, 1, figsize=(4, 10))
frame_id = 5
for idepth, depth in enumerate(depths):
    ax = axes[idepth]
    ax.imshow(all_frames[idepth, frame_id], vmin=0, vmax=1, interpolation="none")
    ax.set_ylabel(f"{depth}cm")
    ax.set_xticks([])
    ax.set_yticks([])
plt.tight_layout()

In [None]:
sess = flz.get_entity(name=session_name, datatype="session", flexilims_session=flm_sess)
recording = flz.get_entity(
    origin_id=sess.id,
    datatype="recording",
    query_key="protocol",
    query_value="SpheresPermTubeReward_multidepth",
    flexilims_session=flm_sess,
)
print("Recording:", recording.name)

In [None]:
harp_dataset = flz.get_datasets(
    origin_id=recording.id,
    dataset_type="harp",
    flexilims_session=flm_sess,
    allow_multiple=False,
)
print("HARP dataset:", harp_dataset.full_name)
for k, v in harp_dataset.extra_attributes.items():
    print(f"{k}: {v}")

In [None]:
from cottage_analysis.preprocessing import synchronisation

photodiode_protocol = 5

vs_df = synchronisation.generate_vs_df(
    recording=recording,
    photodiode_protocol=photodiode_protocol,
    flexilims_session=flm_sess,
    harp_recording=recording,
    onix_recording=None,
    project=project,
    protocol_base="SpheresPermTubeReward_multidepth",
    sync_kwargs={},
)

In [None]:
vs_df.head()

# Regenerating frames

We want to regenerate the stimulus

In [None]:
resolution = 10
sphere_size = 10
azimuth_limits = (-120, 120)
elevation_limits = (-40, 40)

from cottage_analysis.io_module.visstim import get_param_log

param_log = get_param_log(
    flexilims_session=flm_sess,
    harp_recording=recording,
    vis_stim_recording=recording,
    multidepth=True,
)
param_log.shape

In [None]:
param_log.columns

In [None]:
# Find the frame we want to reconstruct, i.e. those with something on the screen

frame_times = 

In [None]:
import numpy as np

valid_depth = sorted(param_log.query("Radius > 0").Radius.unique())
azi_pixels = np.arange(azimuth_limits[0], azimuth_limits[1] + 1, resolution)
ele_pixels = np.arange(elevation_limits[0], elevation_limits[1] + 1, resolution)
frame_shape = (len(valid_depth), len(azi_pixels), len(ele_pixels))
frame_shape

In [None]:
ax = plt.subplot(2, 1, 1)
plt.plot(param_log["Frameindex"].values[300:500], "o")
plt.subplot(2, 1, 2, sharex=ax)
bad = np.diff(param_log["Frameindex"].values[300:500]) < 0
plt.plot(np.diff(param_log["Frameindex"].values[300:500])[bad], "o")

In [None]:
flexilims_session = flm_sess
photodiode_protocol = 5
protocol_base = "SpheresPermTubeReward_multidepth"
harp_recording = recording
onix_recording = None
conflicts = "skip"
sync_kwargs = None

import warnings
import numpy as np
import pandas as pd


from cottage_analysis.utilities.misc import get_str_or_recording

from cottage_analysis.io_module.harp import load_harpmessage
from cottage_analysis.io_module.visstim import get_frame_log, get_param_log
from cottage_analysis.io_module.spikes import (
    load_kilosort_folder,
    get_smoothed_spike_rate,
)
from cottage_analysis.preprocessing import find_frames
from cottage_analysis.imaging.common.find_frames import find_imaging_frames
from cottage_analysis.imaging.common import imaging_loggers_formatting as format_loggers


monitor_frames_df = synchronisation.find_monitor_frames(
    vis_stim_recording=recording,
    flexilims_session=flexilims_session,
    photodiode_protocol=photodiode_protocol,
    harp_recording=harp_recording,
    onix_recording=onix_recording,
    conflicts=conflicts,
    sync_kwargs=sync_kwargs,
)

In [None]:
monitor_frames_df = monitor_frames_df[monitor_frames_df.closest_frame.notnull()].copy()
monitor_frames_df = find_frames.remove_frames_in_wrong_order(monitor_frames_df)
monitor_frames_df.closest_frame = monitor_frames_df.closest_frame.astype("int")
harp_ds = flz.get_datasets(
    flexilims_session=flexilims_session,
    origin_name=harp_recording.name,
    dataset_type="harp",
    allow_multiple=False,
    return_dataseries=False,
)
if type(harp_ds.extra_attributes["csv_files"]) == str:
    harp_files = eval(harp_ds.extra_attributes["csv_files"])
else:
    harp_files = harp_ds.extra_attributes["csv_files"]


# Merge MouseZ and EyeZ from FrameLog.csv to frame_df according to FrameIndex
frame_log = get_frame_log(
    harp_ds.flexilims_session,
    harp_recording=harp_recording,
    vis_stim_recording=recording,
)


# same for SpherePermTubeReward and SpherePermTubeReward_multidepth
frame_log_z = frame_log[["FrameIndex", "HarpTime", "MouseZ", "EyeZ"]].copy()
frame_log_z.rename(
    columns={
        "FrameIndex": "closest_frame",
        "HarpTime": "harptime_framelog",
        "MouseZ": "mouse_z",
        "EyeZ": "eye_z",
    },
    inplace=True,
)

if frame_log_z.closest_frame.isna().any():
    print(
        f"WARNING: {np.sum(frame_log_z.closest_frame.isna())} frames are "
        + "missing from FrameLog.csv. This is likely due to bonsai crash at "
        + "the end."
    )
    frame_log_z = frame_log_z[frame_log_z.closest_frame.notnull()]
    frame_log_z.closest_frame = frame_log_z.closest_frame.astype("int")

merge_on = "closest_frame"

frame_log_z.mouse_z = frame_log_z.mouse_z / 100  # convert cm to m
frame_log_z.eye_z = frame_log_z.eye_z / 100  # convert cm to m

if monitor_frames_df[merge_on].dtype != frame_log_z[merge_on].dtype:
    # print a warning if the merge_on column is not the same type in both dataframes
    warnings.warn(
        f"WARNING: merge_on column {merge_on} is not the same type in both "
        + f"dataframes. monitor_frame_df is {monitor_frames_df[merge_on].dtype} and"
        + f"frame_log_z is {frame_log_z[merge_on].dtype}. Converting to int64."
    )
    # convert both to int64
    monitor_frames_df[merge_on] = monitor_frames_df[merge_on].astype("int64")
    frame_log_z[merge_on] = frame_log_z[merge_on].astype("int64")

vs_df = pd.merge_asof(
    left=monitor_frames_df[["closest_frame", "onset_time"]],
    right=frame_log_z,
    on=merge_on,
    direction="backward",
    allow_exact_matches=True,
)

In [None]:
vs_df.head()

In [None]:
harp_csvs = harp_dataset.extra_attributes["csv_files"]
rewards_logs = {}
newparams_logs = {}

for csvname, filename in harp_csvs.items():
    if "RewardLog" in csvname:
        rewards_logs[csvname] = pd.read_csv(harp_dataset.path_full / filename)
    elif "NewParams" in csvname:
        newparams_logs[csvname] = pd.read_csv(harp_dataset.path_full / filename)
print(
    f"Found {len(rewards_logs)} RewardLog files and {len(newparams_logs)} NewParams files"
)

In [None]:
vs_df.columns

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(15, 5))
t0 = vs_df.harptime_framelog.min()
for idepth, (csv_name, df) in enumerate(rewards_logs.items()):
    depth = int(csv_name.split("_")[-1][:-2])
    rw_time = df["HarpTime"].astype(float) - t0
    ax.plot(rw_time, np.ones_like(rw_time) * idepth, "|", label=depth)

ax.legend()
ax.set_yticks([])
ax.set_xlabel("Time (s)")
ax.set_xlim(2900, 3200)
fig.tight_layout()
ax.set_title("Reward times")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 5))

for idepth, (csv_name, df) in enumerate(newparams_logs.items()):
    depth = int(csv_name.split("_")[-1][:-2])
    radius = df["Radius"].values.astype(float)
    radius[radius < 0] = np.nan
    rw_time = df["HarpTime"].astype(float) - t0
    ax.plot(rw_time, np.log(radius), "|", label=depth)
    r = (df.Radius == -9999).astype(int)
    corridor_end = np.diff(r) == 1
    print(
        f"Depth {depth}: {np.sum(corridor_end)} corridor ends, {np.sum(corridor_end)/8} trials"
    )
ax.set_xlim(2900, 3200)
ax.legend(loc="upper right")
ax.set_xlabel("Time (s)")
ax.set_ylabel("log(Radius)")
ax.set_title("Sphere creation times")
fig.tight_layout()

In [None]:
df.columns

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 5))
t0 = vs_df.monitor_harptime.min()
for idepth, (csv_name, df) in enumerate(newparams_logs.items()):
    depth = int(csv_name.split("_")[-1][:-2])
    rw_time = df["HarpTime"].astype(float) - t0
    ax.plot(rw_time, df.Z0, label=depth)
ax.legend()

In [None]:
vs_df_by_depth = {}
for idepth, (csv_name, param_log) in enumerate(newparams_logs.items()):
    vs_df_depth = vs_df.copy()
    param_log = param_log.rename(columns={"HarpTime": "stimulus_harptime"})
    if "Frameindex" in param_log.columns:
        if param_log.Frameindex.isna().any():
            print(
                f"WARNING: {np.sum(param_log.Frameindex.isna())} frames are missing from ParamLog.csv. This is likely due to bonsai crash at the end."
            )
            param_log = param_log[param_log.Frameindex.notnull()]
            param_log.Frameindex = param_log.Frameindex.astype("int")

    vs_df_depth = pd.merge_asof(
        left=vs_df_depth,
        right=param_log,
        left_on="closest_frame",
        right_on="Frameindex",
        direction="backward",
        allow_exact_matches=True,
    )
    # Rename
    vs_df_depth.rename(
        columns={
            "closest_frame": "monitor_frame",
            "onset_time": "monitor_harptime",
        },
        inplace=True,
    )
    vs_df_depth.drop(
        columns=[
            "harptime_framelog",
            "harptime_sphere",
            "harptime_imaging_trigger",
            "offset_time",
            "peak_time",
        ],
        errors="ignore",
        inplace=True,
    )
    vs_df_by_depth[depth] = vs_df_depth

In [None]:
regenerate_frames(
    frame_times,
    trials_df,
    vs_df,
    param_logger,
    time_column="HarpTime",
    resolution=1,
    sphere_size=10,
    azimuth_limits=(-120, 120),
    elevation_limits=(-40, 40),
    verbose=True,
    output_datatype="int16",
    output=None,
)

In [None]:
r = (df.Radius == -9999).astype(int)
corridor_end = np.diff(r) == 1
plt.plot(r)
plt.plot(corridor_end)
print(np.sum(corridor_end))
print(np.sum(corridor_end) / 8)
plt.xlim(0, 100)

In [None]:
        param_log = get_param_log(
            flexilims_session=flexilims_session,
            harp_recording=harp_recording,
            vis_stim_recording=recording,
        )
        # TODO COPY FROM RAW AND READ FROM PROCESSED INSTEAD
        param_log = param_log.rename(columns={"HarpTime": "stimulus_harptime"})
        if "Frameindex" in param_log.columns:
            if param_log.Frameindex.isna().any():
                print(
                    f"WARNING: {np.sum(param_log.Frameindex.isna())} frames are missing from ParamLog.csv. This is likely due to bonsai crash at the end."
                )
                param_log = param_log[param_log.Frameindex.notnull()]
                param_log.Frameindex = param_log.Frameindex.astype("int")

        # TODO: check if that shouldn't also happen for protocol_base == "KellerTube"
        if photodiode_protocol == 5:
            vs_df = pd.merge_asof(
                left=vs_df,
                right=param_log,
                left_on="closest_frame",
                right_on="Frameindex",
                direction="backward",
                allow_exact_matches=True,
            )
