# Preprocessing trial acquisitions (calcium imaging)
* Creating summed images
* Rigid motion correction (phase correlation)
* Deformable motion correction (optical flow)



In [5]:
import numpy as np
import os

from matplotlib import pyplot as plt
import napari

from math import sqrt
from scipy.ndimage import shift
from skimage.registration import phase_cross_correlation, optical_flow_ilk
from skimage.transform import warp
from skimage.exposure import match_histograms

from scripts.sample_db import SampleDB
from scripts.config_model import update_experiment_config, save_experiment_config, tree
from scripts.utils.image_utils import load_tiff_as_hyperstack, save_array_as_hyperstack_tiff
from tifffile import imwrite, imread

# Step 1: Load the sample database
db_path = r'\\tungsten-nas.fmi.ch\tungsten\scratch\gfriedri\montruth\sample_db.csv'
sample_db = SampleDB()
sample_db.load(db_path)
print(sample_db)

# Step 2: Load experiment configuration
sample_id = '20220426_RM0008_130hpf_fP1_f3'
exp = sample_db.get_sample(sample_id)
tree(exp)

# Step 3: Making shortcuts of sample parameters/information
sample = exp.sample
root_path = exp.paths.root_path
trials_path = exp.paths.trials_path
anatomy_path = exp.paths.anatomy_path
em_path = exp.paths.em_path
n_planes = exp.params_lm.n_planes
n_frames =  exp.params_lm.n_frames
n_slices = exp.params_lm.lm_stack_range
doubling = 2 if exp.params_lm.doubling else 1

# Calculating number of frames per trial (TODO: add it to config file)
n_frames_trial = n_frames // n_planes
exp.params_lm["n_frames_trial"]= n_frames_trial

# Getting paths of the trial acquisitions
raw_trial_paths = os.listdir(os.path.join(trials_path,"raw"))
print(raw_trial_paths)
n_trials = len(raw_trial_paths)
exp.params_lm["n_trials"]= n_trials

# Step 4: Load or skip loading trial acquisitions and computing sum (from ignore frame)
ignore_until_frame = exp.params_lm.shutter_delay_frames # edit if you want to avoid summing out motor movements at the beginning of acquisition

# Define the path for the preprocessed folder
processed_folder = os.path.join(trials_path, "processed")
os.makedirs(processed_folder, exist_ok=True)

ref_images_path = os.path.join(processed_folder, f"sum_raw_trials_{sample.id}.tif")

if os.path.exists(ref_images_path):
    ref_images = imread(ref_images_path)
    print("Reference trial images loaded")
else:
    print("Starting sum of trials")
    ref_images = np.stack([load_tiff_as_hyperstack(os.path.join(trials_path, "raw", trial_path), n_channels=1, n_slices=n_planes, doubling=True)[:,ignore_until_frame:,:,:].sum(axis=1) for trial_path in raw_trial_paths], axis=1)
    save_array_as_hyperstack_tiff(ref_images_path, ref_images)

print(ref_images.shape)

# Step 5: Compute phase correlation for each frame against the reference or load existing parameters
rigid_params_path = os.path.join(processed_folder, "rigid_params.npy")

if os.path.exists(rigid_params_path):
    rigid_params = np.load(rigid_params_path)
    print("Rigid parameters loaded")
else:
    print("Starting rigid parameters computation")
    rigid_params = np.zeros((n_planes*doubling, n_trials, 2))
    total_motion = np.zeros((n_planes*doubling, n_trials))

    for plane in range(n_planes*doubling):
        print(f"Processing plane {plane}")
        for ii in range(n_trials):
            X = phase_cross_correlation(ref_images[plane, 0, :, :], ref_images[plane, ii, :, :], upsample_factor=5, space='real')
            rigid_params[plane, ii, 0] = X[0][0]  # x-displacement
            rigid_params[plane, ii, 1] = X[0][1]  # y-displacement
            total_motion[plane, ii] = sqrt(X[0][0] ** 2 + X[0][1] ** 2)

    # Plot total motion and shifts
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    for plane in range(n_planes*doubling):
        plt.plot(total_motion[plane,:], label=f'Plane {plane}')
    plt.title("Total Motion Over Time")
    plt.xlabel("Frame Index")
    plt.ylabel("Total Motion (Euclidean Distance)")
    plt.legend()
    plt.show()

    np.save(rigid_params_path, rigid_params)


# Step 6: Align trial images to reference or load existing aligned frames
aligned_frames_path = os.path.join(processed_folder, f"sum_rigid_corrected_trials_{exp.sample.id}.tiff")

if os.path.exists(aligned_frames_path):
    aligned_frames = imread(aligned_frames_path)
    print("Aligned frames loaded")
else:
    print("Starting frame alignment")
    aligned_frames = np.zeros_like(ref_images)

    for plane in range(n_planes*doubling):
        print(f"Aligning plane {plane}")
        for trial in range(n_trials):
            current_frame = ref_images[plane, trial, :, :]
            shift_values = (rigid_params[plane, trial])
            if len(shift_values) != current_frame.ndim:
                raise ValueError("shift_values length must match the number of dimensions of current_frame")

            shifted_frame = shift(current_frame, shift_values, order=3, prefilter=True)
            aligned_frames[plane, trial, :, :] = match_histograms(shifted_frame, ref_images[plane, trial, :, :])

    save_array_as_hyperstack_tiff(aligned_frames_path, aligned_frames)

# Visualize the alignment
viewer = napari.Viewer()
viewer.add_image(ref_images, name='ref_frames')
viewer.add_image(aligned_frames, name='aligned_images')

# Step 7: Compute the optical flow with reference to the first image or load existing parameters
elastic_params_path = os.path.join(processed_folder, "elastic_params.npy")

if os.path.exists(elastic_params_path):
    elastic_params = np.load(elastic_params_path)
    print("Elastic parameters loaded")
else:
    print("Starting elastic parameters computation")
    elastic_params = np.zeros((n_planes*doubling, n_trials, 2, ref_images.shape[-2], ref_images.shape[-1]))
    warped_movie = np.zeros_like(aligned_frames)

    for plane in range(n_planes*doubling):
        ref_image = aligned_frames[plane, 0]
        nr, nc = ref_image.shape
        print(f"Starting plane {plane}")

        for trial in range(n_trials):
            print(f"Warping frame {trial+1}/{n_trials}")
            frame = aligned_frames[plane, trial, :, :]
            v, u = optical_flow_ilk(ref_image, frame, radius=15)
            row_coords, col_coords = np.meshgrid(np.arange(nr), np.arange(nc), indexing='ij')
            warped_frame = warp(frame, np.array([row_coords + v, col_coords + u]), mode='edge')
            warped_frame = match_histograms(warped_frame, ref_image)
            warped_movie[plane, trial, :, :] = warped_frame
            elastic_params[plane, trial, 0] = v
            elastic_params[plane, trial, 1] = u

    np.save(elastic_params_path, elastic_params)
    elastic_corrected_path = os.path.join(processed_folder, f"sum_elastic_corrected_trials_{exp.sample.id}.tiff")
    save_array_as_hyperstack_tiff(elastic_corrected_path, warped_movie)



SampleDB(sample_ids=['20220426_RM0008_130hpf_fP1_f3'])
sample: <class 'scripts.config_model.Sample'>
    id: <class 'str'>
    parents_id: typing.Optional[str]
    genotype: typing.Optional[str]
    phenotype: typing.Optional[str]
    dof: <class 'str'>
    hpf: <class 'int'>
    body_length_mm: typing.Optional[int]
params_odor: typing.Optional[scripts.config_model.ParamsOdor]
    odor_list: typing.List[str]
    odor_sequence: typing.List[str]
    odor_concentration_uM: typing.List[scripts.config_model.OdorConcentration]
        name: <class 'str'>
        concentration_mM: <class 'float'>
    n_trials: <class 'int'>
    pulse_delay_s: <class 'int'>
    pulse_duration_s: <class 'int'>
    trial_interval_s: <class 'int'>
    missed_trials: typing.List
    events: typing.List[typing.Tuple[str, datetime.datetime]]
params_lm: typing.Optional[scripts.config_model.ParamsLM]
    start_time: <class 'datetime.datetime'>
    end_time: <class 'datetime.datetime'>
    date: typing.Optional[datetim

NameError: name 'warped_movie' is not defined

In [None]:
# Visualize the warped movie
import concurrent.futures

# Step 8: Parallelize the processing of each trial
def process_trial(trial_idx, trial_path):
    print(f"  Processing trial {trial_idx + 1}/{n_trials}")
    raw_movie = load_tiff_as_hyperstack(os.path.join(trials_path, "raw", trial_path), n_slices=exp.params_lm.n_planes, doubling=True)

    n_planes, n_frames, height, width = raw_movie.shape
    print(raw_movie.shape)
    transformed_movie = np.zeros_like(raw_movie, dtype=np.float32)

    for plane in range(n_planes):
        print(f"  Processing plane {plane + 1}/{n_planes}")
        Xs, Ys = rigid_params[plane, trial_idx, 0], rigid_params[plane, trial_idx, 1]
        v, u = elastic_params[plane, trial_idx, 0], elastic_params[plane, trial_idx, 1]

        for frame in range(n_frames):
            shifted_frame = shift(raw_movie[plane, frame, :, :].astype(np.float32), (Xs, Ys), order=3, prefilter=True)
            row_coords, col_coords = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
            warped_frame = warp(shifted_frame, np.array([row_coords + v, col_coords + u]), mode='edge')
            matched_frame = match_histograms(warped_frame, raw_movie[plane, frame, :, :].astype(np.float32))
            transformed_movie[plane, frame, :, :] = matched_frame

    #transformed_movie_uint16 = np.clip(transformed_movie, 0, 65535).astype(np.uint16)
    transformed_trial_path = os.path.join(processed_folder, f"motion_corrected_{trial_path}")
    save_array_as_hyperstack_tiff(transformed_trial_path, transformed_movie)
    print(f"Transformed trial {trial_path} saved at {transformed_trial_path}")

with concurrent.futures.ThreadPoolExecutor() as executor:
    futures = [executor.submit(process_trial, trial_idx, trial_path) for trial_idx, trial_path in enumerate(raw_trial_paths)]
    for future in concurrent.futures.as_completed(futures):
        future.result()
"""
# Helper function to interleave planes
def interleave_planes(array):
    planes, time, height, width = array.shape
    interleaved_array = array.transpose(1, 0, 2, 3).reshape(planes * time, height, width)
    return interleaved_array

transformed_movie_interleaved = interleave_planes(transformed_movie)
"""

# Visualize the transformed movie
#viewer.add_image(transformed_movie, name='transformed_movie')
#viewer.add_image(raw_movie, name='raw_movie')

# Save experiment configuration
save_experiment_config(exp, exp.paths.config_path)
tree(exp)


  Processing trial 1/24
  Processing trial 2/24
  Processing trial 3/24
  Processing trial 4/24
  Processing trial 5/24
  Processing trial 6/24
  Processing trial 7/24
  Processing trial 8/24
  Processing trial 9/24
  Processing trial 10/24
  Processing trial 11/24
  Processing trial 12/24
\\tungsten-nas.fmi.ch\\tungsten\\scratch\\gfriedri\\montruth\\2P_RawData\\2022-04-26\\f3\trials\raw\20220426_RM0008_130hpf_fP1_f3_t1_o3Ctrl_001_.tif loaded.
(8, 375, 256, 512)
  Processing plane 1/8
  Processing plane 2/8


In [None]:
tree(exp)