# PushT Noise discretization & dataset variants

This notebook automates the PushT action discretization pipeline:

1. Copy the original `pusht_noise` dataset into a new folder that will hold all discretized assets.
2. Fit MiniBatchKMeans codebooks (`K=9` and `K=15`) on the train split and build a deterministic `r=24` fixed-compass codebook (center + 8 compass directions).
3. Discretize `rel_actions.pth` for every split into `rel_actions_discretized_<variant>.pth` tensors that keep the original `(E, T_max, 2)` shape but store centroid vectors instead of the continuous actions.
4. Add a `states_constant.pth` companion file (zeros with the same shape/dtype as `states.pth`) for non-proprio training tricks.
5. Materialize three full datasets on disk: the base `pusht_noise_discretized` copy, plus two uniformly subsampled variants (`pusht_noise_discretized_1k` and `_10k`) built from random trajectories.
6. Save all fitted codebooks under `codebooks/` to make downstream reuse trivial.

Configure the source/destination directories via the environment variables `SRC_DATASET_DIR` or `DATASET_DIR`, and `DST_DATASET_DIR` if you do not want the notebook to default to `<DATASET_DIR>/pusht_noise` and `<DATASET_DIR>/pusht_noise_discretized`.


In [None]:
import os
os.environ["DATASET_DIR"] = "/Users/julianquast/Documents/Documents - pythonPedro (Mac)/Bachelor Thesis/Datasets/pusht_noise"


In [None]:
# === Imports & global config ===
import os, gc, math, json, pickle, shutil, re
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch
from sklearn.cluster import MiniBatchKMeans
import psutil

torch.set_grad_enabled(False)

SPLITS = ("train", "val")
ACTIONS_FNAME = "rel_actions.pth"
LENGTHS_FNAME = "seq_lengths.pkl"
STATES_FNAME = "states.pth"
STATE_CONSTANT_FNAME = "states_constant.pth"

VARIANTS = {
    "kmeans_k9": {"type": "kmeans", "K": 9},
    "kmeans_k15": {"type": "kmeans", "K": 15},
    "fixed_compass_r24": {"type": "fixed_compass", "r": 24.0, "n_dirs": 4, "include_center": True},
}

SUBSET_SPECS = {
    "1k": 1_000,
    "10k": 10_000,
}

RANDOM_SEED = int(os.getenv("DISCRETIZATION_SEED", "0"))


def _resolve_dataset_dir(path_str: str, default_leaf: str = "pusht_noise") -> Path:
    base = Path(path_str).expanduser()
    if (base / "train").exists() and (base / "val").exists():
        return base
    candidate = base / default_leaf
    if (candidate / "train").exists() and (candidate / "val").exists():
        return candidate
    raise FileNotFoundError(f"Could not resolve dataset directory from {path_str!r}. Provide SRC_DATASET_DIR or DATASET_DIR.")

src_root_env = os.getenv("SRC_DATASET_DIR", os.getenv("DATASET_DIR", "data"))
SRC_DATASET_DIR = _resolve_dataset_dir(src_root_env)

dst_root_env = os.getenv("DST_DATASET_DIR", str(SRC_DATASET_DIR.parent / "pusht_noise_discretized"))
DST_DATASET_DIR = Path(dst_root_env).expanduser()

print(f"Source dataset : {SRC_DATASET_DIR}")
print(f"Discretized dir: {DST_DATASET_DIR}")
print(f"Variants       : {list(VARIANTS.keys())}")
print(f"Subset specs   : {SUBSET_SPECS}")
print(f"Random seed    : {RANDOM_SEED}")


In [None]:
# === Discretization helpers ===
def load_actions_and_lengths(split_dir: Path):
    actions = torch.load(split_dir / ACTIONS_FNAME, map_location="cpu")
    with open(split_dir / LENGTHS_FNAME, "rb") as f:
        lengths = np.asarray(pickle.load(f), dtype=np.int64)
    assert actions.shape[0] == len(lengths), f"Mismatch: {actions.shape[0]} tensors vs {len(lengths)} lengths"
    return actions, lengths


def load_lengths_only(split_dir: Path) -> np.ndarray:
    with open(split_dir / LENGTHS_FNAME, "rb") as f:
        lengths = np.asarray(pickle.load(f), dtype=np.int64)
    return lengths


def flatten_valid(actions_3d: torch.Tensor, lengths: np.ndarray) -> np.ndarray:
    E, T_max, _ = actions_3d.shape
    chunks: List[np.ndarray] = []
    for epi in range(E):
        L = int(lengths[epi])
        if L > 0:
            chunks.append(actions_3d[epi, :L].cpu().numpy())
    if not chunks:
        return np.empty((0, 2), dtype=np.float32)
    flat = np.concatenate(chunks, axis=0).astype(np.float32, copy=False)
    return flat


def fit_codebook_kmeans(xy_flat: np.ndarray, K: int, seed: int) -> np.ndarray:
    assert xy_flat.ndim == 2 and xy_flat.shape[1] == 2 and xy_flat.size > 0
    batch_size = min(8192, max(2048, xy_flat.shape[0] // 32))
    kmeans = MiniBatchKMeans(
        n_clusters=K,
        random_state=seed,
        batch_size=batch_size,
        n_init="auto",
    )
    kmeans.fit(xy_flat)
    centers = kmeans.cluster_centers_.astype(np.float32, copy=False)
    return centers


def predict_labels_chunked(xy_flat: np.ndarray, centers: np.ndarray, chunk_size: int | None = None) -> np.ndarray:
    N = xy_flat.shape[0]
    if N == 0:
        return np.empty((0,), dtype=np.int64)
    K = centers.shape[0]
    if chunk_size is None:
        avail_bytes = psutil.virtual_memory().available
        bytes_per_row = K * 4  # float32 distances
        target = max(100_000, min(N, (avail_bytes // 8) // max(bytes_per_row, 1)))
        chunk_size = int(target)
    labels = np.empty(N, dtype=np.int64)
    center_norm = (centers ** 2).sum(axis=1)
    for start in range(0, N, chunk_size):
        end = min(N, start + chunk_size)
        block = xy_flat[start:end]
        block_norm = (block ** 2).sum(axis=1, keepdims=True)
        dots = block @ centers.T
        d2 = block_norm + center_norm[None, :] - 2.0 * dots
        labels[start:end] = np.argmin(d2, axis=1)
    return labels


def map_labels_to_centroids(actions_3d: torch.Tensor, lengths: np.ndarray, labels: np.ndarray, centers: np.ndarray) -> torch.Tensor:
    out = torch.zeros_like(actions_3d)
    idx = 0
    for epi in range(actions_3d.shape[0]):
        L = int(lengths[epi])
        if L > 0:
            lab_slice = labels[idx: idx + L]
            idx += L
            centroid_slice = torch.from_numpy(centers[lab_slice]).to(out.dtype)
            out[epi, :L] = centroid_slice
    return out


def build_fixed_compass(r: float, n_dirs: int = 8, include_center: bool = True) -> np.ndarray:
    assert n_dirs > 0, "n_dirs must be positive"
    dirs = np.linspace(0.0, 2.0 * math.pi, n_dirs, endpoint=False)
    vecs = np.stack([np.cos(dirs), np.sin(dirs)], axis=1).astype(np.float32) * float(r)
    if include_center:
        centers = np.vstack([np.zeros((1, 2), dtype=np.float32), vecs])
    else:
        centers = vecs
    return centers.astype(np.float32)


def summarize_labels(labels: np.ndarray, K: int, name: str):
    if labels.size == 0:
        print(f"[{name}] No labels to summarize.")
        return
    counts = np.bincount(labels, minlength=K)
    probs = counts / counts.sum()
    entropy = -(probs[probs > 0] * np.log2(probs[probs > 0])).sum()
    print(f"[{name}] usage -> min={counts.min()} max={counts.max()} entropy={entropy:.3f} bits")


In [None]:
# === Copy dataset & gather metadata ===
def ensure_dataset_copy(src: Path, dst: Path):
    if dst.exists():
        print(f"[skip] {dst} already exists")
        return
    print(f"Copying dataset tree from {src} -> {dst} ...")
    shutil.copytree(src, dst)
    print("Copy complete.")


ensure_dataset_copy(SRC_DATASET_DIR, DST_DATASET_DIR)

split_meta: Dict[str, Dict[str, np.ndarray | int]] = {}
total_eps = 0
for split in SPLITS:
    split_dir = SRC_DATASET_DIR / split
    lengths = load_lengths_only(split_dir)
    split_meta[split] = {
        "lengths": lengths,
        "num_episodes": int(len(lengths)),
        "max_T": int(lengths.max()) if lengths.size else 0,
    }
    total_eps += len(lengths)
    print(f"[{split}] episodes={len(lengths):,} | max_T={split_meta[split]['max_T']}")

print(f"Total episodes across splits: {total_eps:,}")


In [None]:
# === Fit codebooks (MiniBatchKMeans + fixed compass) ===
CODEBOOK_DIR = DST_DATASET_DIR / "codebooks"
CODEBOOK_DIR.mkdir(parents=True, exist_ok=True)

train_actions, train_lengths = load_actions_and_lengths(SRC_DATASET_DIR / "train")
flat_train = flatten_valid(train_actions, train_lengths)
print(f"Train valid steps: {flat_train.shape[0]:,}")

discretization_centers: Dict[str, np.ndarray] = {}
codebook_paths: Dict[str, Path] = {}

for name, cfg in VARIANTS.items():
    if cfg["type"] == "kmeans":
        centers = fit_codebook_kmeans(flat_train, cfg["K"], seed=RANDOM_SEED)
        summarize_labels(
            predict_labels_chunked(flat_train[: min(200_000, len(flat_train))], centers),
            centers.shape[0],
            name + "_sample",
        )
    elif cfg["type"] == "fixed_compass":
        centers = build_fixed_compass(cfg["r"], n_dirs=cfg.get("n_dirs", 8), include_center=cfg.get("include_center", True))
    else:
        raise ValueError(f"Unknown variant type: {cfg['type']}")

    out_path = CODEBOOK_DIR / f"{name}_centroids.npy"
    np.save(out_path, centers)
    discretization_centers[name] = centers
    codebook_paths[name] = out_path
    print(f"Saved {name} codebook -> {out_path} | shape={centers.shape}")

# cleanup to free RAM
train_actions = None
flat_train = None
gc.collect()


In [None]:
# === Discretize rel_actions & add states_constant ===
def ensure_states_constant(split_dir: Path):
    out_path = split_dir / STATE_CONSTANT_FNAME
    src_path = split_dir / STATES_FNAME
    if out_path.exists():
        print(f"[skip] {out_path.name} already exists")
        return
    states = torch.load(src_path, map_location="cpu")
    zeros = torch.zeros_like(states)
    torch.save(zeros, out_path)
    print(f"Saved {out_path} (zeros with shape={tuple(states.shape)})")
    del states, zeros
    gc.collect()


def discretize_split(split: str, variant_name: str, centers: np.ndarray):
    src_split = SRC_DATASET_DIR / split
    dst_split = DST_DATASET_DIR / split
    actions, lengths = load_actions_and_lengths(src_split)
    flat = flatten_valid(actions, lengths)
    labels = predict_labels_chunked(flat, centers)
    summarize_labels(labels, centers.shape[0], name=f"{split}_{variant_name}")
    disc = map_labels_to_centroids(actions, lengths, labels, centers)
    out_path = dst_split / f"rel_actions_discretized_{variant_name}.pth"
    torch.save(disc, out_path)
    print(f"Saved {out_path} | shape={tuple(disc.shape)}")
    del actions, flat, labels, disc
    gc.collect()


for variant_name, centers in discretization_centers.items():
    print(f"\n== Discretizing variant: {variant_name} ==")
    for split in SPLITS:
        discretize_split(split, variant_name, centers)

for split in SPLITS:
    ensure_states_constant(DST_DATASET_DIR / split)


In [None]:
# === Verify discretized rel_actions files ===
print("\n== Checking discretized rel_actions artifacts ==")
for split in SPLITS:
    print(f"[{split}]")
    for variant in VARIANTS.keys():
        rel_path = Path(split) / f"rel_actions_discretized_{variant}.pth"
        abs_path = DST_DATASET_DIR / rel_path
        if not abs_path.exists():
            raise FileNotFoundError(f"Missing {rel_path} in {DST_DATASET_DIR}")
        size_mb = abs_path.stat().st_size / (1024 ** 2)
        print(f"  - {rel_path.name:>32} | size={size_mb:8.2f} MB")
    print()


In [None]:
# === Subset helpers (1k / 10k datasets) ===
def allocate_subset_counts(total_target: int, split_meta: Dict[str, Dict[str, int]]) -> Dict[str, int]:
    total_available = sum(meta["num_episodes"] for meta in split_meta.values())
    if total_target >= total_available:
        return {split: meta["num_episodes"] for split, meta in split_meta.items()}
    counts: Dict[str, int] = {}
    remaining = total_target
    for idx, split in enumerate(SPLITS):
        meta = split_meta[split]
        if idx == len(SPLITS) - 1:
            take = remaining
        else:
            prop = meta["num_episodes"] / total_available
            take = min(meta["num_episodes"], int(round(total_target * prop)))
        counts[split] = max(0, min(meta["num_episodes"], take))
        remaining -= counts[split]
    while remaining > 0:
        for split in SPLITS:
            if counts[split] < split_meta[split]["num_episodes"]:
                counts[split] += 1
                remaining -= 1
                if remaining == 0:
                    break
    return counts


def subset_tensor_file(path: Path, idxs: np.ndarray, orig_count: int):
    if not path.exists():
        return False
    obj = torch.load(path, map_location="cpu")
    if torch.is_tensor(obj) and obj.shape[0] == orig_count:
        subset = obj[idxs]
        torch.save(subset, path)
        return True
    if isinstance(obj, (list, tuple)) and len(obj) == orig_count:
        subset = [obj[int(i)] for i in idxs]
        torch.save(subset, path)
        return True
    return False


def subset_pickle_file(path: Path, idxs: np.ndarray, orig_count: int):
    if not path.exists():
        return False
    with open(path, "rb") as f:
        data = pickle.load(f)
    if isinstance(data, np.ndarray):
        data = data.tolist()
    if isinstance(data, list) and len(data) == orig_count:
        subset = [data[int(i)] for i in idxs]
        with open(path, "wb") as f:
            pickle.dump(subset, f)
        return True
    return False


def build_episode_file_map(obs_dir: Path) -> Dict[int, Path]:
    mapping = {}
    if not obs_dir.exists():
        return mapping
    pattern = re.compile(r"episode_(\d+)")
    for path in obs_dir.glob("episode_*"):
        m = pattern.search(path.stem)
        if m:
            mapping[int(m.group(1))] = path
    return mapping


def subset_observation_media(obs_dir: Path, idxs: np.ndarray):
    if not obs_dir.exists():
        return
    idxs = np.asarray(idxs, dtype=np.int64)
    mapping = build_episode_file_map(obs_dir)
    tmp_dir = obs_dir.parent / (obs_dir.name + "_subset_tmp")
    if tmp_dir.exists():
        shutil.rmtree(tmp_dir)
    tmp_dir.mkdir(parents=True, exist_ok=True)
    pad = max(3, len(str(len(idxs))))
    for new_idx, old_idx in enumerate(idxs):
        src_path = mapping.get(int(old_idx))
        if src_path is None:
            raise FileNotFoundError(f"Missing obs file for episode {old_idx} in {obs_dir}")
        dst_path = tmp_dir / f"episode_{new_idx:0{pad}d}{src_path.suffix}"
        shutil.copy2(src_path, dst_path)
    shutil.rmtree(obs_dir)
    tmp_dir.rename(obs_dir)


def subset_split_dir(split_dir: Path, idxs: np.ndarray, orig_count: int):
    idxs = np.asarray(idxs, dtype=np.int64)
    for path in split_dir.glob("*.pth"):
        subset_tensor_file(path, idxs, orig_count)
    for path in split_dir.glob("*.pt"):
        subset_tensor_file(path, idxs, orig_count)
    for path in split_dir.glob("*.pkl"):
        subset_pickle_file(path, idxs, orig_count)
    subset_observation_media(split_dir / "obses", idxs)


In [None]:
# === Build subset datasets (1k / 10k trajectories) ===
subset_roots: Dict[str, Path] = {}
for subset_name, target_total in SUBSET_SPECS.items():
    subset_root = DST_DATASET_DIR.parent / f"{DST_DATASET_DIR.name}_{subset_name}"
    subset_roots[subset_name] = subset_root
    if subset_root.exists():
        print(f"[skip] {subset_root} already exists")
        continue
    print(f"\n== Building subset {subset_name} ({target_total} trajectories) ==")
    shutil.copytree(DST_DATASET_DIR, subset_root)
    counts = allocate_subset_counts(target_total, split_meta)
    subset_rng = np.random.default_rng(RANDOM_SEED + target_total)
    for split in SPLITS:
        orig_count = split_meta[split]["num_episodes"]
        take = min(orig_count, counts.get(split, 0))
        if take <= 0:
            idxs = np.empty((0,), dtype=np.int64)
        elif take == orig_count:
            idxs = np.arange(orig_count, dtype=np.int64)
        else:
            idxs = np.sort(subset_rng.choice(orig_count, size=take, replace=False))
        print(f"[{subset_name}/{split}] keeping {len(idxs)} / {orig_count} episodes")
        subset_split_dir(subset_root / split, idxs, orig_count)
    print(f"Subset {subset_name} written to {subset_root}")
