In [15]:
import zarr
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

from huggingface_hub import snapshot_download
import numpy as np

In [2]:
# Specify the mission you want to download.
mission = "2024-10-01-11-47-44"

# Download the full dataset
#allow_patterns = [f"*"]

# Download all data from a single mission
allow_patterns = [f"{mission}/*"]

# Download a specific topic
#topic = "alphasense_front_center"
#allow_patterns = [f"{mission}/*{topic}*", f"{mission}/*.yaml"]


# If this is interuppted during download, simply re-run the block and huggingface_hub will resume the download without re-downloading the already downloaded files.
hugging_face_data_cache_path = snapshot_download(repo_id="leggedrobotics/grand_tour_dataset", allow_patterns=allow_patterns, repo_type="dataset")

Fetching 202 files: 100%|██████████| 202/202 [04:56<00:00,  1.47s/it]


In [3]:
hugging_face_data_cache_path

'/home/rohan/.cache/huggingface/hub/datasets--leggedrobotics--grand_tour_dataset/snapshots/eed9be0dba01495fcc4fe9fd737c9f767a23f8e9'

In [6]:
from pathlib import Path

# Define the destination directory
dataset_folder = Path("~/Desktop/grand_tour/grand_tour_dataset_exploration/missions").expanduser()
dataset_folder.mkdir(parents=True, exist_ok=True)

# Print for confirmation
print(f"Data will be extracted to: {dataset_folder}")

Data will be extracted to: /home/rohan/Desktop/grand_tour/grand_tour_dataset_exploration/missions


In [7]:
import os
import shutil
import tarfile
import re

def move_dataset(cache, dataset_folder, allow_patterns=["*"]):

    def convert_glob_patterns_to_regex(glob_patterns):
        regex_parts = []
        for pat in glob_patterns:
            # Escape regex special characters except for * and ?
            pat = re.escape(pat)
            # Convert escaped glob wildcards to regex equivalents
            pat = pat.replace(r'\*', '.*').replace(r'\?', '.')
            # Make sure it matches full paths
            regex_parts.append(f".*{pat}$")
        
        # Join with |
        combined = "|".join(regex_parts)
        return re.compile(combined)
    
    pattern = convert_glob_patterns_to_regex(allow_patterns)
    files = [f for f in Path(cache).rglob("*") if pattern.match(str(f))]
    tar_files = [f for f in files if f.suffix == ".tar" ]
    
    for source_path in tar_files:
        dest_path = dataset_folder / source_path.relative_to(cache)
        dest_path.parent.mkdir(parents=True, exist_ok=True)
        
        try:
            with tarfile.open(source_path, "r") as tar:
                tar.extractall(path=dest_path.parent)
        except tarfile.ReadError as e:
            print(f"Error opening or extracting tar file '{source_path}': {e}")
        except Exception as e:
            print(f"An unexpected error occurred while processing {source_path}: {e}")
    
    other_files = [f for f in files if not f.suffix == ".tar" and f.is_file()]
    for source_path in other_files:
        dest_path = dataset_folder / source_path.relative_to(cache)
        dest_path.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(source_path,dest_path)

    print(f"Moved data from {cache} to {dataset_folder} !")

print(dataset_folder)
move_dataset(hugging_face_data_cache_path, dataset_folder, allow_patterns=allow_patterns)

/home/rohan/Desktop/grand_tour/grand_tour_dataset_exploration/missions
Moved data from /home/rohan/.cache/huggingface/hub/datasets--leggedrobotics--grand_tour_dataset/snapshots/eed9be0dba01495fcc4fe9fd737c9f767a23f8e9 to /home/rohan/Desktop/grand_tour/grand_tour_dataset_exploration/missions !


In [10]:
mission_folder = dataset_folder / mission
mission_root = zarr.open_group(store=mission_folder / "data", mode='r')

In [13]:
sensors = [
    "anymal_state_odometry",
    "anymal_state_state_estimator",
    "anymal_imu",
    "anymal_state_actuator",
    "anymal_command_twist",
    #"hdr_front",
    #"hdr_left",
    #"hdr_right"
]

In [17]:
TARGET_HZ = 50.0
DT = 1.0 / TARGET_HZ

def _to_np(x):
    return np.asarray(x[:]) if hasattr(x, '__getitem__') and not isinstance(x, np.ndarray) else np.asarray(x)

def _search_zoh_indices(src_ts, tgt_ts):
    """
    Vectorized zero order hold: for each target time, pick the last src index with src_ts <= tgt
    in english: find the index of the last source timestamp that happened at or before that time.
    Returns idx array (int) with -1 where no src sample exists yet
    """
    idx = np.searchsorted(src_ts, tgt_ts, side='right') - 1
    return idx

def _resample_group_zoh(group, tgt_ts, ts_key="timestamp", skip_keys=("timestamp","sequence_id")):
    """
    Resample all fields in a Zarr group to tgt_ts using ZOH.
    """
    out = {}
    src_ts = _to_np(group[ts_key])

    # Assure ascending timestamps
    if not np.all(src_ts[:-1] <= src_ts[1:]):
        order = np.argsort(src_ts)
        src_ts = src_ts[order]
        # Reorder all fields to keep arrays aligned
        for key in group.keys():
            if key in skip_keys: 
                continue
            arr = _to_np(group[key])
            out[key] = arr[order]  # temp store; we’ll overwrite after computing indices
        reordered = True
    else:
        reordered = False

    idx = _search_zoh_indices(src_ts, tgt_ts)  # -1 if tgt time is before first src sample
    # For each tgt time stamp, find which source timestamp (from og sensor) was the most recent reading that happened <= the tgt time
    # --> so idx are the row of the original sensor data to use for each new aligned time step

    # Build a safe index for gather; we’ll mask invalids later
    safe_idx = idx.copy()
    safe_idx[safe_idx < 0] = 0
    safe_idx[safe_idx >= len(src_ts)] = len(src_ts) - 1

    for key in group.keys():
        if key in skip_keys: 
            continue

        arr = _to_np(group[key]) if not (reordered and key in out) else out[key]
        # Gather
        res = arr[safe_idx]
        # Mask times before the first source sample (the -1s from _search_zoh_indices) as NaN 
        if res.dtype.kind in ('f',):  # floating types: use NaN
            res[idx < 0] = np.nan
        else:
            # For non-floats (ints, bools), you can choose a sentinel; here we keep first value.
            pass
        out[key] = res

    # Always return the resampled timestamps too (the grid)
    out["timestamp_50hz"] = tgt_ts
    return out

def _overlap_window(mission_root, sensors, ts_key="timestamp"):
    """Compute overlapping [start, end] across sensors to avoid extrapolation beyond last sample."""
    starts = []
    ends = []
    for s in sensors:
        ts = _to_np(mission_root[s][ts_key])
        starts.append(ts[0])
        ends.append(ts[-1])
    return max(starts), min(ends)

def build_50hz_grid(t_start, t_end):
    # Inclusive start, inclusive end if it lands exactly; otherwise stops before end
    n = int(np.floor((t_end - t_start) * TARGET_HZ)) + 1
    return (t_start + np.arange(n) * DT).astype(np.float64)

# main entrypoint
def align_mission_to_50hz(mission_root, sensors, ts_key="timestamp"):
    """
    Returns:
      {
        "t": np.ndarray [T],  # 50 Hz grid
        "sensors": {
            sensor_name: { field: np.ndarray[T, ...], "timestamp_50hz": np.ndarray[T] }
        }
      }
    """
    t0, t1 = _overlap_window(mission_root, sensors, ts_key=ts_key)
    tgt_ts = build_50hz_grid(t0, t1)

    aligned = {}
    for s in sensors:
        aligned[s] = _resample_group_zoh(mission_root[s], tgt_ts, ts_key=ts_key)

    return {"t": tgt_ts, "sensors": aligned}


aligned = align_mission_to_50hz(mission_root, sensors)

t = aligned["t"]  # 50 Hz timeline
base_lin_vel = aligned["sensors"]["anymal_state_state_estimator"]["twist_lin"]   
imu_ang_vel   = aligned["sensors"]["anymal_imu"]["ang_vel"]                      
cmd_linear    = aligned["sensors"]["anymal_command_twist"]["linear"]             


In [18]:
for sensor in sensors:
    print(f"{sensor}:")
    keys = list(aligned["sensors"][sensor].keys())
    keys.sort()
    for key in keys:
        print(f"-->    {key} {aligned["sensors"][sensor][key].shape} {type(aligned["sensors"][sensor][key])}")
    print(f"------------------------------\n")
        

anymal_state_odometry:
-->    pose_cov (18228, 6, 6) <class 'numpy.ndarray'>
-->    pose_orien (18228, 4) <class 'numpy.ndarray'>
-->    pose_pos (18228, 3) <class 'numpy.ndarray'>
-->    timestamp_50hz (18228,) <class 'numpy.ndarray'>
-->    twist_ang (18228, 3) <class 'numpy.ndarray'>
-->    twist_cov (18228, 6, 6) <class 'numpy.ndarray'>
-->    twist_lin (18228, 3) <class 'numpy.ndarray'>
------------------------------

anymal_state_state_estimator:
-->    LF_FOOT_contact (18228,) <class 'numpy.ndarray'>
-->    LF_FOOT_friction_coef (18228,) <class 'numpy.ndarray'>
-->    LF_FOOT_normal (18228, 3) <class 'numpy.ndarray'>
-->    LF_FOOT_restitution_coef (18228,) <class 'numpy.ndarray'>
-->    LF_FOOT_state (18228,) <class 'numpy.ndarray'>
-->    LF_FOOT_wrench_force (18228, 3) <class 'numpy.ndarray'>
-->    LF_FOOT_wrench_torque (18228, 3) <class 'numpy.ndarray'>
-->    LH_FOOT_contact (18228,) <class 'numpy.ndarray'>
-->    LH_FOOT_friction_coef (18228,) <class 'numpy.ndarray'>
--> 

In [19]:
def get_axis_params(value, axis_idx):
    axis = np.zeros(3)
    axis[axis_idx] = value
    return axis

def quat_rotate_inverse(q, v):
    """
    Rotate vector(s) v by the inverse of quaternion(s) q.
    q: (..., 4) array [x, y, z, w]
    v: (..., 3) array
    returns: rotated v in same shape
    """
    q = np.asarray(q)
    v = np.asarray(v)

    q_vec = q[..., :3]         # (x, y, z)
    q_w = q[..., 3]            # w
    t = 2.0 * np.cross(q_vec, v)
    return v - q_w[..., None] * t + np.cross(q_vec, t)


In [20]:
from reward import compute_rewards_offline
def build_offline_dataset(data, episode_len_s=20, hz=50):
    """Convert aligned ANYmal sensor data into offline RL dataset."""

    est = data["sensors"]["anymal_state_state_estimator"]
    act = data["sensors"]["anymal_state_actuator"]
    cmd = data["sensors"]["anymal_command_twist"]
    imu = data["sensors"]["anymal_imu"]

    up_axis_idx = 2 # 2 for z, 1 for y -> adapt gravity accordingly
    gravity_vec = get_axis_params(-1., up_axis_idx)
    base_quat =  imu["orien"] # assumes quaternion that rotates body --> world 
    projected_gravity = quat_rotate_inverse(base_quat, gravity_vec) # (T,4)

    base_lin_vel = est["twist_lin"]          # (T, 3)
    base_ang_vel = est["twist_ang"]          # (T, 3)
    joint_pos = est["joint_positions"]       # (T, 12)
    joint_vel = est["joint_velocities"]      # (T, 12)
    cmd_lin = cmd["linear"]                  # (T, 3)
    cmd_ang = cmd["angular"]                 # (T, 3)

    act_keys = [f"{i:02d}_command_position" for i in range(12)]
    actions = np.stack([act[k] for k in act_keys], axis=-1)   # (T, 12)

    prev_actions = np.zeros_like(actions)
    prev_actions[1:] = actions[:-1]

    obs = np.concatenate([
        base_lin_vel,
        base_ang_vel,
        projected_gravity,
        joint_pos,
        joint_vel,
        prev_actions,       
        cmd_lin,
        cmd_ang,
    ], axis=-1)  # (T, obs_dim)
    
    rews = compute_rewards_offline(
        base_ang_vel,
        base_lin_vel,
        prev_actions,
        actions,
        joint_vel,
        est["LF_FOOT_contact"],
        est["LH_FOOT_contact"],
        est["RF_FOOT_contact"],
        est["RH_FOOT_contact"],
        cmd_lin,
        cmd_ang,
        est["joint_efforts"],
        len(obs)
    )

    # Shift for next_observations 
    observations = obs[:-1]
    next_observations = obs[1:]
    actions = actions[:-1]
    rewards = rews[:-1]

    # Terminals every 20s (20s * 50hz = 1000 steps)
    T = len(observations)
    episode_len = int(episode_len_s * hz)
    terminals = np.zeros(T, dtype=bool)
    terminals[np.arange(episode_len - 1, T, episode_len)] = True


    # offline dataset 
    dataset = dict(
        observations=observations,
        actions=actions,
        next_observations=next_observations,
        rewards=rewards,
        terminals=terminals,
    )

    return dataset


In [21]:
dataset = build_offline_dataset(aligned)

In [22]:
for k in list(dataset.keys()):
    print(f"{k}: {type(dataset[k])} {dataset[k].shape}")

observations: <class 'numpy.ndarray'> (18227, 51)
actions: <class 'numpy.ndarray'> (18227, 12)
next_observations: <class 'numpy.ndarray'> (18227, 51)
rewards: <class 'numpy.ndarray'> (18227,)
terminals: <class 'numpy.ndarray'> (18227,)


In [23]:
def episode_returns(rewards, terminals):
    episode_sums = []
    current_sum = 0.0

    for r, done in zip(rewards, terminals):
        current_sum += r
        if done:
            episode_sums.append(current_sum)
            current_sum = 0.0

    if not terminals[-1]:
        episode_sums.append(current_sum)

    return np.array(episode_sums)

ep_ret = episode_returns(dataset["rewards"], dataset["terminals"])
print(len(ep_ret))
print(np.median(ep_ret))
ep_ret

19
26.53067552820117


array([27.44218748, 27.05541286, 26.39754375, 26.51982278, 24.91035863,
       24.27332467, 26.3357035 , 26.53067553, 27.70199123, 26.7212877 ,
       26.60569158, 26.46818968, 26.61520911, 26.70020568, 26.95783483,
       26.43130929, 26.16244311, 26.72840684,  6.05985623])