In [22]:
import stac_mjx 
from pathlib import Path
import h5py
from tqdm import tqdm
import numpy as np

In [9]:
# Paths to original and clipped files
original_file_path = "/root/vast/eric/CVAT_mouse_reach/triangulate_optimize/joshua_data/points3d_keewui_params.h5"
clipped_file_path = "/root/vast/eric/CVAT_mouse_reach/triangulate_optimize/joshua_data/points3d_keewui_params_clipped.h5"

# Define the frame range
start_frame = 7000
end_frame = 8000

# Open the original HDF5 file and read the dataset
with h5py.File(original_file_path, "r") as original_h5:
    original_dataset = original_h5["tracks"]
    total_frames = end_frame - start_frame

    # Allocate memory for the sliced frames
    truncated_data = np.empty(
        (total_frames,) + original_dataset.shape[1:],  # preserve all dims except the first
        dtype=original_dataset.dtype
    )

    # Use tqdm to show progress as frames are copied
    for i, frame_idx in enumerate(tqdm(range(start_frame, end_frame), desc="Slicing frames")):
        truncated_data[i] = original_dataset[frame_idx]

    truncated_data = truncated_data.squeeze()

# Write the truncated data to a new HDF5 file
with h5py.File(clipped_file_path, "w") as clipped_h5:
    clipped_h5.create_dataset("tracks", data=truncated_data)

print(f"Saved frames [{start_frame}:{end_frame}] to '{clipped_file_path}' under the dataset 'tracks'.")

Slicing frames:   0%|          | 0/1000 [00:00<?, ?it/s]

Slicing frames: 100%|██████████| 1000/1000 [00:00<00:00, 19200.47it/s]

Saved frames [7000:8000] to '/root/vast/eric/CVAT_mouse_reach/triangulate_optimize/joshua_data/points3d_keewui_params_clipped.h5' under the dataset 'tracks'.





In [19]:
# Paths to original and clipped files
clipped_file_path = "/root/vast/eric/CVAT_mouse_reach/triangulate_optimize/joshua_data/points3d_keewui_params_clipped.h5"

# Open the original HDF5 file and read the dataset
with h5py.File(clipped_file_path, "r") as clipped_h5:
    clipped_dataset = clipped_h5["tracks"]

    print(clipped_dataset.shape)

(1000, 3, 3)


In [25]:
import os
import numpy as np
from jax import numpy as jnp
import yaml
import scipy.io as spio
import pickle
from typing import Text, Union
from pynwb import NWBHDF5IO
from ndx_pose import PoseEstimationSeries, PoseEstimation
import h5py
from pathlib import Path
from omegaconf import DictConfig
import stac_mjx.io_dict_to_hdf5 as ioh5

In [26]:
def load_data(cfg: DictConfig, base_path: Union[Path, None] = None):
    """Load mocap data based on file type.

    Loads mocap file based on filetype, and returns the data flattened
    for immediate consumption by stac_mjx algorithm.

    Args:
        cfg (DictConfig): Configs.
        base_path (Union[Path, None], optional): Base path for file paths in configs. Defaults to None.

    Returns:
        Mocap data flattened into an np array of shape [#frames, keypointXYZ],
        where 'keypointXYZ' represents the flattened 3D keypoint components.
        The data is also scaled by multiplication with "MOCAP_SCALE_FACTOR", e.g.
        if the mocap data is in mm and the model is in meters, this should be
        0.001.

    Raises:
        ValueError if an unsupported filetype is encountered.
        ValueError if ordered list of keypoint names is missing or
        does not match number of keypoints.
    """
    if base_path is None:
        base_path = Path.cwd()

    file_path = base_path / cfg.stac.data_path
    # using pathlib
    if file_path.suffix == ".mat":
        label3d_path = cfg.model.get("KP_NAMES_LABEL3D_PATH", None)
        data, kp_names = load_dannce(str(file_path), names_filename=label3d_path)
    elif file_path.suffix == ".nwb":
        data, kp_names = load_nwb(file_path)
    elif file_path.suffix == ".h5":
        data, kp_names = load_h5(file_path)
    else:
        raise ValueError(
            "Unsupported file extension. Please provide a .nwb or .mat file."
        )

    kp_names = kp_names or cfg.model.KP_NAMES

    if kp_names is None:
        raise ValueError(
            "Keypoint names not provided. Please provide an ordered list of keypoint names \
            corresponding to the keypoint data order."
        )

    if len(kp_names) != data.shape[1]:
        raise ValueError(
            f"Number of keypoint names ({len(kp_names)}) is not the same as the number of keypoints in data ({data.shape[1]})"
        )

    model_inds = [
        kp_names.index(src) for src, dst in cfg.model.KEYPOINT_MODEL_PAIRS.items()
    ]

    sorted_kp_names = [kp_names[i] for i in model_inds]

    # Scale mocap data to match model
    data = data * cfg.model.MOCAP_SCALE_FACTOR
    # Sort in kp_names order
    data = jnp.array(data[:, :, model_inds])
    # Flatten data from [#num frames, #keypoints, xyz]
    # into [#num frames, #keypointsXYZ]
    data = jnp.transpose(data, (0, 2, 1))
    data = jnp.reshape(data, (data.shape[0], -1))

    return data, sorted_kp_names


def load_dannce(filename, names_filename=None):
    """Load mocap data from .mat file.

    .mat file is presumed to be constructed by dannce:
    (https://github.com/spoonsso/dannce). In particular this means it relies on
    the data being in millimeters [num frames, num keypoints, xyz], and that we
    use the data stored in the "pred" key.
    """
    node_names = None
    if names_filename is not None:
        mat = spio.loadmat(names_filename)
        node_names = [item[0] for sublist in mat["joint_names"] for item in sublist]

    data = _check_keys(spio.loadmat(filename, struct_as_record=False, squeeze_me=True))[
        "pred"
    ]
    return data, node_names


def load_nwb(filename):
    """Load mocap data from .nwb file.

    Data is presumed [num frames, num keypoints, xyz].
    """
    data = []
    with NWBHDF5IO(filename, mode="r", load_namespaces=True) as io:
        nwbfile = io.read()
        pose_est = nwbfile.processing["behavior"]["PoseEstimation"]
        node_names = pose_est.nodes[:].tolist()
        data = np.stack(
            [pose_est[node_name].data[:] for node_name in node_names], axis=-1
        )

    return data, node_names


def load_h5(filename):
    """Load .h5 file formatted as [frames, xyz, keypoints].

    Args:
        filename (str): Path to the .h5 file.

    Returns:
        dict: Dictionary containing the data from the .h5 file.
    """
    # TODO add track information
    data = {}
    with h5py.File(filename, "r") as f:
        for key in f.keys():
            data[key] = f[key][()]

    data = np.array(data["tracks"])
    data = np.squeeze(data, axis=1)
    data = np.transpose(data, (0, 2, 1))
    return data, None


def _check_keys(dict):
    """Check if entries in dictionary are mat-objects.

    Mat-objects are changed to nested dictionaries.
    """
    for key in dict:
        if isinstance(dict[key], spio.matlab.mat_struct):
            dict[key] = _todict(dict[key])
    return dict


def _todict(matobj):
    """A recursive function which constructs from matobjects nested dictionaries."""
    dict = {}
    for strg in matobj._fieldnames:
        elem = matobj.__dict__[strg]
        if isinstance(elem, spio.matlab.mat_struct):
            dict[strg] = _todict(elem)
        else:
            dict[strg] = elem
    return dict


def _load_params(param_path):
    """Load parameters for the animal.

    :param param_path: Path to .yaml file specifying animal parameters.
    """
    with open(param_path, "r") as infile:
        try:
            params = yaml.safe_load(infile)
        except yaml.YAMLError as exc:
            print(exc)
    return params


# FLY_MODEL: decide to keep or not!
# def load_stac_ik_only(save_path):
#     _, file_extension = os.path.splitext(save_path)
#     if file_extension == ".p":
#         with open(save_path, "rb") as file:
#             fit_data = pickle.load(file)
#     elif file_extension == ".h5":
#         fit_data = ioh5.load(save_path)
#     return fit_data


def save(fit_data, save_path: Text):
    """Save data.

    Save data as .p or .h5 file.

    Args:
        fit_data (numpy array): Data to write out.
        save_path (Text): Path to save data. Defaults to None.
    """
    if os.path.dirname(save_path) != "":
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
    _, file_extension = os.path.splitext(save_path)
    if file_extension == ".p":
        with open(save_path, "wb") as output_file:
            pickle.dump(fit_data, output_file, protocol=2)
    elif file_extension == ".h5":
        ioh5.save(save_path, fit_data)
    else:
        with open(save_path + ".p", "wb") as output_file:
            pickle.dump(fit_data, output_file, protocol=2)


In [37]:
# Enable XLA flags if on GPU
stac_mjx.enable_xla_flags()

# Choose parent directory as base path for data files
base_path = Path("/root/vast/eric/stac-mjx/")

# Load configs
cfg = stac_mjx.load_configs(base_path / "configs")

# Load data
kp_data, sorted_kp_names = load_data(cfg, base_path)

# Run stac
fit_path, ik_only_path = stac_mjx.run_stac(
 cfg,
 kp_data, 
 sorted_kp_names, 
 base_path=base_path
)

Calibration iteration: 1/6
Pose Optimization:
Pose Optimization done in 29.298856258392334
Frame 1 done in 29.249473571777344 with a final error of 0.0
Mean: 0.0
Standard deviation: 0.0
starting offset optimization
Begining offset optimization:


  return jax.tree_map(update_fun, params, updates)


Final error of 0.0009889188222587109
offset optimization finished in 9.90920090675354
Calibration iteration: 2/6
Pose Optimization:
Pose Optimization done in 0.03760814666748047
Frame 1 done in 0.03358817100524902 with a final error of 0.0
Mean: 0.0
Standard deviation: 0.0
starting offset optimization
Begining offset optimization:
Final error of 0.0009778663516044617
offset optimization finished in 7.165462255477905
Calibration iteration: 3/6
Pose Optimization:
Pose Optimization done in 0.03855466842651367
Frame 1 done in 0.03409934043884277 with a final error of 0.0
Mean: 0.0
Standard deviation: 0.0
starting offset optimization
Begining offset optimization:
Final error of 0.0009818419348448515
offset optimization finished in 0.056619882583618164
Calibration iteration: 4/6
Pose Optimization:
Pose Optimization done in 0.03251075744628906
Frame 1 done in 0.02886819839477539 with a final error of 0.0
Mean: 0.0
Standard deviation: 0.0
starting offset optimization
Begining offset optimizati

In [40]:
print(fit_path)

/root/vast/eric/stac-mjx/demo_fit.p
