In [3]:
# === SAFE build: uniform-random 1k / 10k subsets (no discretization), train/val both from ORIGINAL train ===
# Output:
#   <DATASET_DIR>_1k/{train,val}/...
#   <DATASET_DIR>_10k/{train,val}/...

import gc, json, pickle, shutil, re
from pathlib import Path
from typing import Dict, Any, Tuple

import numpy as np
import torch

torch.set_grad_enabled(False)
torch.set_num_threads(1)

# -------------------- Config --------------------
LENGTHS_FNAME = "seq_lengths.pkl"

SUBSET_SPECS = {
    "1k": 1_000,
    "10k": 10_000,
}

DATASET_DIR = Path("/Users/julianquast/Downloads/pusht_noise").expanduser()

try:
    SUBSET_SEED
except NameError:
    SUBSET_SEED = 0

RANDOM_SEED = int(SUBSET_SEED)

# set True if you want to rebuild an existing <DATASET_DIR>_10k etc.
OVERWRITE = False

assert (DATASET_DIR / "train").is_dir(), f"Missing train/: {DATASET_DIR / 'train'}"
assert (DATASET_DIR / "val").is_dir(), f"Missing val/: {DATASET_DIR / 'val'}"

print(f"Dataset dir  : {DATASET_DIR}")
print(f"Subset specs : {SUBSET_SPECS}")
print(f"Random seed  : {RANDOM_SEED}")

# -------------------- Helpers --------------------
def load_lengths_only(split_dir: Path) -> np.ndarray:
    p = split_dir / LENGTHS_FNAME
    if not p.exists():
        raise FileNotFoundError(f"Missing {LENGTHS_FNAME} in {split_dir}")
    with open(p, "rb") as f:
        data = pickle.load(f)
    arr = data.astype(np.int64, copy=False) if isinstance(data, np.ndarray) else np.asarray(data, dtype=np.int64)
    if arr.ndim != 1:
        arr = arr.reshape(-1)
    return arr

def _subset_obj(obj: Any, idxs: np.ndarray, orig_count: int) -> Tuple[Any, bool]:
    """Returns (subset_obj, changed_flag)."""
    if torch.is_tensor(obj):
        if obj.ndim >= 1 and obj.shape[0] == orig_count:
            out = obj.index_select(0, torch.as_tensor(idxs, dtype=torch.long))
            return out, True
        return obj, False

    if isinstance(obj, np.ndarray):
        if obj.ndim >= 1 and obj.shape[0] == orig_count:
            return obj[idxs], True
        return obj, False

    if isinstance(obj, list):
        if len(obj) == orig_count:
            return [obj[int(i)] for i in idxs], True
        return obj, False

    if isinstance(obj, tuple):
        if len(obj) == orig_count:
            return tuple(obj[int(i)] for i in idxs), True
        return obj, False

    if isinstance(obj, dict):
        changed = False
        out = {}
        for k, v in obj.items():
            vv, ch = _subset_obj(v, idxs, orig_count)
            out[k] = vv
            changed = changed or ch
        return out, changed

    return obj, False

def torch_load_safe(path: Path) -> Any:
    """
    Safe loader:
      - tries mmap=True with *string filename* (required by torch internals)
      - falls back to normal torch.load on any error
    """
    # IMPORTANT: mmap=True requires a string path in some torch versions
    path_str = str(path)
    try:
        return torch.load(path_str, map_location="cpu", mmap=True)
    except Exception:
        return torch.load(path_str, map_location="cpu")

def subset_tensor_to_path(src_path: Path, dst_path: Path, idxs: np.ndarray, orig_count: int) -> bool:
    data = torch_load_safe(src_path)
    new_data, changed = _subset_obj(data, idxs, orig_count)
    if changed:
        dst_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(new_data, str(dst_path))
    # free
    del data, new_data
    gc.collect()
    return changed

def subset_pickle_to_path(src_path: Path, dst_path: Path, idxs: np.ndarray, orig_count: int) -> bool:
    with open(src_path, "rb") as f:
        data = pickle.load(f)
    new_data, changed = _subset_obj(data, idxs, orig_count)
    if changed:
        dst_path.parent.mkdir(parents=True, exist_ok=True)
        with open(dst_path, "wb") as f:
            pickle.dump(new_data, f)
    del data, new_data
    gc.collect()
    return changed

def build_episode_file_map(obs_dir: Path) -> Dict[int, Path]:
    mapping: Dict[int, Path] = {}
    if not obs_dir.exists():
        return mapping
    pattern = re.compile(r"episode_(\d+)")
    for path in obs_dir.glob("episode_*"):
        m = pattern.search(path.stem)
        if m:
            mapping[int(m.group(1))] = path
    return mapping

def copy_subset_observation_media(src_obs_dir: Path, dst_obs_dir: Path, idxs: np.ndarray):
    """Copy only selected episode_* files from src_obs_dir to dst_obs_dir, renaming to contiguous indices."""
    if not src_obs_dir.exists():
        return
    idxs = np.asarray(idxs, dtype=np.int64)
    mapping = build_episode_file_map(src_obs_dir)

    if dst_obs_dir.exists():
        shutil.rmtree(dst_obs_dir)
    dst_obs_dir.mkdir(parents=True, exist_ok=True)

    pad = max(3, len(str(len(idxs) - 1 if len(idxs) else 0)))
    for new_idx, old_idx in enumerate(idxs):
        src_path = mapping.get(int(old_idx))
        if src_path is None:
            raise FileNotFoundError(f"Missing obs file for episode {old_idx} in {src_obs_dir}")
        dst_path = dst_obs_dir / f"episode_{new_idx:0{pad}d}{src_path.suffix}"
        shutil.copy2(src_path, dst_path)

def copy_non_split_extras(src_root: Path, dst_root: Path):
    """Copy everything at dataset root except train/ and val/ into dst_root."""
    dst_root.mkdir(parents=True, exist_ok=True)
    for item in src_root.iterdir():
        if item.name in ("train", "val"):
            continue
        dst_item = dst_root / item.name
        if item.is_dir():
            shutil.copytree(item, dst_item, dirs_exist_ok=True)
        else:
            dst_item.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(item, dst_item)

def write_train_val_from_train_split(
    src_train_dir: Path,
    out_train_dir: Path,
    out_val_dir: Path,
    train_idxs: np.ndarray,
    val_idxs: np.ndarray,
    orig_count: int,
):
    """
    Build out_train_dir and out_val_dir from src_train_dir.

    For each item in src_train_dir:
      - obses/ is subset-copied for each output
      - *.pth/*.pt and *.pkl are subsetted if episode-aligned; otherwise copied as-is
      - other files/dirs are copied as-is to both outputs
    """
    out_train_dir.mkdir(parents=True, exist_ok=True)
    out_val_dir.mkdir(parents=True, exist_ok=True)

    for item in sorted(src_train_dir.iterdir()):
        if item.is_dir():
            if item.name == "obses":
                copy_subset_observation_media(item, out_train_dir / "obses", train_idxs)
                copy_subset_observation_media(item, out_val_dir / "obses", val_idxs)
            else:
                shutil.copytree(item, out_train_dir / item.name, dirs_exist_ok=True)
                shutil.copytree(item, out_val_dir / item.name, dirs_exist_ok=True)
            gc.collect()
            continue

        suffix = item.suffix.lower()
        dst_train = out_train_dir / item.name
        dst_val   = out_val_dir / item.name

        if suffix in (".pth", ".pt"):
            changed_train = subset_tensor_to_path(item, dst_train, train_idxs, orig_count)
            changed_val   = subset_tensor_to_path(item, dst_val,   val_idxs,   orig_count)

            if not changed_train:
                dst_train.parent.mkdir(parents=True, exist_ok=True)
                shutil.copy2(item, dst_train)
            if not changed_val:
                dst_val.parent.mkdir(parents=True, exist_ok=True)
                shutil.copy2(item, dst_val)

        elif suffix == ".pkl":
            changed_train = subset_pickle_to_path(item, dst_train, train_idxs, orig_count)
            changed_val   = subset_pickle_to_path(item, dst_val,   val_idxs,   orig_count)

            if not changed_train:
                dst_train.parent.mkdir(parents=True, exist_ok=True)
                shutil.copy2(item, dst_train)
            if not changed_val:
                dst_val.parent.mkdir(parents=True, exist_ok=True)
                shutil.copy2(item, dst_val)

        else:
            dst_train.parent.mkdir(parents=True, exist_ok=True)
            dst_val.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(item, dst_train)
            shutil.copy2(item, dst_val)

        gc.collect()

# -------------------- Episode counts --------------------
src_train_dir = DATASET_DIR / "train"
train_avail = int(len(load_lengths_only(src_train_dir)))
print(f"Original train episodes: {train_avail:,}")
assert train_avail > 0, "No train episodes available."

# -------------------- Build subsets --------------------
subset_roots: Dict[str, Path] = {}

for subset_name, target_total in SUBSET_SPECS.items():
    if target_total > train_avail:
        raise ValueError(f"Requested {target_total} episodes but only {train_avail} train episodes available.")

    subset_root = DATASET_DIR.parent / f"{DATASET_DIR.name}_{subset_name}"
    subset_roots[subset_name] = subset_root

    if subset_root.exists():
        if OVERWRITE:
            print(f"[overwrite] removing existing {subset_root}")
            shutil.rmtree(subset_root)
        else:
            print(f"[skip] {subset_root} already exists")
            continue

    print(f"\n== Building subset {subset_name} ({target_total:,} trajectories) ==")

    # Copy extras (anything not train/val)
    subset_root.mkdir(parents=True, exist_ok=True)
    copy_non_split_extras(DATASET_DIR, subset_root)

    # Uniform sample from ORIGINAL train, then disjoint 90/10 split
    rng = np.random.default_rng(RANDOM_SEED + target_total)
    sampled = rng.choice(train_avail, size=target_total, replace=False)
    perm = rng.permutation(len(sampled))

    train_target = int(round(target_total * 0.9))
    val_target   = target_total - train_target

    train_idxs = np.sort(sampled[perm[:train_target]])
    val_idxs   = np.sort(sampled[perm[train_target:]])

    assert len(train_idxs) == train_target and len(val_idxs) == val_target
    assert len(np.intersect1d(train_idxs, val_idxs)) == 0

    print(f"[{subset_name}] new train={len(train_idxs):,}, new val={len(val_idxs):,} (both from original train)")

    out_train_dir = subset_root / "train"
    out_val_dir   = subset_root / "val"

    write_train_val_from_train_split(
        src_train_dir=src_train_dir,
        out_train_dir=out_train_dir,
        out_val_dir=out_val_dir,
        train_idxs=train_idxs,
        val_idxs=val_idxs,
        orig_count=train_avail,
    )

    info = {
        "source_dataset": str(DATASET_DIR),
        "subset_name": subset_name,
        "target_total": int(target_total),
        "train_episodes": int(len(train_idxs)),
        "val_episodes": int(len(val_idxs)),
        "orig_train_episodes": int(train_avail),
        "seed": int(RANDOM_SEED),
        "sampling": "uniform sample target_total from original train (no replacement); permute and split 90/10 (disjoint).",
        "overwrite": bool(OVERWRITE),
    }
    with open(subset_root / "subset_info.json", "w") as f:
        json.dump(info, f, indent=2)

    gc.collect()
    print(f"Subset written to {subset_root}")

print("\nDone. Subset roots:")
for k, v in subset_roots.items():
    print(f"  {k}: {v}")


Dataset dir  : /Users/julianquast/Downloads/pusht_noise
Subset specs : {'1k': 1000, '10k': 10000}
Random seed  : 0
Original train episodes: 18,685

== Building subset 1k (1,000 trajectories) ==
[1k] new train=900, new val=100 (both from original train)
Subset written to /Users/julianquast/Downloads/pusht_noise_1k

== Building subset 10k (10,000 trajectories) ==
[10k] new train=9,000, new val=1,000 (both from original train)
Subset written to /Users/julianquast/Downloads/pusht_noise_10k

Done. Subset roots:
  1k: /Users/julianquast/Downloads/pusht_noise_1k
  10k: /Users/julianquast/Downloads/pusht_noise_10k


In [5]:
# === Sanity-check subsets: seq_lengths, obses, and episode-aligned action/state tensors ===
# Checks:
#  - seq_lengths.pkl exists, sane, and defines N episodes
#  - obses/episode_* count == N (if obses exists) and indices are contiguous
#  - rel_actions.pth (if present) has first dim == N
#  - states.pth / states_constant.pth (if present) has first dim == N
#  - any other *.pth/*.pt tensors at top-level with first dim == N are reported

import pickle, re, gc
from pathlib import Path
import numpy as np
import torch

torch.set_grad_enabled(False)
torch.set_num_threads(1)

LENGTHS_FNAME = "seq_lengths.pkl"
ACTIONS_FNAME = "rel_actions.pth"
STATES_FNAME = "states.pth"
STATES_CONST_FNAME = "states_constant.pth"

DATASET_DIR = Path("/Users/julianquast/Downloads/pusht_noise").expanduser()
SUBSET_DIRS = [
    DATASET_DIR.parent / f"{DATASET_DIR.name}_1k",
    DATASET_DIR.parent / f"{DATASET_DIR.name}_10k",
]

def load_lengths(split_dir: Path) -> np.ndarray:
    p = split_dir / LENGTHS_FNAME
    if not p.exists():
        raise FileNotFoundError(f"Missing {p}")
    with open(p, "rb") as f:
        x = pickle.load(f)
    x = x.astype(np.int64, copy=False) if isinstance(x, np.ndarray) else np.asarray(x, dtype=np.int64)
    return x.reshape(-1)

def count_obs_eps(obs_dir: Path) -> int:
    if not obs_dir.exists():
        return 0
    pat = re.compile(r"episode_(\d+)")
    ids = []
    for p in obs_dir.glob("episode_*"):
        m = pat.search(p.stem)
        if m:
            ids.append(int(m.group(1)))
    if not ids:
        return 0
    ids_sorted = sorted(ids)
    if ids_sorted[0] != 0 or ids_sorted[-1] != len(ids_sorted) - 1:
        raise ValueError(
            f"obses episode indices not contiguous in {obs_dir} "
            f"(min={ids_sorted[0]}, max={ids_sorted[-1]}, n={len(ids_sorted)})"
        )
    return len(ids_sorted)

def torch_load_safe(path: Path):
    # Avoid mmap quirks; keep it simple for checks.
    return torch.load(str(path), map_location="cpu")

def check_tensor_first_dim(path: Path, expected_n: int) -> bool:
    """Returns True if OK, False if mismatch. Prints details."""
    if not path.exists():
        print(f"    - {path.name}: (missing) [SKIP]")
        return True
    try:
        obj = torch_load_safe(path)
    except Exception as e:
        print(f"    - {path.name}: could not load ({type(e).__name__}: {e})")
        return False

    ok = True
    if torch.is_tensor(obj):
        if obj.ndim < 1:
            print(f"    - {path.name}: tensor has ndim<1 [WARN]")
        elif obj.shape[0] != expected_n:
            print(f"    - {path.name}: BAD first dim={obj.shape[0]} expected={expected_n}")
            ok = False
        else:
            print(f"    - {path.name}: OK shape={tuple(obj.shape)}")
    elif isinstance(obj, dict):
        # Some files are dicts; check any tensors inside that look episode-aligned.
        found_any = False
        for k, v in obj.items():
            if torch.is_tensor(v) and v.ndim >= 1:
                found_any = True
                if v.shape[0] != expected_n:
                    print(f"    - {path.name}[{k!r}]: BAD first dim={v.shape[0]} expected={expected_n}")
                    ok = False
                else:
                    print(f"    - {path.name}[{k!r}]: OK shape={tuple(v.shape)}")
        if not found_any:
            print(f"    - {path.name}: dict (no tensor fields checked) [INFO]")
    else:
        print(f"    - {path.name}: type={type(obj).__name__} (not checked) [INFO]")

    del obj
    gc.collect()
    return ok

def check_split(split_dir: Path):
    lengths = load_lengths(split_dir)
    n = len(lengths)

    print(f"  split: {split_dir.name}")
    print(f"    episodes (seq_lengths): {n:,} | minT={int(lengths.min())} maxT={int(lengths.max())} meanT={float(lengths.mean()):.2f}")

    if np.any(lengths <= 0):
        raise ValueError(f"{split_dir}: seq_lengths contains non-positive values")
    if np.any(~np.isfinite(lengths)):
        raise ValueError(f"{split_dir}: seq_lengths contains non-finite values")

    obs_dir = split_dir / "obses"
    if obs_dir.exists():
        obs_n = count_obs_eps(obs_dir)
        print(f"    obses/: {obs_n:,} episode_* files")
        if obs_n != n:
            print(f"    WARNING: obses count ({obs_n}) != seq_lengths count ({n})")
    else:
        print(f"    obses/: (not present)")

    # Explicit action/state checks
    ok = True
    ok &= check_tensor_first_dim(split_dir / ACTIONS_FNAME, n)
    ok &= check_tensor_first_dim(split_dir / STATES_FNAME, n)
    ok &= check_tensor_first_dim(split_dir / STATES_CONST_FNAME, n)

    # Optional: scan other top-level tensor files and report episode-aligned ones
    extra = sorted(list(split_dir.glob("*.pth")) + list(split_dir.glob("*.pt")))
    extra = [p for p in extra if p.name not in (ACTIONS_FNAME, STATES_FNAME, STATES_CONST_FNAME)]
    if extra:
        print(f"    other tensor files (top-level): {len(extra)}")
        for p in extra:
            # only print if it matches N; otherwise just note mismatch
            try:
                obj = torch_load_safe(p)
                if torch.is_tensor(obj) and obj.ndim >= 1:
                    tag = "OK" if obj.shape[0] == n else "mismatch"
                    print(f"    - {p.name}: {tag} shape={tuple(obj.shape)}")
                else:
                    print(f"    - {p.name}: type={type(obj).__name__} [INFO]")
                del obj
            except Exception as e:
                print(f"    - {p.name}: could not load ({type(e).__name__}: {e})")
            gc.collect()

    print("    RESULT:", "OK\n" if ok else "HAS ISSUES (see above)\n")

for root in SUBSET_DIRS:
    print(f"\n=== Checking subset: {root} ===")
    if not root.exists():
        print("  MISSING (dir does not exist)")
        continue
    for split in ("train", "val"):
        split_dir = root / split
        if not split_dir.exists():
            print(f"  MISSING split dir: {split_dir}")
            continue
        check_split(split_dir)



=== Checking subset: /Users/julianquast/Downloads/pusht_noise_1k ===
  split: train
    episodes (seq_lengths): 900 | minT=49 maxT=246 meanT=125.83
    obses/: 900 episode_* files
    - rel_actions.pth: OK shape=(900, 246, 2)
    - states.pth: OK shape=(900, 246, 5)
    - states_constant.pth: (missing) [SKIP]
    other tensor files (top-level): 3
    - abs_actions.pth: OK shape=(900, 246, 2)
    - tokens.pth: type=list [INFO]
    - velocities.pth: OK shape=(900, 246, 2)
    RESULT: OK

  split: val
    episodes (seq_lengths): 100 | minT=49 maxT=229 meanT=126.56
    obses/: 100 episode_* files
    - rel_actions.pth: OK shape=(100, 246, 2)
    - states.pth: OK shape=(100, 246, 5)
    - states_constant.pth: (missing) [SKIP]
    other tensor files (top-level): 3
    - abs_actions.pth: OK shape=(100, 246, 2)
    - tokens.pth: type=list [INFO]
    - velocities.pth: OK shape=(100, 246, 2)
    RESULT: OK


=== Checking subset: /Users/julianquast/Downloads/pusht_noise_10k ===
  split: train
  

In [4]:
# === Compute mean/std for actions + states (streaming; safe for large files) ===
# Computes per-dimension mean/std over ALL timesteps of all episodes in a split.
# Supports:
#   actions: [N, T, A] or [N, A]
#   states : [N, T, D] or [N, D]
#
# Updates vs your version:
#   - Optional ACTION_SCALE (e.g. 100.0) applied BEFORE stats
#   - Optional TIME_STRIDE (e.g. 5) applied on time axis BEFORE stats
#   - Reports BOTH raw and scaled/strided stats for actions (so you can compare)
#   - Still streams to avoid RAM spikes

import gc
from pathlib import Path
import numpy as np
import torch
from typing import Any

torch.set_grad_enabled(False)
torch.set_num_threads(1)

ACTIONS_FNAME = "rel_actions.pth"
STATES_FNAME = "states.pth"
STATES_CONST_FNAME = "states_constant.pth"

DATASET_DIR = Path("/Users/julianquast/Downloads/pusht_noise").expanduser()
TARGETS = [
    (DATASET_DIR.parent / f"{DATASET_DIR.name}_1k", "1k"),
    (DATASET_DIR.parent / f"{DATASET_DIR.name}_10k", "10k"),
]

# If memory is tight, reduce this. If it's fast, you can increase.
BATCH_EPISODES = 256

# --- YOUR NORMALIZATION / SAMPLING SETTINGS ---
ACTION_SCALE = 100.0     # set to 1.0 to disable; if you normalize actions by /100 in training, keep 100.0
TIME_STRIDE = 1          # set to 5 if you want stats computed on x[:, ::5, :] (training-time stride)

# -------------------- Core streaming stats --------------------
def _flatten_samples(x: torch.Tensor) -> torch.Tensor:
    """
    Flatten everything except last dim to [num_samples, feat_dim].
    """
    if x.ndim == 1:
        return x.reshape(-1, 1)
    return x.reshape(-1, x.shape[-1])

def _stream_mean_std_rows(x2: torch.Tensor, chunk_rows: int = 1_000_000) -> tuple[torch.Tensor, torch.Tensor, int]:
    """
    x2: [num_samples, feat_dim] float/double tensor on CPU
    returns mean, std (float32), and sample count.
    """
    n = 0
    mean = torch.zeros(x2.shape[1], dtype=torch.float64)
    M2 = torch.zeros(x2.shape[1], dtype=torch.float64)

    for start in range(0, x2.shape[0], chunk_rows):
        xb = x2[start:start + chunk_rows].to(torch.float64)
        nb = xb.shape[0]
        if nb == 0:
            continue

        b_mean = xb.mean(dim=0)
        b_var = xb.var(dim=0, unbiased=False)

        if n == 0:
            mean = b_mean
            M2 = b_var * nb
            n = nb
        else:
            delta = b_mean - mean
            new_n = n + nb
            mean = mean + delta * (nb / new_n)
            M2 = M2 + b_var * nb + (delta * delta) * (n * nb / new_n)
            n = new_n

    var = M2 / max(n, 1)
    std = torch.sqrt(var)
    return mean.to(torch.float32), std.to(torch.float32), int(n)

def _load_tensor(path: Path) -> Any:
    return torch.load(str(path), map_location="cpu")

def _extract_tensor(obj: Any, fname: str) -> torch.Tensor | None:
    if torch.is_tensor(obj):
        return obj
    if isinstance(obj, dict):
        # pick first tensor value (heuristic)
        for v in obj.values():
            if torch.is_tensor(v):
                return v
    return None

def _apply_time_stride(x: torch.Tensor, stride: int) -> torch.Tensor:
    if stride <= 1:
        return x
    # Only applies cleanly if there is a time dimension.
    # Convention here: [N, T, D] -> stride on axis 1.
    if x.ndim >= 3:
        return x[:, ::stride, ...]
    return x

def compute_stats_for_file(
    split_dir: Path,
    fname: str,
    *,
    action_scale: float = 1.0,
    time_stride: int = 1,
) -> dict:
    p = split_dir / fname
    if not p.exists():
        return {"file": fname, "present": False}

    obj = _load_tensor(p)
    x = _extract_tensor(obj, fname)
    if x is None:
        return {"file": fname, "present": True, "error": f"Not a tensor or dict-with-tensor: {type(obj).__name__}"}

    if x.ndim < 1:
        return {"file": fname, "present": True, "error": f"Tensor has ndim={x.ndim}"}

    # Apply time stride if applicable
    x = _apply_time_stride(x, time_stride)

    # Apply scaling (mainly for actions)
    if action_scale != 1.0:
        x = x / float(action_scale)

    N = int(x.shape[0])

    # Stream over episodes to control memory
    total_n = 0
    mean = torch.zeros(x.shape[-1] if x.ndim > 1 else 1, dtype=torch.float64)
    M2 = torch.zeros_like(mean)

    for s in range(0, N, BATCH_EPISODES):
        xb = x[s:s + BATCH_EPISODES]
        x2 = _flatten_samples(xb)  # [num_samples, feat_dim]

        b_mean, b_std, nb = _stream_mean_std_rows(x2)
        b_mean64 = b_mean.to(torch.float64)
        b_var64 = (b_std.to(torch.float64) ** 2)

        if total_n == 0:
            mean = b_mean64
            M2 = b_var64 * nb
            total_n = nb
        else:
            delta = b_mean64 - mean
            new_n = total_n + nb
            mean = mean + delta * (nb / new_n)
            M2 = M2 + b_var64 * nb + (delta * delta) * (total_n * nb / new_n)
            total_n = new_n

        gc.collect()

    var = M2 / max(total_n, 1)
    std = torch.sqrt(var)

    # cleanup
    del obj, x
    gc.collect()

    return {
        "file": fname,
        "present": True,
        "samples": int(total_n),
        "feat_dim": int(mean.numel()),
        "mean": mean.to(torch.float32).numpy(),
        "std": std.to(torch.float32).numpy(),
        "action_scale": float(action_scale),
        "time_stride": int(time_stride),
    }

def pretty_print_stats(stats: dict, max_dims: int = 10, label: str = ""):
    if not stats.get("present", False):
        print(f"    - {stats['file']}{label}: missing")
        return
    if "error" in stats:
        print(f"    - {stats['file']}{label}: ERROR: {stats['error']}")
        return

    mean = stats["mean"]
    std = stats["std"]
    d = stats["feat_dim"]
    shown = min(d, max_dims)

    def fmt(arr):
        return np.array2string(arr[:shown], precision=6, separator=", ")

    tail = "" if d <= max_dims else f" ... (+{d-max_dims} dims)"
    print(f"    - {stats['file']}{label}: samples={stats['samples']:,} feat_dim={d} (scale={stats['action_scale']}, stride={stats['time_stride']})")
    print(f"      mean: {fmt(mean)}{tail}")
    print(f"      std : {fmt(std)}{tail}")

for root, tag in TARGETS:
    print(f"\n=== Stats for subset {tag}: {root} ===")
    if not root.exists():
        print("  MISSING subset directory")
        continue

    for split in ("train", "val"):
        split_dir = root / split
        if not split_dir.exists():
            print(f"  MISSING split: {split_dir}")
            continue

        print(f"  split: {split}")

        # Actions: report RAW and (optionally) NORMALIZED/STRIDED
        a_raw = compute_stats_for_file(split_dir, ACTIONS_FNAME, action_scale=1.0, time_stride=1)
        pretty_print_stats(a_raw, label=" [RAW]")

        a_norm = compute_stats_for_file(split_dir, ACTIONS_FNAME, action_scale=ACTION_SCALE, time_stride=TIME_STRIDE)
        # Only print normalized if it differs from raw settings
        if ACTION_SCALE != 1.0 or TIME_STRIDE != 1:
            pretty_print_stats(a_norm, label=f" [SCALED/STRIDED]")

        # States: typically not scaled; optionally strided if you want stats at training stride
        s_stats = compute_stats_for_file(split_dir, STATES_FNAME, action_scale=1.0, time_stride=TIME_STRIDE)
        sc_stats = compute_stats_for_file(split_dir, STATES_CONST_FNAME, action_scale=1.0, time_stride=TIME_STRIDE)

        pretty_print_stats(s_stats)
        pretty_print_stats(sc_stats)




=== Stats for subset 1k: /Users/julianquast/Downloads/pusht_noise_1k ===
  split: train
    - rel_actions.pth [RAW]: samples=221,400 feat_dim=2 (scale=1.0, stride=1)
      mean: [-0.312646,  0.260833]
      std : [15.063884, 14.829579]
    - rel_actions.pth [SCALED/STRIDED]: samples=221,400 feat_dim=2 (scale=100.0, stride=1)
      mean: [-0.003126,  0.002608]
      std : [0.150639, 0.148296]
    - states.pth: samples=221,400 feat_dim=5 (scale=1.0, stride=1)
      mean: [118.281876, 148.32808 , 125.58703 , 139.52927 ,   1.085302]
      std : [137.72513, 161.16718, 133.32805, 145.71861,   1.73888]
    - states_constant.pth: missing
  split: val
    - rel_actions.pth [RAW]: samples=24,600 feat_dim=2 (scale=1.0, stride=1)
      mean: [-0.369958,  0.402305]
      std : [15.223104, 15.381005]
    - rel_actions.pth [SCALED/STRIDED]: samples=24,600 feat_dim=2 (scale=100.0, stride=1)
      mean: [-0.0037  ,  0.004023]
      std : [0.152231, 0.15381 ]
    - states.pth: samples=24,600 feat_dim=5