# 5. Motion correction, followed by Whiten
https://spikeinterface.readthedocs.io/en/stable/how_to/handle_drift.html

In [None]:
import spikeinterface.full as si
import matplotlib.pyplot as plt
import os
from pathlib import Path
import numpy as np

# Setting file paths and basic parameters
base_folder = Path('D:/Ephys_C2DRG/')
data_folder = Path("D:/Ephys_C2DRG/2023_9_19/")

recording = si.load_extractor()

n_cpus = os.cpu_count()
n_jobs = n_cpus #n_jobs = -1 :equal to the number of cores.
job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=True)

In [None]:
preset_params_rf = si.get_motion_parameters_preset("rigid_fast")
preset = "rigid_fast"
preset_params_rf = ({'direction': 'y', 'rigid': True, 'win_shape': 'rect', 'win_step_um': 100.0, 
                     'win_scale_um': 150.0, 'win_margin_um': None, 'method': 'dredge_ap', 'bin_s': 5.0})
preset_params_rf

In [None]:
print("Computing with", preset)
folder = data_folder / "motion_folder_dataset" / preset
recording_corrected, motion, motion_info = si.correct_motion(recording, preset=preset, folder=folder, output_motion=True, 
                                                                 output_motion_info=True, estimate_motion_kwargs=preset_params_rf, **job_kwargs)

In [None]:
folder = data_folder / "motion_folder_dataset" / preset
motion_info = si.load_motion_info(folder)
# plot motion
fig = plt.figure(figsize=(14, 8))
si.plot_motion_info(motion_info, recording, figure=fig, depth_lim=(-50, 300), 
                    color_amplitude=True, amplitude_cmap="inferno", scatter_decimate=10)
fig.suptitle(f"{preset=}")

In [None]:
from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks

folder = data_folder / "motion_folder_dataset" / preset
motion_info = si.load_motion_info(folder)
motion = motion_info["motion"]

fig, axs = plt.subplots(ncols=2, figsize=(8, 6), sharey=True)
ax = axs[0]
si.plot_probe_map(recording, ax=ax)
peaks = motion_info["peaks"]
sr = recording.get_sampling_frequency()
time_lim0 = 0.0
time_lim1 = 6000.0
mask = (peaks["sample_index"] > int(sr * time_lim0)) & (peaks["sample_index"] < int(sr * time_lim1))
sl = slice(None, None, 5)
amps = np.abs(peaks["amplitude"][mask][sl])
amps /= np.quantile(amps, 0.95)
c = plt.get_cmap("inferno")(amps)

color_kargs = dict(alpha=0.5, s=2, c=c)

peak_locations = motion_info["peak_locations"]
ax.scatter(peak_locations["x"][mask][sl], peak_locations["y"][mask][sl], **color_kargs)
ax.set_ylim(-75, 250)
ax.set_xlim(-45, 80)

peak_locations2 = correct_motion_on_peaks(peaks, peak_locations, motion, recording_corrected)
ax = axs[1]
si.plot_probe_map(recording, ax=ax)
ax.scatter(peak_locations2["x"][mask][sl], peak_locations2["y"][mask][sl], **color_kargs)
ax.set_ylim(-75, 320)
ax.set_xlim(-45, 80)
fig.suptitle(f"{preset=}")