In [2]:
import IO
#from dff_helper import process_dff_trials
from pathlib import Path


In [3]:
def load_config(config_path: str | Path) -> dict:
    config_path = Path(config_path)
    with config_path.open("r") as f:
        cfg = yaml.safe_load(f)
    return cfg

import yaml

cfg = load_config(r"C:\Users\tobiasleva\Work\WF_analysis\wf_analysis\config_example.yaml")

dataset_id = cfg["dataset_id"]
session_id = cfg["session_id"]
protocol_list = cfg["protocols"]

raw_root = cfg["paths"]["raw_root"]
processed_root = cfg["paths"]["processed_root"]

dff_cfg = cfg["dff"]
roi_cfg = cfg["roi"]
trace_cfg = cfg["trace"]

dataset_id = cfg["dataset_id"]
session_id = cfg["session_id"]
protocol_list = cfg["protocols"]

raw_root = cfg["paths"]["raw_root"]
processed_root = cfg["paths"]["processed_root"]

dff_cfg = cfg["dff"]
roi_cfg = cfg["roi"]
trace_cfg = cfg["trace"]

In [None]:
start =  cfg["stimulus"]["start"]
end =  cfg["stimulus"]["end"]

In [4]:
from pathlib import Path
import numpy as np
from typing import List, Dict, Any


def load_all_dffs_for_session(
    dataset_id: str,
    session_id: str,
    processed_root: str | Path,
    protocol_list: list[str] | None = None,
    load_arrays: bool = True,
) -> List[Dict[str, Any]]:
    """
    Load all DFF movies for a given animal + session.

    Matches exactly how the dF/F stage writes data:
        processed_root/dataset/session/protocol/trial_id/dff.npy

    Parameters
    ----------
    dataset_id : str
        Animal ID, e.g. "JPCM-08704"
    session_id : str
        Session ID, e.g. "250905_leica"
    processed_root : str or Path
        Root processed directory
    protocol_list : list[str] or None
        If None, auto-discover protocols
    load_arrays : bool
        If False, only paths + metadata are returned

    Returns
    -------
    entries : list of dict
        Each dict contains:
            {
                "protocol": str,
                "trial_id": str,
                "path": Path,
                "dff": np.ndarray | None
            }
    """

    processed_root = Path(processed_root)
    session_dir = processed_root / dataset_id / session_id

    if not session_dir.exists():
        raise FileNotFoundError(f"Session not found: {session_dir}")

    # Auto-discover protocols if not provided
    if protocol_list is None:
        protocol_list = sorted(
            p.name for p in session_dir.iterdir() if p.is_dir()
        )

    entries: List[Dict[str, Any]] = []

    for protocol in protocol_list:
        proto_dir = session_dir / protocol
        if not proto_dir.exists():
            continue

        for dff_path in sorted(proto_dir.glob("*/dff.npy")):
            trial_id = dff_path.parent.name

            entry = {
                "protocol": protocol,
                "trial_id": trial_id,
                "path": dff_path,
                "dff": None,
            }

            if load_arrays:
                entry["dff"] = np.load(dff_path)

            entries.append(entry)

    if not entries:
        print("⚠️ No dff.npy files found.")

    return entries


In [5]:
entries = load_all_dffs_for_session(
    dataset_id="JPCM-08704",
    session_id="250905_leica",
    processed_root="Z:/Individual_Folders/Tobi/WF_axonimaging/axonal_imaging_tobi/data/processed",
)

In [6]:
movie = entries[0]["dff"]     # shape: (T, H, W)
protocol = entries[0]["protocol"]
trial_id = entries[0]["trial_id"]


In [7]:
movie.shape

(300, 1024, 1024)

In [8]:
movies = [entries[i]["dff"] for i in range(len(entries))]

In [9]:
len(movies)

10

In [12]:
from IO import read_trial_data
from utils import group_trials_by_stimulus

res_list = read_trial_data(
    dataset_id="JPCM-08704",
    session_id="250905_leica",
    base_dir=raw_root,
    protocol_list=[protocol],
)

# 2) Group trials by stimulus
stim_groups = group_trials_by_stimulus(
    res_list=res_list,
    dffs=movies,
    start = 5000,
    end=7000
    
)


--- Reading protocol: 22_42_interleaved ---
Reading 250905_133500 (Z:\Individual_Folders\Tobi\WF_axonimaging\axonal_imaging_tobi\data\raw/JPCM-08704/250905_leica/22_42_interleaved\250905_133500\data.h5)
Reading 250905_133530 (Z:\Individual_Folders\Tobi\WF_axonimaging\axonal_imaging_tobi\data\raw/JPCM-08704/250905_leica/22_42_interleaved\250905_133530\data.h5)
Reading 250905_133600 (Z:\Individual_Folders\Tobi\WF_axonimaging\axonal_imaging_tobi\data\raw/JPCM-08704/250905_leica/22_42_interleaved\250905_133600\data.h5)
Reading 250905_133630 (Z:\Individual_Folders\Tobi\WF_axonimaging\axonal_imaging_tobi\data\raw/JPCM-08704/250905_leica/22_42_interleaved\250905_133630\data.h5)
Reading 250905_133700 (Z:\Individual_Folders\Tobi\WF_axonimaging\axonal_imaging_tobi\data\raw/JPCM-08704/250905_leica/22_42_interleaved\250905_133700\data.h5)
Reading 250905_133730 (Z:\Individual_Folders\Tobi\WF_axonimaging\axonal_imaging_tobi\data\raw/JPCM-08704/250905_leica/22_42_interleaved\250905_133730\data.h5)
R

In [21]:
dff_movie = stim_groups[22.0]['dff_mean']

In [25]:
stim_groups[22.0]['indices']

array([0, 2, 4, 7, 9])

In [22]:
temps = [res_list[i]['stim'] for i in range(len(entries))]

In [28]:
temperature = np.array(temps)[stim_groups[22.0]['indices']].mean(axis=0)

In [55]:
from pathlib import Path
import numpy as np
import imageio


def create_temp_movie(
    dat: dict,
    output_file: Path,
    frames_per_second: float,
    cmap: str = "inferno",
):
    """
    Create a movie with:
      - Top subplot (20% height): temperature trace with moving cursor
      - Bottom subplot (80% height): movie frames

    Parameters
    ----------
    dat : dict
        Must contain:
            dat["movie"] : ndarray (T, H, W)
            dat["temp"]  : ndarray (T,)
    output_file : Path
        Output video path (.mp4)
    frames_per_second : float
        Video frame rate
    cmap : str
        Colormap for movie frames
    """

    from io import BytesIO
    import matplotlib.pyplot as plt

    movie = dat["movie"]
    temp = dat["temp"]

    if movie.ndim != 3:
        raise ValueError("dat['movie'] must have shape (T, H, W)")
    if temp.ndim != 1:
        raise ValueError("dat['temp'] must be 1D (T,)")

    n_frames, height, width = movie.shape

    #if len(temp) != n_frames:
     #   raise ValueError("Temperature trace length must match number of frames")

    # -------------------------
    # Figure + layout
    # -------------------------
    fig, axes = plt.subplots(
        2,
        1,
        figsize=(10, 8),
        gridspec_kw={"height_ratios": [1, 4]},  # 20% / 80%
    )

    fig.patch.set_facecolor("black")

    ax_temp, ax_img = axes

    # -------------------------
    # Temperature plot (static)
    # -------------------------
    time = np.arange(temp.shape[0])

    ax_temp.plot(time, temp, color="white", linewidth=1)
    temp_cursor = ax_temp.axvline(0, color="red", linewidth=2)

    ax_temp.set_xlim(0, temp.shape[0])
    ax_temp.set_ylabel("Temp", color="white")
    ax_temp.tick_params(colors="white")
    ax_temp.spines[:].set_color("white")
    ax_temp.set_facecolor("black")

    # -------------------------
    # Movie plot
    # -------------------------
    vmin = 0#float(np.nanmin(movie))
    vmax = 0.03#float(np.nanmax(movie))

    img = ax_img.imshow(movie[0], cmap=cmap, vmin=vmin, vmax=vmax)
    ax_img.axis("off")

    frame_text = fig.suptitle(
        f"Frame 1 / {n_frames}", color="white", fontsize=12
    )

    plt.tight_layout()

    # -------------------------
    # Video writer
    # -------------------------
    writer = imageio.get_writer(
        str(output_file),
        format="FFMPEG",
        fps=frames_per_second,
        codec="libx264",
    )

    # -------------------------
    # Frame loop
    # -------------------------
    for i in range(n_frames):
        img.set_data(movie[i])
        temp_cursor.set_xdata([i/frames_per_second*samplingrate])
        frame_text.set_text(f"Frame {i + 1} / {n_frames}")

        buffer = BytesIO()
        fig.savefig(
            buffer,
            format="png",
            facecolor=fig.get_facecolor(),
            dpi=100,
            bbox_inches="tight",
        )
        buffer.seek(0)

        frame = imageio.v2.imread(buffer)
        buffer.close()

        # Drop alpha channel if present
        if frame.shape[2] == 4:
            frame = frame[:, :, :3]

        writer.append_data(frame)

    writer.close()
    plt.close(fig)


In [56]:
temperature.shape

(15000,)

In [57]:
fps=20
samplingrate=1000

dat = {
    "movie": dff_movie,     # (T, H, W)
    "temp": temperature,    # (T,)
}

create_temp_movie(
    dat=dat,
    output_file=Path("temp_movie.mp4"),
    frames_per_second=20,
)




In [58]:
from pathlib import Path
import numpy as np
import imageio


def create_multi_temp_movie(
    dat: dict,
    output_file: Path,
    frames_per_second: float,
    cmap: str = "inferno",
    vmin: float | None = None,
    vmax: float | None = None,
):
    """
    Create a movie with N columns, each containing:
      - Top subplot (20%): stimulus/temperature trace with moving cursor
      - Bottom subplot (80%): movie frames

    Parameters
    ----------
    dat : dict
        Required keys:
            dat["movies"] : list of ndarray, each (T, H, W)
            dat["traces"] : list of ndarray, each (Nt,)
            dat["trace_sampling_rate"] : float (Hz)

        Optional:
            dat["labels"] : list of str (length N)

    output_file : Path
        Output video (.mp4)
    frames_per_second : float
        Video frame rate
    cmap : str
        Colormap for movies
    vmin, vmax : float or None
        Global color limits for all movies (None → auto per movie)
    """

    from io import BytesIO
    import matplotlib.pyplot as plt

    movies = dat["movies"]
    traces = dat["traces"]
    trace_sr = dat["trace_sampling_rate"]
    labels = dat.get("labels", None)

    n_cols = len(movies)

    if len(traces) != n_cols:
        raise ValueError("dat['movies'] and dat['traces'] must have same length")

    if labels is not None and len(labels) != n_cols:
        raise ValueError("dat['labels'] must match number of movies")

    # Basic validation
    n_frames = movies[0].shape[0]
    for i, mov in enumerate(movies):
        if mov.ndim != 3:
            raise ValueError(f"Movie {i} must have shape (T, H, W)")
        if mov.shape[0] != n_frames:
            raise ValueError("All movies must have same number of frames")

    # -------------------------
    # Figure layout
    # -------------------------
    fig = plt.figure(figsize=(6 * n_cols, 8))
    gs = fig.add_gridspec(
        2,
        n_cols,
        height_ratios=[1, 4],   # 20% / 80%
        hspace=0.1,
        wspace=0.05,
    )

    fig.patch.set_facecolor("black")

    axes_trace = []
    axes_img = []

    for col in range(n_cols):
        axes_trace.append(fig.add_subplot(gs[0, col]))
        axes_img.append(fig.add_subplot(gs[1, col]))

    # -------------------------
    # Plot traces (static)
    # -------------------------
    trace_cursors = []

    for i in range(n_cols):
        trace = traces[i]
        time = np.arange(trace.shape[0]) / trace_sr

        ax = axes_trace[i]
        ax.plot(time, trace, color="white", linewidth=1)
        cursor = ax.axvline(0, color="red", linewidth=2)
        trace_cursors.append(cursor)

        ax.set_xlim(time[0], time[-1])
        ax.set_ylabel("Stim", color="white")
        ax.tick_params(colors="white")
        ax.spines[:].set_color("white")
        ax.set_facecolor("black")

        if labels is not None:
            ax.set_title(labels[i], color="white")

    # -------------------------
    # Plot movies
    # -------------------------
    img_handles = []

    for i in range(n_cols):
        movie = movies[i]
        ax = axes_img[i]

        if vmin is None:
            vmin_i = float(np.nanmin(movie))
        else:
            vmin_i = vmin

        if vmax is None:
            vmax_i = float(np.nanmax(movie))
        else:
            vmax_i = vmax

        img = ax.imshow(movie[0], cmap=cmap, vmin=vmin_i, vmax=vmax_i)
        img_handles.append(img)
        ax.axis("off")

    frame_text = fig.suptitle(
        f"Frame 1 / {n_frames}",
        color="white",
        fontsize=12,
    )

    plt.tight_layout()

    # -------------------------
    # Video writer
    # -------------------------
    writer = imageio.get_writer(
        str(output_file),
        format="FFMPEG",
        fps=frames_per_second,
        codec="libx264",
    )

    # -------------------------
    # Frame loop
    # -------------------------
    for f in range(n_frames):
        t_sec = f / frames_per_second

        for i in range(n_cols):
            img_handles[i].set_data(movies[i][f])
            trace_cursors[i].set_xdata([t_sec])

        frame_text.set_text(f"Frame {f + 1} / {n_frames}")

        buffer = BytesIO()
        fig.savefig(
            buffer,
            format="png",
            facecolor=fig.get_facecolor(),
            dpi=100,
            bbox_inches="tight",
        )
        buffer.seek(0)

        frame = imageio.v2.imread(buffer)
        buffer.close()

        if frame.shape[2] == 4:
            frame = frame[:, :, :3]

        writer.append_data(frame)

    writer.close()
    plt.close(fig)


In [59]:
dat = {
    "movies": [stim_groups[22.0]['dff_mean'], stim_groups[42.0]['dff_mean']],
    "traces": [np.array(temps)[stim_groups[22.0]['indices']].mean(axis=0), np.array(temps)[stim_groups[42.0]['indices']].mean(axis=0)],
    "trace_sampling_rate": samplingrate,
    "labels": ["Cold", "Warm"],
}

In [60]:
create_multi_temp_movie(
    dat,
    Path("comparison.mp4"),
    frames_per_second=20,
    vmin=0,
    vmax=0.03,
)

  plt.tight_layout()
