# Preprocessing of revision experiments

This notebook contains the preprocessing steps for the revision experiments. It lists
the expected sessions and checks that they exists on flexilims

In [None]:
sessions = {
    "PZAG17.3a_S20250402": "motor",
    "PZAG17.3a_S20250319": "multidepth",
    "PZAG17.3a_S20250306": "spheretube_5",
    "PZAG17.3a_S20250305": "spheretube_4",
    "PZAG17.3a_S20250303": "spheretube_3",
    "PZAG17.3a_S20250228": "spheretube_2",
    "PZAG17.3a_S20250227": "spheretube_1",
    "PZAG16.3c_S20250401": "motor",
    "PZAG16.3c_S20250317": "multidepth",
    "PZAG16.3c_S20250313": "spheretube_5",
    "PZAG16.3c_S20250310": "spheretube_4",
    "PZAG16.3c_S20250221": "spheretube_3",
    "PZAG16.3c_S20250220": "spheretube_2",
    "PZAG16.3c_S20250219": "spheretube_1",
    "PZAG16.3b_S20250401": "motor",
    "PZAG16.3b_S20250317": "multidepth",
    "PZAG16.3b_S20250313": "spheretube_5",
    "PZAG16.3b_S20250310": "spheretube_4",
    "PZAG16.3b_S20250226": "spheretube_3",
    "PZAG16.3b_S20250225": "spheretube_2",
    "PZAG16.3b_S20250224": "spheretube_1",
    "PZAH17.1e_S20250403": "motor",
    "PZAH17.1e_S20250318": "multidepth",
    "PZAH17.1e_S20250313": "multidepth",
    "PZAH17.1e_S20250311": "spheretube_5",
    "PZAH17.1e_S20250307": "spheretube_4",
    "PZAH17.1e_S20250306": "spheretube_3",
    "PZAH17.1e_S20250305": "spheretube_2",
    "PZAH17.1e_S20250304": "spheretube_1",
}

print(f"{len(sessions)} sessions to analyze")

In [None]:
import flexiznam as flz

project = "colasa_3d-vision_revisions"
flm_sess = flz.get_flexilims_session(project_id=project)

valid_sessions = dict()
for session, protocol in sessions.items():
    sess = flz.get_entity(name=session, project_id=project, datatype="session")
    if sess is None:
        print(f"Session {session} doesn't exist")
        continue
    valid_sessions[session] = [sess, protocol]
print(f"{len(valid_sessions)}/{len(sessions)} valid sessions to analyze")

In [None]:
# Now check that we have the recordings we need
for session_name, (session, protocol) in valid_sessions.items():
    # Get recordings children of the session
    recordings = flz.get_children(
        session.id, children_datatype="recording", flexilims_session=flm_sess
    )
    if not len(recordings):
        print(f"No recordings for session {session_name}")
        continue
    assert (
        "SpheresPermTubeReward" in recordings.protocol.values
    ), f"Session {session_name} doesn't have the Sphere protocol"
    if protocol == "motor":
        assert (
            "SpheresTubeMotor" in recordings.protocol.values
        ), f"Session {session_name} doesn't have the motor protocol"
    elif protocol == "multidepth":
        assert (
            "SpheresPermTubeReward_multidepth" in recordings.protocol.values
        ), f"Session {session_name} doesn't have the multidepth protocol"

In [None]:
# There should be a suite2p dataset for each session
bad_sessions = []
for session_name, (session, protocol) in valid_sessions.items():
    # Get recordings children of the session
    suite2p_dataset = flz.get_entity(
        project_id=project,
        datatype="dataset",
        origin_id=session.id,
        query_key="dataset_type",
        query_value="suite2p_rois",
    )
    if suite2p_dataset is None:
        print(f"Session {session_name} doesn't have a suite2p dataset")
        bad_sessions.append(session_name)

print(f"{len(bad_sessions)} sessions don't have a suite2p dataset")
# remove them from valid_sessions
for session_name in bad_sessions:
    del valid_sessions[session_name]
print(f"{len(valid_sessions)} sessions to analyze after removing bad sessions")

In [None]:
# Save the mean enhanced images
import numpy as np
import tifffile as tiff

SAVE_MAX = True
SAVE_QUANTILE = False
for session_name, (session, protocol) in valid_sessions.items():
    suite2p_dataset = flz.get_entity(
        project_id=project,
        datatype="dataset",
        origin_id=session.id,
        query_key="dataset_type",
        query_value="suite2p_rois",
    )
    suite2p_dataset = flz.Dataset.from_dataseries(suite2p_dataset)
    try:
        ops = np.load(
            suite2p_dataset.path_full / "plane0" / "ops.npy", allow_pickle=True
        ).item()
    except EOFError:
        print(f"Could not load ops.npy for session {session_name}, skipping.")
        continue
    cellpose_folder = suite2p_dataset.path_full.parent.parent.parent / "cellpose_data"
    cellpose_folder.mkdir(exist_ok=True)
    for img_name in ["meanImgE", "meanImg"]:
        img = ops[img_name]
        tiff.imwrite(
            suite2p_dataset.path_full / "plane0" / f"{session_name}_{img_name}.tif", img
        )
        tiff.imwrite(cellpose_folder / f"{session_name}_{img_name}.tif", img)
    if SAVE_MAX:
        print(f"Saving max projection for session {session_name}")
        binary = suite2p_dataset.path_full / "plane0" / "data.bin"
        target = suite2p_dataset.path_full / "plane0" / f"{session_name}_max.tif"
        if target.exists():
            print(
                f"Max projection already exists for session {session_name}, skipping."
            )
            continue
        data = np.memmap(binary, dtype="uint16", mode="r")
        data.shape = (ops["nframes"], ops["Ly"], ops["Lx"])
        max_data = np.max(data, axis=0)
        tiff.imwrite(
            suite2p_dataset.path_full / "plane0" / f"{session_name}_max.tif", max_data
        )
        tiff.imwrite(cellpose_folder / f"{session_name}_max.tif", max_data)
    if SAVE_QUANTILE:
        print(f"Saving 99.9th percentile projection for session {session_name}")
        target = (
            suite2p_dataset.path_full / "plane0" / f"{session_name}_max_q{99.9:.1f}.tif"
        )
        if target.exists():
            print(
                f"Max projection already exists for session {session_name}, skipping."
            )
            continue
        binary = suite2p_dataset.path_full / "plane0" / "data.bin"
        data = np.memmap(binary, dtype="uint16", mode="r")
        data.shape = (ops["nframes"], ops["Ly"], ops["Lx"])
        # average across frames
        max_data = np.percentile(data, axis=0, q=[99.9])
        for q, img in zip([99.9], max_data):
            tiff.imwrite(
                suite2p_dataset.path_full
                / "plane0"
                / f"{session_name}_max_q{q:.1f}.tif",
                img,
            )
            tiff.imwrite(
                cellpose_folder / f"{session_name}_max_q{q:.1f}.tif",
                img,
            )
    print(f"Done with session {session_name}")

In [None]:
# To re-run suite2p see `revisions/run_suite2p.py`

# It is in a separate file as it uses a different conda environment (because of suite2p)

In [None]:
# To run a single session (for debugging mostly)
from cottage_analysis.pipelines import analysis_pipeline

if True:
    session_name = "PZAH17.1e_S20250313"
    analysis_pipeline.main(
        project,
        session_name,
        conflicts="overwrite",
        photodiode_protocol=5,
        use_slurm=True,
        run_depth_fit=True,
        run_rf=True,
        run_rsof_fit=True,
        run_plot=True,
        protocol_base="SpheresPermTubeReward",
    )

In [None]:
from cottage_analysis.pipelines import pipeline_utils

running = []
sess_to_do = ["PZAG16.3b_S20250313"]
if False:
    for session_name in valid_sessions:
        if (sess_to_do is not None) and not (session_name in sess_to_do):
            continue
        protocol = sessions[session_name]
        if protocol.startswith("spheretube"):
            continue
        if session_name in running:
            print(f"Session {session_name} is already running")
            continue
        print(f"Submitting session {session_name} to the pipeline ({protocol})")
        pipeline_utils.sbatch_session(
            project=project,
            session_name=session_name,
            pipeline_filename="run_analysis_pipeline.sh",
            conflicts="overwrite",
            photodiode_protocol=5,
        )

In [None]:
bad = "PZAH17.1e_S20250318"

sess, protocol = valid_sessions[bad]
sess

In [None]:
import tifffile

ds = flz.Dataset.from_flexilims(
    id="67ed63e5b99b5006b4e789b5", flexilims_session=flm_sess
)
ds

In [None]:
img = tifffile.imread(ds.path_full / ds.extra_attributes["tif_files"][0])
img.shape

In [None]:
flm_sess.delete(ds.id)