
# PushT Discretization & Codebook Notebook (K=9, K=16)

This notebook:
1. Loads **relative actions** (`rel_actions.pth`, shape `(E, T_max, 2)`) and **sequence lengths** (`seq_lengths.pkl`).
2. Fits **MiniBatchKMeans** codebooks on the **train** split for **K = 9** and **K = 16**.
3. Saves codebooks as `action_codebook_k9.npy` and `action_codebook_k16.npy` in the train directory.
4. Discretizes **train** and **val** splits to `rel_actions_discrete_9.pth` and `rel_actions_discrete_16.pth` using **float labels in [0..K-1]** (both channels set to the label, padding set to **0.0**).
5. Provides a **filtering/analysis** cell to measure effects of removing episodes with outlier actions (does **not** change training data unless you explicitly use it).
6. Performs **sanity checks** on outputs (shape, label range, entropy, quick QE sample).


In [None]:

# === Config & Imports ===
import os, math, gc, pickle
from pathlib import Path
import numpy as np
import torch
from sklearn.cluster import MiniBatchKMeans
import psutil

# ---- EDIT THESE to your dataset layout ----
TRAIN_DIR = Path(os.getenv("TRAIN_DIR", "data/pusht/train"))  # must contain rel_actions.pth and seq_lengths.pkl
VAL_DIR   = Path(os.getenv("VAL_DIR",   "data/pusht/val"))

# K choices
K_LIST = [9, 15]
SEED = 0

# File names (expected existing inputs)
ACTIONS_FNAME = "rel_actions.pth"     # shape (E, T_max, 2), float
LENGTHS_FNAME = "seq_lengths.pkl"     # list/array of ints length E

# Output names
DISCRETE_TEMPLATE = "rel_actions_discrete_{K}.pth"  # same shape as actions, stores float labels in both channels (padding=0.0)
CODEBOOK_TEMPLATE = "action_codebook_k{K}.npy"      # (K, 2) float32

# Safety
torch.set_grad_enabled(False)
np.random.seed(SEED)

print(f"TRAIN_DIR = {TRAIN_DIR.resolve()}")
print(f"VAL_DIR   = {VAL_DIR.resolve()}")
print(f"K_LIST    = {K_LIST}")


In [None]:

# === Utilities ===
def load_split(split_dir: Path):
    actions = torch.load(split_dir / ACTIONS_FNAME, map_location="cpu")  # (E, T_max, 2)
    with open(split_dir / LENGTHS_FNAME, "rb") as f:
        lengths = np.asarray(pickle.load(f), dtype=np.int64)            # (E,)
    assert actions.ndim == 3 and actions.shape[-1] == 2, "rel_actions must be (E, T_max, 2)"
    E, T_max, _ = actions.shape
    assert lengths.shape[0] == E, f"seq_lengths length ({lengths.shape[0]}) != num episodes ({E})"
    return actions, lengths

def flatten_valid(actions_3d: torch.Tensor, lengths: np.ndarray) -> np.ndarray:
    \"\"\"Concatenate only valid (unpadded) steps into (N,2) float32.\"\"\"
    E, T_max, _ = actions_3d.shape
    lst = []
    for i in range(E):
        L = int(lengths[i])
        if L > 0:
            lst.append(actions_3d[i, :L, :].cpu().numpy())
    if not lst:
        return np.empty((0, 2), dtype=np.float32)
    flat = np.concatenate(lst, axis=0).astype(np.float32, copy=False)
    return flat

def fit_codebook_kmeans(xy_flat: np.ndarray, K: int, seed: int = 0) -> np.ndarray:
    \"\"\"Fit MiniBatchKMeans codebook on flattened valid actions.\"\"\"
    assert xy_flat.ndim == 2 and xy_flat.shape[1] == 2 and xy_flat.size > 0
    bs = 2048  # minibatch size
    kmeans = MiniBatchKMeans(
        n_clusters=K, random_state=seed, batch_size=bs, n_init="auto"
    )
    kmeans.fit(xy_flat)
    centers = kmeans.cluster_centers_.astype(np.float32, copy=False)
    return centers

def predict_labels_chunked(xy_flat: np.ndarray, centers: np.ndarray, chunk_size: int = None) -> np.ndarray:
    \"\"\"Argmin over squared distances to centers, chunked to avoid RAM blowups. Returns labels int64 shape (N,).\"\"\"
    N, D = xy_flat.shape
    K = centers.shape[0]
    if chunk_size is None:
        # memory-based heuristic: keep approx <= 400MB buffer for distances
        avail_gb = psutil.virtual_memory().available / 1e9
        target_bytes = min(400, max(100, int(avail_gb * 0.5))) * 1024**2  # between 100MB and ~0.5*avail
        bytes_per_row = K * 4  # float32 distances
        chunk_size = max(100_000, min(N, target_bytes // bytes_per_row))
    labels = np.empty(N, dtype=np.int64)
    c2 = (centers ** 2).sum(axis=1)  # (K,)
    for s in range(0, N, chunk_size):
        e = min(N, s + chunk_size)
        X = xy_flat[s:e]  # (M,2)
        x2 = (X ** 2).sum(axis=1, keepdims=True)  # (M,1)
        dots = X @ centers.T                       # (M,K)
        d2 = x2 + c2[None, :] - 2.0 * dots
        labels[s:e] = np.argmin(d2, axis=1)
    return labels

def labels_to_tensor_like_floatonly(actions_3d: torch.Tensor, lengths: np.ndarray, labels: np.ndarray) -> torch.Tensor:
    \"\"\"Map flat labels back to (E, T_max, 2) with float labels only (both channels). Padding -> 0.0.\"\"\"
    E, T_max, _ = actions_3d.shape
    out = torch.zeros_like(actions_3d)  # default padding = 0.0 in both channels
    idx = 0
    for i in range(E):
        L = int(lengths[i])
        if L > 0:
            l_slice = labels[idx: idx + L].astype(np.float32)
            idx += L
            out[i, :L, 0] = torch.from_numpy(l_slice)
            out[i, :L, 1] = torch.from_numpy(l_slice)
    assert idx == labels.shape[0], "Label count mismatch when reshaping"
    return out

def summarize_labels(labels: np.ndarray, K: int, name: str):
    from collections import Counter
    cnt = Counter(labels.tolist())
    total = labels.size
    if total > 0:
        entropy = -sum((n/total)*math.log2(n/total) for n in cnt.values() if n > 0)
    else:
        entropy = 0.0
    top = sorted(cnt.items(), key=lambda kv: kv[1], reverse=True)[:10]
    print(f\"[{name}] K={K} | N={total:,} | Entropy={entropy:.2f} bits | Unique labels={len(cnt)}/{K}\")
    print(\" Top-10 label counts:\", top)


### Optional: Episode filtering analysis (does not affect training unless you use the masks)

In [None]:

# === Optional: Filter episodes using EPISODES_LENGTHS_PATH and report stats ===
import pandas as pd

# Config for this analysis cell only (does not modify training data)
USE_PERCENTILE = True
PERCENTILE     = 99.0
ABS_THRESHOLD  = 60.0

# Pick which split to analyze (train by default)
SPLIT_DIR = TRAIN_DIR

# Allow overriding with EPISODES_LENGTHS_PATH, otherwise use split lengths
if 'EPISODES_LENGTHS_PATH' in globals():
    with open(EPISODES_LENGTHS_PATH, "rb") as f:
        ep_lengths = np.asarray(pickle.load(f), dtype=np.int64)
    actions_3d = torch.load(SPLIT_DIR / ACTIONS_FNAME, map_location="cpu")
else:
    actions_3d, ep_lengths = load_split(SPLIT_DIR)

E, T_max, _ = actions_3d.shape
flat = flatten_valid(actions_3d, ep_lengths)
mag = np.linalg.norm(flat, axis=1)

if USE_PERCENTILE:
    threshold = float(np.percentile(mag, PERCENTILE))
    th_desc = f">{PERCENTILE:.2f}p ({threshold:.2f}px)"
else:
    threshold = float(ABS_THRESHOLD)
    th_desc = f">{threshold:.2f}px"

starts = np.r_[0, np.cumsum(ep_lengths)[:-1]]
ends   = starts + ep_lengths

ep_max = np.empty(E, dtype=np.float32)
idx = 0
for i, (s, e) in enumerate(zip(starts, ends)):
    L = (e - s)
    ep_max[i] = mag[idx: idx + L].max() if L > 0 else 0.0
    idx += L
assert idx == flat.shape[0]

keep_episode_mask = ep_max <= threshold
kept_E = int(keep_episode_mask.sum())

# Build action mask for kept episodes
keep_action_mask = np.zeros(flat.shape[0], dtype=bool)
idx = 0
for i, (s, e, keep) in enumerate(zip(starts, ends, keep_episode_mask)):
    L = (e - s)
    if keep and L > 0:
        keep_action_mask[idx: idx + L] = True
    idx += L
kept_N = int(keep_action_mask.sum())

outlier_action_mask = mag > threshold
outlier_N = int(outlier_action_mask.sum())

summary = {
    "Split": SPLIT_DIR.name,
    "Threshold": th_desc,
    "Episodes (total)": f"{E:,}",
    "Episodes kept": f"{kept_E:,} ({kept_E/E*100:.2f}%)",
    "Actions (total)": f"{flat.shape[0]:,}",
    "Actions kept": f"{kept_N:,} ({kept_N/flat.shape[0]*100:.2f}%)",
    "Outlier actions": f"{outlier_N:,} ({outlier_N/flat.shape[0]*100:.2f}%)",
}
display(pd.DataFrame([summary]))


### Fit codebooks on **train** and save centroids

In [None]:

# === Fit codebooks on TRAIN valid steps and save centroids ===
train_actions, train_lengths = load_split(TRAIN_DIR)
flat_train = flatten_valid(train_actions, train_lengths)
print(f"[Train] Valid actions: {flat_train.shape[0]:,}")

CODEBOOK_PATHS = {}

for K in K_LIST:
    print(f"\n-- Fitting MiniBatchKMeans K={K} --")
    centers = fit_codebook_kmeans(flat_train, K, seed=SEED)
    np.save(TRAIN_DIR / CODEBOOK_TEMPLATE.format(K=K), centers)
    CODEBOOK_PATHS[K] = TRAIN_DIR / CODEBOOK_TEMPLATE.format(K=K)
    print(f" Saved codebook to: {CODEBOOK_PATHS[K]}")
    # quick reconstruction error on a 100k sample for visibility
    n_sample = min(100_000, len(flat_train))
    idx = np.random.choice(len(flat_train), size=n_sample, replace=False) if n_sample > 0 else np.array([], dtype=int)
    if n_sample > 0:
        sample = flat_train[idx]
        labels_sample = predict_labels_chunked(sample, centers, chunk_size=200_000)
        err = np.linalg.norm(sample - centers[labels_sample], axis=1)
        print(f" Sample mean error: {err.mean():.2f}px | p95: {np.percentile(err,95):.2f}px")
    gc.collect()


### Discretize **train** and **val** using the fitted codebooks

In [None]:

# === Discretize train and val with fitted codebooks; save rel_actions_discrete_{K}.pth ===
def discretize_split_floatlabels(split_dir: Path, K: int, centers: np.ndarray):
    actions, lengths = load_split(split_dir)
    flat = flatten_valid(actions, lengths)
    print(f"[{split_dir.name}] Discretizing with K={K} | valid steps={flat.shape[0]:,}")
    labels = predict_labels_chunked(flat, centers) if flat.size > 0 else np.empty((0,), dtype=np.int64)
    summarize_labels(labels, K, name=f"{split_dir.name}")
    out_tensor = labels_to_tensor_like_floatonly(actions, lengths, labels)
    out_path = split_dir / DISCRETE_TEMPLATE.format(K=K)
    torch.save(out_tensor, out_path)
    print(f" Saved: {out_path} | shape={tuple(out_tensor.shape)} | dtype={out_tensor.dtype}")
    return out_path

SAVED_DISCRETE = {}
for K in K_LIST:
    centers = np.load(CODEBOOK_PATHS[K])
    print(f"\n== K={K} ==")
    SAVED_DISCRETE[(K, "train")] = discretize_split_floatlabels(TRAIN_DIR, K, centers)
    SAVED_DISCRETE[(K, "val")]   = discretize_split_floatlabels(VAL_DIR,   K, centers)
    gc.collect()


### Post-checks: shapes, label ranges, integer-like labels, quick QE

In [None]:

def quick_check_floatlabels(split_dir: Path, K: int, centers: np.ndarray, name: str):
    disc_t = torch.load(split_dir / DISCRETE_TEMPLATE.format(K=K), map_location="cpu")  # (E,T,2)
    acts_t = torch.load(split_dir / ACTIONS_FNAME,              map_location="cpu")
    with open(split_dir / LENGTHS_FNAME, "rb") as f:
        lengths = np.asarray(pickle.load(f), dtype=np.int64)

    E, T_max, _ = acts_t.shape
    assert disc_t.shape == acts_t.shape == (E, T_max, 2)
    # Collect valid labels (first channel)
    labels = []
    for i in range(E):
        L = int(lengths[i])
        if L > 0:
            labels.append(disc_t[i, :L, 0].numpy())
    labels = np.concatenate(labels, axis=0).astype(np.float32) if labels else np.empty((0,), dtype=np.float32)

    if labels.size > 0:
        vmin, vmax = float(labels.min()), float(labels.max())
        in_range = (vmin >= 0.0) and (vmax <= (K - 1 + 1e-6))
        labels_rounded = np.rint(labels).astype(np.int64)
        integer_ok = np.allclose(labels, labels_rounded, atol=1e-6)
        summarize_labels(labels_rounded, K, name=f"{name}")

        print(f"[{name}] K={K} | shape OK: {disc_t.shape} | values in [0,{K-1}]: {in_range} "
              f"| integer-like: {integer_ok} | N_valid={labels.size:,} | min={vmin:.1f} max={vmax:.1f}")

        # Quick QE on a sample of original continuous actions
        flat = flatten_valid(acts_t, lengths)
        n_sample = min(100_000, len(flat))
        if n_sample > 0:
            idx = np.random.choice(len(flat), size=n_sample, replace=False)
            sample = flat[idx]
            pred = predict_labels_chunked(sample, centers, chunk_size=200_000)
            err = np.linalg.norm(sample - centers[pred], axis=1)
            print(f" QE(sample): mean={err.mean():.2f}px | p95={np.percentile(err,95):.2f}px")
    else:
        print(f"[{name}] K={K} | No valid labels found.")

for K in K_LIST:
    centers = np.load(CODEBOOK_PATHS[K])
    print(f"\n== Post-checks for K={K} with float-only labels ==")
    quick_check_floatlabels(TRAIN_DIR, K, centers, name="train")
    quick_check_floatlabels(VAL_DIR,   K, centers, name="val")
