# Sample Efficiency Dev

This notebook is the development of the sample efficiency experiment.

The rodent dataset can be downloaded via the [following google drive link](https://drive.google.com/file/d/1BHKe9agnupdPC8xExx4OaPtPJBKTTI2U/view).

## Create Dataset w/ Different Sample Sizes and Different Sample Distributions


In [1]:
%load_ext autoreload
%autoreload 2


from typing import Union
import h5py
from jax import numpy as jp
from flax import struct
import yaml
from omegaconf import DictConfig
from pathlib import Path
import hydra
import logging
from track_mjx.io.load import ReferenceClip


def make_multiclip_data(traj_data_path):
    """Creates ReferenceClip object with multiclip tracking data.
    Features have shape = (clips, frames, dims)
    """

    def reshape_frames(arr, clip_len):
        return jp.array(
            arr[()].reshape(arr.shape[0] // clip_len, clip_len, *arr.shape[1:])
        )

    with h5py.File(traj_data_path, "r") as data:
        # Read the config string as yaml in to dict
        yaml_str = data["config"][()]
        yaml_str = yaml_str.decode("utf-8")
        config = yaml.safe_load(yaml_str)
        clip_len = config["stac"]["n_frames_per_clip"]

        # Reshape the data to (clips, frames, dims)
        batch_qpos = reshape_frames(data["qpos"], clip_len)
        batch_xpos = reshape_frames(data["xpos"], clip_len)
        batch_qvel = reshape_frames(data["qvel"], clip_len)
        batch_xquat = reshape_frames(data["xquat"], clip_len)
        return ReferenceClip(
            position=batch_qpos[:, :, :3],
            quaternion=batch_qpos[:, :, 3:7],
            joints=batch_qpos[:, :, 7:],
            body_positions=batch_xpos,
            velocity=batch_qvel[:, :, :3],
            angular_velocity=batch_qvel[:, :, 3:6],
            joints_velocity=batch_qvel[:, :, 6:],
            body_quaternions=batch_xquat,
        )

In [2]:
stac_data_path = "/root/vast/scott-yang/track-mjx/data/transform_snips.h5"

clips = make_multiclip_data(stac_data_path)

In [9]:
clips.position.shape

(842, 250, 3)

In [6]:
with h5py.File(stac_data_path, "r") as data:
    # Read the config string as yaml in to dict)
    yaml_str = data["config"][()]
    yaml_str = yaml_str.decode("utf-8")
    config = yaml.safe_load(yaml_str)
    clip_len = config["stac"]["n_frames_per_clip"]
    print(f"{clip_len=}")
    print(data["qpos"].shape)

clip_len=250
(210500, 74)


In [4]:
# here is how to get the snips class from data file.
config["model"]["snips_order"]

import re

pattern = re.compile(r"/([^/]+)_([0-9]+)\.p$")

for path in config["model"]["snips_order"]:
    match = pattern.search(path)
    if match:
        name, number = match.groups()
        print(name, number)

RGroom 4
Rear 149
Walk 174
FaceGroom 0
LGroom 48
FaceGroom 143
FaceGroom 82
LGroom 51
Walk 182
FaceGroom 142
FaceGroom 101
Rear 112
Rear 191
LGroom 177
LGroom 107
FastWalk 76
Rear 168
FaceGroom 53
Rear 87
Walk 4
Rear 171
LGroom 101
FastWalk 184
LGroom 161
RGroom 121
FastWalk 27
Rear 27
Rear 65
Rear 120
Rear 142
Walk 2
FaceGroom 149
FastWalk 6
Walk 173
FastWalk 64
FastWalk 115
FaceGroom 193
Rear 53
Walk 101
RGroom 116
FaceGroom 54
FastWalk 65
FastWalk 43
Walk 64
RGroom 48
Rear 8
Rear 70
RGroom 155
RGroom 154
FastWalk 107
Rear 28
FastWalk 108
FastWalk 154
RGroom 104
Walk 95
LGroom 147
Walk 73
Rear 133
FastWalk 148
Rear 60
FaceGroom 107
FastWalk 11
LGroom 32
FaceGroom 2
Rear 128
Rear 115
LGroom 106
FastWalk 41
FaceGroom 113
LGroom 135
FastWalk 96
RGroom 149
Rear 98
FastWalk 189
LGroom 167
Walk 154
FastWalk 122
RGroom 53
LGroom 12
RGroom 73
FaceGroom 119
Rear 89
FastWalk 143
FastWalk 79
FaceGroom 182
RGroom 199
RGroom 109
FaceGroom 167
Rear 156
RGroom 38
LGroom 45
Rear 189
FaceGroom 49
Rea