In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import tadpose as tp

import skimage as ski
from tqdm.auto import tqdm
from scipy import ndimage as ndi
from moviepy import VideoFileClip, clips_array


## Analysis of Straight-Line Movement Episodes

### 1. Find Straight-Line Movement Episdodes

To identify episodes in which the frogs moved approximately in a straight line, we first filtered the video frames where the frog's body-axis vector was aligned with its movement vector, indicating active forward motion. This was achieved by projecting the frog's unit displacement vectors onto its body axis, defined by the unit vector from `Tail_Stem` to `Heart_Center`. The resulting forward correlation was denoised using a median filter with a window size of 3 frames (1/20 sec) and thresholded at a **minimum correlation of 0.8**. Following thresholding, the candidate episodes were further refined by requiring a **minimum duration of 90 frames** (1.5 sec), a **minimum average speed of 1.2 cm/sec**, and a **minimum trajectory confinement ratio of 0.95**. The confinement ratio is defined as the ratio of the net, Eucledian distance traveled to the actual distance traveled along its trajectory.

### 2. Analysis of Straight-Line Movement Episdodes

For each found straight-line episodes, we computed the features as described in the [analysis](../analysis/) section:

* [Basic Locomotion](../doc/locomotion.md)
* [Area explored](../doc/area_explored.md)
* [Angle ranges](../doc/angle_range.md)
* [Angle correlation](../doc/angle_correlation.md)
* [Frequency](../doc/frequency.md)

All statistics for features (such as mean, std, p95, etc) are then computed over each respective episode.



### Functions

In [None]:
def mask2episode(mask, min_duration=60):
    """
    Convert a 1D boolean mask (frames) into labeled episodes.
    Keeps only regions with area (duration) >= min_duration.
    Returns a 1D integer label array (0 = background).
    """
    rps = ski.measure.regionprops(ski.measure.label(mask[:, None]))
    rps = [rp for rp in rps if rp.area >= min_duration]

    mask_out = np.zeros_like(mask, dtype=bool)
    for rp in rps:
        mask_out[rp.slice[0]] = True
    return ski.measure.label(mask_out)


def episode2bounds(epi):
    """
    Convert a labeled episode array to a list of (start, stop) frame bounds.
    epi: 1D integer label array (0 = background).
    """
    rps = ski.measure.regionprops(epi[:, None])
    return [(rp.slice[0].start, rp.slice[0].stop) for rp in rps]


def filter_episodes_by_speed(tad, episodes, min_speed_cm_s=1.2, fps=60):
    """
    Remove episodes from `episodes` whose mean speed is below min_speed_cm_s.
    Modifies `episodes` in-place.
    """
    tad_speed = (
        tad.speed(part="Heart_Center", track_idx=0, pre_sigma=1, sigma=5)
        * tp.utils.calibrate_by_dish(tad, 14)
        * fps
    )
    for e_id in np.unique(episodes[episodes > 0]):
        mean_speed = tad_speed[episodes == e_id].mean()
        if mean_speed < min_speed_cm_s:
            print(f" - Remove episode {e_id}: Speed={mean_speed:0.3f}")
            episodes[episodes == e_id] = 0


def filter_episodes_by_confinement_ratio(tad, episodes, min_confinment_ratio):
    """
    Remove episodes whose confinement ratio (net_dist / total_dist) is below threshold.
    Modifies `episodes` in-place.
    """
    c_locs = tad.locs(parts=("Heart_Center",)).squeeze()

    for e_id in np.unique(episodes[episodes > 0]):
        episode_indicies = np.nonzero(episodes == e_id)[0]
        start, stop = np.min(episode_indicies), np.max(episode_indicies)
        net_dist = np.linalg.norm(c_locs[stop] - c_locs[start])
        total_dist = np.linalg.norm(np.diff(c_locs[start:stop], axis=0), axis=1).sum()
        cr = net_dist / total_dist if total_dist > 0 else 0.0
        if cr < min_confinment_ratio:
            print(f" - Remove episode {e_id}: CR={cr:0.3f}")
            episodes[episodes == e_id] = 0


def straight_line_episodes(
    tad, min_corr, min_duration, min_speed_cm_s, min_cofinment_ratio
):
    """
    Detect straight-line movement episodes in a Tadpole object.

    Returns a labeled episode array (1D integers).
    """
    central_part, _ = tad.aligner.bodyparts_to_align
    c_locs = tad.locs(parts=(central_part,), fill_missing=False).squeeze()

    body_heading_vecs = tad.aligner.heading_vectors
    center_displ_vecs = np.gradient(c_locs, axis=0)

    # safe normalization (avoid divide-by-zero)
    bh_norm = np.linalg.norm(body_heading_vecs, axis=1)
    bh_norm[bh_norm == 0] = 1.0
    body_heading_vecs = body_heading_vecs / bh_norm[:, None]

    cd_norm = np.linalg.norm(center_displ_vecs, axis=1)
    cd_norm[cd_norm == 0] = 1.0
    center_displ_vecs = center_displ_vecs / cd_norm[:, None]

    corr = np.vecdot(body_heading_vecs, center_displ_vecs)
    corr_med = ndi.median_filter(corr, 3)

    sl_mov_epi = mask2episode(corr_med > min_corr, min_duration=min_duration)

    filter_episodes_by_speed(tad, sl_mov_epi, min_speed_cm_s=min_speed_cm_s)
    filter_episodes_by_confinement_ratio(tad, sl_mov_epi, min_cofinment_ratio)

    return ski.measure.label(sl_mov_epi)


def generate_montage_movie(file_list, shape):
    """
    Create a montage movie from a list of video file paths.
    shape: (rows, cols) and must match len(file_list).

    Returns a moviepy clip (array montage).
    """
    assert len(shape) == 2
    assert int(np.prod(shape)) == len(file_list)

    clips = [VideoFileClip(file) for file in file_list]
    clips = sorted(clips, key=lambda c: c.n_frames, reverse=True)
    array = np.reshape(clips, shape)
    return clips_array(array)


### Find and Export Straight-Line Episodes

In [None]:
WT_ROOT = Path("B:/group/sweengrp/Behavior/Behavior_MNV1/MNV1_dataset/Juv/WT")
OUT_DIR = Path("./Juv/")

MIN_CORR = 0.8
MIN_DURATION = 90  # frames
MIN_SPEED_CM_S = 1.2
MIN_CONFINMENT_RATIO = 0.95

OUT_DIR.mkdir(exist_ok=True)

df_result = []

for mov_path in (pbar := tqdm(list(WT_ROOT.glob("*.mp4")))):
    print("Finding SL-Episodes for", mov_path.name)
    pbar.set_description(mov_path.name)
    tad = tp.Tadpole.from_sleap(str(mov_path))
    tad.aligner = tp.alignment.RotationalAligner(
        central_part="Tail_Stem", aligned_part="Heart_Center"
    )
    sl_epis = straight_line_episodes(
        tad,
        min_corr=MIN_CORR,
        min_duration=MIN_DURATION,
        min_speed_cm_s=MIN_SPEED_CM_S,
        min_cofinment_ratio=MIN_CONFINMENT_RATIO,
    )

    for ei, ei_bounds in enumerate(episode2bounds(sl_epis)):
        print(f" - {ei}")
        img_gen = tad.bbox_image_gen(
            frames=ei_bounds,
            center_part="Heart_Center",
            dest_height=300,
            dest_width=300,
        )
        out_fn = mov_path.stem + f"__e{ei}_{ei_bounds[0]}-{ei_bounds[1]}.mp4"
        out_fn_path = OUT_DIR / out_fn

        tp.utils.write_gen_video(
            fn=out_fn_path,
            gen=img_gen,
        )

        df_result.append(
            {
                "video_fn": str(mov_path),
                "track_idx": 0,
                "stage": "Juv",
                "episode_id": ei,
                "episode_start": ei_bounds[0],
                "episode_stop": ei_bounds[1],
            }
        )
tab = pd.DataFrame(df_result)
tab.to_csv(OUT_DIR / "straight_line_episodes.tab", sep="\t")

### Create Overview Movies of all Episodes

In [None]:
def get_square_shape(n, s_max=10):
    s1 = int(np.sqrt(n))
    s1 = min(s1, s_max)
    s2 = n // s1
    s2 = min(s2, s_max)
    return (s1, s2)


root_dir = Path(".")

for wt_dir in root_dir.iterdir():
    if wt_dir.is_dir() and wt_dir.name.startswith("WT"):
        print("Generating montage for", wt_dir)
        file_list = sorted(
            wt_dir.glob("*.mp4"), key=lambda p: p.stat().st_size, reverse=True
        )
        n = len(file_list)
        shape = get_square_shape(n, s_max=12)
        n = int(np.prod(shape))
        out_clip = generate_montage_movie(file_list[:n], shape=shape)
        out_clip.write_videofile(wt_dir / f"{wt_dir}_montage.mp4")


### Analyze Straight-Line Episodes

In [None]:
from main import analyze_episodes
import yaml

cfg_fn = "../MNV1_SETTINGS_final.yaml"
epi_fn = "straight_line_episodes.tab"

with open(cfg_fn, "r") as ymlfile:
    cfg = yaml.safe_load(ymlfile)

analyze_episodes(epi_fn, cfg)