# PushT discretization verification


This notebook inspects the discretized PushT datasets, confirms that every new artifact was generated, and compares
them against the original continuous-action data. It is intended to be rerunnable whenever the discretization pipeline
is updated or a new dataset copy is produced.


In [18]:
# === Imports & configuration ===
import os, gc, math, pickle
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch

torch.set_grad_enabled(False)

SPLITS = ("train", "val")
ACTIONS_FNAME = "rel_actions.pth"
ABS_ACTIONS_FNAME = "abs_actions.pth"
VELOCITIES_FNAME = "velocities.pth"
STATES_FNAME = "states.pth"
STATE_CONSTANT_FNAME = "states_constant.pth"
LENGTHS_FNAME = "seq_lengths.pkl"
TOKENS_FNAME = "tokens.pth"
VARIANTS = ("kmeans_k9", "kmeans_k15", "fixed_compass_r24")
SUBSET_SPECS = {"1k": 1_000, "10k": 10_000}

def _resolve_dataset_dir(path_str: str, default_leaf: str = "pusht_noise") -> Path:
    base = Path(path_str).expanduser()
    if (base / "train").exists() and (base / "val").exists():
        return base
    candidate = base / default_leaf
    if (candidate / "train").exists() and (candidate / "val").exists():
        return candidate
    raise FileNotFoundError(f"Could not resolve dataset directory from {path_str!r}. Provide SRC_DATASET_DIR or DATASET_DIR.")

src_root_env = os.getenv("SRC_DATASET_DIR", os.getenv("DATASET_DIR", "data"))
SRC_DATASET_DIR = _resolve_dataset_dir(src_root_env)
dst_root_env = os.getenv("DST_DATASET_DIR", str(SRC_DATASET_DIR.parent / "pusht_noise_discretized"))
DST_DATASET_DIR = Path(dst_root_env).expanduser()

subset_roots = {name: DST_DATASET_DIR.parent / f"{DST_DATASET_DIR.name}_{name}" for name in SUBSET_SPECS}

print(f"Source dataset       : {SRC_DATASET_DIR}")
print(f"Discretized dataset  : {DST_DATASET_DIR}")
print(f"Subset roots         : {subset_roots}")
print(f"Variants             : {VARIANTS}")


Source dataset       : /Users/julianquast/Documents/Documents - pythonPedro (Mac)/Bachelor Thesis/Datasets/pusht_noise
Discretized dataset  : /Users/julianquast/Documents/Documents - pythonPedro (Mac)/Bachelor Thesis/Datasets/pusht_noise_discretized
Subset roots         : {'1k': PosixPath('/Users/julianquast/Documents/Documents - pythonPedro (Mac)/Bachelor Thesis/Datasets/pusht_noise_discretized_1k'), '10k': PosixPath('/Users/julianquast/Documents/Documents - pythonPedro (Mac)/Bachelor Thesis/Datasets/pusht_noise_discretized_10k')}
Variants             : ('kmeans_k9', 'kmeans_k15', 'fixed_compass_r24')


In [19]:
# === Helper utilities ===
def load_lengths(split_dir: Path) -> np.ndarray:
    with open(split_dir / LENGTHS_FNAME, "rb") as f:
        lengths = pickle.load(f)
    return np.asarray(lengths, dtype=np.int64)

def list_split_files(split_dir: Path) -> List[str]:
    return sorted(p.name for p in split_dir.iterdir() if p.is_file())

def describe_tensor(tensor: torch.Tensor) -> Dict[str, object]:
    return {"shape": tuple(int(x) for x in tensor.shape),
            "dtype": str(tensor.dtype),
            "min": float(tensor.min().item()),
            "max": float(tensor.max().item())}

def gather_episode_meta(dataset_dir: Path) -> Dict[str, Dict[str, object]]:
    meta: Dict[str, Dict[str, object]] = {}
    for split in SPLITS:
        split_dir = dataset_dir / split
        lengths = load_lengths(split_dir)
        if lengths.size == 0:
            min_len = max_len = total_steps = 0
        else:
            min_len = int(lengths.min())
            max_len = int(lengths.max())
            total_steps = int(lengths.sum())
        meta[split] = {
            "num_episodes": int(len(lengths)),
            "num_steps": total_steps,
            "min_len": min_len,
            "max_len": max_len,
            "lengths": lengths,
        }
    return meta

def allocate_subset_counts(total_target: int, split_meta: Dict[str, Dict[str, int]]) -> Dict[str, int]:
    total_available = sum(meta["num_episodes"] for meta in split_meta.values())
    if total_target >= total_available:
        return {split: split_meta[split]["num_episodes"] for split in SPLITS}
    counts: Dict[str, int] = {}
    remaining = total_target
    for idx, split in enumerate(SPLITS):
        meta = split_meta[split]
        if idx == len(SPLITS) - 1:
            take = remaining
        else:
            prop = meta["num_episodes"] / total_available
            take = min(meta["num_episodes"], int(round(total_target * prop)))
        counts[split] = max(0, min(meta["num_episodes"], take))
        remaining -= counts[split]
    while remaining > 0:
        for split in SPLITS:
            if counts[split] < split_meta[split]["num_episodes"]:
                counts[split] += 1
                remaining -= 1
                if remaining == 0:
                    break
    return counts

EPISODE_ALIGNED_FILES = [
    ABS_ACTIONS_FNAME,
    VELOCITIES_FNAME,
    ACTIONS_FNAME,
    STATES_FNAME,
    STATE_CONSTANT_FNAME,
    TOKENS_FNAME,
] + [f"rel_actions_discretized_{variant}.pth" for variant in VARIANTS]

def _episode_count_from_obj(obj) -> int | None:
    if torch.is_tensor(obj):
        return int(obj.shape[0])
    if isinstance(obj, (list, tuple)):
        return len(obj)
    return None

def verify_episode_aligned_files(split_dir: Path, expected_count: int):
    for fname in EPISODE_ALIGNED_FILES:
        path = split_dir / fname
        if not path.exists():
            raise FileNotFoundError(f"Missing {path}")
        obj = torch.load(path, map_location="cpu")
        count = _episode_count_from_obj(obj)
        if count is None:
            continue
        if count != expected_count:
            raise ValueError(f"{path} has {count} episodes, expected {expected_count}")

def count_observation_files(obs_dir: Path) -> int:
    if not obs_dir.exists():
        return 0
    return sum(1 for _ in obs_dir.glob("episode_*"))


def flatten_valid_tensor(tensor: torch.Tensor, lengths: np.ndarray) -> torch.Tensor:
    lengths = np.asarray(lengths, dtype=np.int64)
    if tensor.shape[0] != len(lengths):
        raise ValueError(f"Mismatch between tensor batch {tensor.shape[0]} and lengths {len(lengths)}")
    chunks: List[torch.Tensor] = []
    for epi, L in enumerate(lengths):
        L_int = int(L)
        if L_int > 0:
            chunks.append(tensor[epi, :L_int])
    if not chunks:
        return torch.empty((0, tensor.shape[-1]), dtype=tensor.dtype)
    return torch.cat(chunks, dim=0)


def sample_valid_positions(lengths: np.ndarray, num_samples: int, rng=None) -> List[Tuple[int, int]]:
    lengths = np.asarray(lengths, dtype=np.int64)
    total_steps = int(lengths.sum())
    if num_samples <= 0 or total_steps == 0:
        return []
    if rng is None:
        rng = np.random.default_rng()
    count = min(num_samples, total_steps)
    flat_indices = np.asarray(rng.choice(total_steps, size=count, replace=False), dtype=np.int64)
    flat_indices.sort()
    cumulative = lengths.cumsum()
    samples: List[Tuple[int, int]] = []
    for idx in flat_indices:
        epi = int(np.searchsorted(cumulative, idx, side='right'))
        prev = int(cumulative[epi - 1]) if epi > 0 else 0
        step = int(idx - prev)
        samples.append((epi, step))
    return samples


## Dataset normalization stats
Compute per-dimension mean and standard deviation for a chosen dataset to help with normalization. Update the dataset root or splits as needed before running.


In [None]:
# === Dataset normalization stats ===
NORMALIZATION_DATASET_DIR = SRC_DATASET_DIR  # change to DST_DATASET_DIR or a subset root if desired
NORMALIZATION_SPLITS = ("train",)  # typically normalization is computed on the train split
TARGET_FILES = (STATES_FNAME, VELOCITIES_FNAME, ACTIONS_FNAME, ABS_ACTIONS_FNAME)

def compute_mean_std(split_dir: Path, lengths: np.ndarray, fname: str) -> tuple[torch.Tensor, torch.Tensor]:
    path = split_dir / fname
    tensor = torch.load(path, map_location="cpu")
    flat = flatten_valid_tensor(tensor, lengths)
    if flat.numel() == 0:
        raise ValueError(f"No valid samples found in {path}")
    mean = flat.mean(dim=0)
    std = flat.std(dim=0, unbiased=False)
    return mean, std

print("== Normalization statistics ==")
for split in NORMALIZATION_SPLITS:
    split_dir = NORMALIZATION_DATASET_DIR / split
    lengths = load_lengths(split_dir)
    if lengths.size == 0:
        print(f"[skip] No episodes found for {split_dir}")
        continue
    for fname in TARGET_FILES:
        path = split_dir / fname
        if not path.exists():
            print(f"[skip] Missing {path}")
            continue
        mean, std = compute_mean_std(split_dir, lengths, fname)
        print(f"[{split}] {fname}: mean shape={tuple(mean.shape)}, std shape={tuple(std.shape)}")
        print(f"    mean={mean.tolist()}")
        print(f"    std ={std.tolist()}")


In [None]:
# === File inventory comparison ===
print("== File inventory comparison between source and discretized roots ==")
for split in SPLITS:
    src_files = set(list_split_files(SRC_DATASET_DIR / split))
    dst_files = set(list_split_files(DST_DATASET_DIR / split))
    missing = sorted(src_files - dst_files)
    extras = sorted(dst_files - src_files)
    print(f"[{split}] missing in discretized: {missing if missing else 'None'}")
    print(f"[{split}] new files          : {extras if extras else 'None'}")


In [20]:
# === Detailed comparison: source vs discretized base dataset ===
variant_stats: List[Dict[str, object]] = []
source_meta = gather_episode_meta(SRC_DATASET_DIR)
discretized_meta = gather_episode_meta(DST_DATASET_DIR)
for split in SPLITS:
    print(f"\n[{split}] episode counts")
    print(f"source episodes     : {source_meta[split]['num_episodes']}")
    print(f"discretized episodes: {discretized_meta[split]['num_episodes']}")
    if source_meta[split]['num_episodes'] != discretized_meta[split]['num_episodes']:
        raise ValueError(f"Episode counts disagree for {split}")
    src_lengths = source_meta[split]['lengths']
    dst_lengths = discretized_meta[split]['lengths']
    if not np.array_equal(src_lengths, dst_lengths):
        raise ValueError(f"Sequence lengths diverged for {split}")
    split_dir_src = SRC_DATASET_DIR / split
    split_dir_dst = DST_DATASET_DIR / split
    # Compare rel_actions tensors
    src_actions = torch.load(split_dir_src / ACTIONS_FNAME, map_location="cpu")
    dst_actions = torch.load(split_dir_dst / ACTIONS_FNAME, map_location="cpu")
    if src_actions.shape != dst_actions.shape:
        raise ValueError(f"rel_actions shape mismatch for {split}")
    if src_actions.dtype != dst_actions.dtype:
        raise ValueError(f"rel_actions dtype mismatch for {split}")
    actions_equal = bool(torch.equal(src_actions, dst_actions))
    max_abs_diff = float((dst_actions - src_actions).abs().max().item())
    print(f"rel_actions identical: {actions_equal} (max diff={max_abs_diff:.3e})")
    # Compare auxiliary continuous files
    for fname in (ABS_ACTIONS_FNAME, VELOCITIES_FNAME):
        src_tensor = torch.load(split_dir_src / fname, map_location="cpu")
        dst_tensor = torch.load(split_dir_dst / fname, map_location="cpu")
        if not torch.equal(src_tensor, dst_tensor):
            raise ValueError(f"{fname} mismatches for {split}")
    tokens_src = torch.load(split_dir_src / TOKENS_FNAME, map_location="cpu")
    tokens_dst = torch.load(split_dir_dst / TOKENS_FNAME, map_location="cpu")
    if len(tokens_src) != len(tokens_dst):
        raise ValueError(f"tokens length mismatch for {split}")
    # Validate discretized variants
    for variant in VARIANTS:
        var_path = split_dir_dst / f"rel_actions_discretized_{variant}.pth"
        if not var_path.exists():
            raise FileNotFoundError(f"Missing {var_path}")
        var_tensor = torch.load(var_path, map_location="cpu")
        if var_tensor.shape != src_actions.shape:
            raise ValueError(f"Variant {variant} shape mismatch for {split}")
        if var_tensor.dtype != src_actions.dtype:
            raise ValueError(f"Variant {variant} dtype mismatch for {split}")
        delta = var_tensor - src_actions
        mean_abs = float(delta.abs().mean().item())
        max_abs = float(delta.abs().max().item())
        flat = var_tensor.reshape(-1, var_tensor.shape[-1])
        unique_centroids = int(torch.unique(flat, dim=0).shape[0])
        variant_stats.append({"split": split, "variant": variant, "mean_abs_diff": mean_abs, "max_abs_diff": max_abs, "unique_centroids": unique_centroids})
        print(f"  {variant}: mean|Δ|={mean_abs:.4f}, max|Δ|={max_abs:.4f}, unique centroids={unique_centroids}")
        del var_tensor, delta, flat
    states = torch.load(split_dir_dst / STATES_FNAME, map_location="cpu")
    states_const = torch.load(split_dir_dst / STATE_CONSTANT_FNAME, map_location="cpu")
    if states_const.shape != states.shape:
        raise ValueError(f"states_constant shape mismatch for {split}")
    zero_max = float(states_const.abs().max().item())
    print(f"states_constant zeros check: max|value|={zero_max:.3e}")
    del src_actions, dst_actions, states, states_const, tokens_src, tokens_dst
    gc.collect()



[train] episode counts
source episodes     : 18685
discretized episodes: 18685
rel_actions identical: True (max diff=0.000e+00)
  kmeans_k9: mean|Δ|=3.7636, max|Δ|=153.7511, unique centroids=10
  kmeans_k15: mean|Δ|=3.0071, max|Δ|=144.6930, unique centroids=16
  fixed_compass_r24: mean|Δ|=4.6653, max|Δ|=180.2670, unique centroids=5
states_constant zeros check: max|value|=0.000e+00

[val] episode counts
source episodes     : 21
discretized episodes: 21
rel_actions identical: True (max diff=0.000e+00)
  kmeans_k9: mean|Δ|=3.5948, max|Δ|=71.3540, unique centroids=10
  kmeans_k15: mean|Δ|=2.8930, max|Δ|=66.2523, unique centroids=16
  fixed_compass_r24: mean|Δ|=4.4253, max|Δ|=89.8806, unique centroids=5
states_constant zeros check: max|value|=0.000e+00


In [21]:
# === Subset dataset verification (1k / 10k trajectories) ===
base_counts_for_alloc = {split: {"num_episodes": discretized_meta[split]["num_episodes"]} for split in SPLITS}
for subset_name, target_total in SUBSET_SPECS.items():
    subset_root = subset_roots[subset_name]
    print(f"\n== Checking subset {subset_name} ({target_total} requested trajectories) ==")
    if not subset_root.exists():
        print(f"[skip] {subset_root} does not exist")
        continue
    subset_meta = gather_episode_meta(subset_root)
    expected_counts = allocate_subset_counts(target_total, base_counts_for_alloc)
    total_expected = sum(expected_counts.values())
    total_actual = sum(meta["num_episodes"] for meta in subset_meta.values())
    print(f"expected total: {total_expected}, actual total: {total_actual}")
    if total_actual != total_expected:
        raise ValueError(f"Subset {subset_name} has {total_actual} episodes, expected {total_expected}")
    for split in SPLITS:
        expected = expected_counts.get(split, 0)
        actual = subset_meta[split]["num_episodes"]
        print(f"[{subset_name}/{split}] episodes: {actual} (expected {expected}), steps: {subset_meta[split]['num_steps']}")
        if actual != expected:
            raise ValueError(f"Subset {subset_name} split {split} mismatch: {actual} vs {expected}")
        verify_episode_aligned_files(subset_root / split, actual)
        lengths = subset_meta[split]["lengths"]
        if lengths.size:
            print(f"    lengths range: [{lengths.min()}, {lengths.max()}]")
        obs_count = count_observation_files((subset_root / split) / "obses")
        if obs_count and obs_count != actual:
            raise ValueError(f"Observation file count mismatch for {subset_name}/{split}: {obs_count} vs {actual}")



== Checking subset 1k (1000 requested trajectories) ==
expected total: 1000, actual total: 1000
[1k/train] episodes: 999 (expected 999), steps: 125676
    lengths range: [49, 246]
[1k/val] episodes: 1 (expected 1), steps: 114
    lengths range: [114, 114]

== Checking subset 10k (10000 requested trajectories) ==
expected total: 10000, actual total: 10000
[10k/train] episodes: 9989 (expected 9989), steps: 1243713
    lengths range: [49, 246]
[10k/val] episodes: 11 (expected 11), steps: 1317
    lengths range: [76, 167]


In [22]:
# === Aggregate discretization stats ===
if not variant_stats:
    print("No variant statistics collected. Run the previous cell first.")
else:
    header = f"{'Split':<8}{'Variant':<20}{'mean|Δ|':>12}{'max|Δ|':>12}{'Centroids':>12}"
    print(header)
    print('-' * len(header))
    for entry in variant_stats:
        print(f"{entry['split']:<8}{entry['variant']:<20}{entry['mean_abs_diff']:>12.6f}{entry['max_abs_diff']:>12.6f}{entry['unique_centroids']:>12}")


Split   Variant                  mean|Δ|      max|Δ|   Centroids
----------------------------------------------------------------
train   kmeans_k9               3.763560  153.751147          10
train   kmeans_k15              3.007122  144.692996          16
train   fixed_compass_r24       4.665251  180.266978           5
val     kmeans_k9               3.594826   71.354050          10
val     kmeans_k15              2.893014   66.252327          16
val     fixed_compass_r24       4.425332   89.880615           5


In [None]:
# === Codebook inspection ===
CODEBOOK_DIR = DST_DATASET_DIR / "codebooks"
centroid_cache: Dict[str, torch.Tensor] = {}
print("== Loading discretization centroids ==")
for variant in VARIANTS:
    npy_path = CODEBOOK_DIR / f"{variant}_centroids.npy"
    pt_path = CODEBOOK_DIR / f"{variant}_centroids.pt"
    if npy_path.exists():
        centers = np.load(npy_path)
    elif pt_path.exists():
        centers = torch.load(pt_path, map_location="cpu")
        if torch.is_tensor(centers):
            centers = centers.cpu().numpy()
    else:
        print(f"[warn] Missing centroids for {variant} in {CODEBOOK_DIR}")
        continue
    centers_tensor = torch.as_tensor(centers, dtype=torch.float32)
    centroid_cache[variant] = centers_tensor
    print(f"{variant}: shape={tuple(centers_tensor.shape)}")
    print(centers_tensor)


In [None]:
# === Unique centroid diagnostics (valid vs padded) ===
print("== Unique centroid counts including/excluding padding ==")
for split in SPLITS:
    lengths = load_lengths(DST_DATASET_DIR / split)
    for variant in VARIANTS:
        var_path = DST_DATASET_DIR / split / f"rel_actions_discretized_{variant}.pth"
        if not var_path.exists():
            print(f"[skip] Missing {var_path}")
            continue
        disc_tensor = torch.load(var_path, map_location="cpu")
        flat_all = disc_tensor.reshape(-1, disc_tensor.shape[-1])
        unique_all = torch.unique(flat_all, dim=0)
        flat_valid = flatten_valid_tensor(disc_tensor, lengths)
        unique_valid = torch.unique(flat_valid, dim=0) if flat_valid.numel() else torch.empty((0, disc_tensor.shape[-1]), dtype=disc_tensor.dtype)
        padding_vecs = []
        if unique_valid.shape[0] < unique_all.shape[0]:
            if unique_valid.numel():
                dists = torch.cdist(unique_all, unique_valid)
                unmatched = dists.min(dim=1).values > 1e-6
                padding_vecs = unique_all[unmatched].tolist()
            else:
                padding_vecs = unique_all.tolist()
        padding_str = f", padding vectors={padding_vecs}" if padding_vecs else ""
        print(f"[{split}/{variant}] unique(all)={unique_all.shape[0]}, unique(valid)={unique_valid.shape[0]}{padding_str}")
        del disc_tensor, flat_all, unique_all, flat_valid, unique_valid
    gc.collect()


In [None]:
# === Spot-check discretized samples ===
SAMPLES_PER_SPLIT: Dict[str, List[Tuple[int, int]]] = {}
rng = np.random.default_rng(0)
NUM_SAMPLES = 5
for split in SPLITS:
    lengths = load_lengths(SRC_DATASET_DIR / split)
    SAMPLES_PER_SPLIT[split] = sample_valid_positions(lengths, NUM_SAMPLES, rng=rng)
    print(f"Sampled positions for {split}: {SAMPLES_PER_SPLIT[split]}")

print("\n== Discretization spot-checks ==")
for split in SPLITS:
    src_actions = torch.load(SRC_DATASET_DIR / split / ACTIONS_FNAME, map_location="cpu")
    lengths = load_lengths(SRC_DATASET_DIR / split)
    for variant in VARIANTS:
        var_path = DST_DATASET_DIR / split / f"rel_actions_discretized_{variant}.pth"
        if not var_path.exists():
            print(f"[skip] Missing {var_path}")
            continue
        disc_tensor = torch.load(var_path, map_location="cpu")
        centroids = centroid_cache.get(variant)
        if centroids is None:
            print(f"[skip] No centroids loaded for {variant}")
            continue
        print(f"\n[{split}/{variant}] Checking {len(SAMPLES_PER_SPLIT[split])} samples")
        for epi, step in SAMPLES_PER_SPLIT[split]:
            if step >= lengths[epi]:
                print(f"  [skip] episode {epi} shorter than step {step}")
                continue
            orig = src_actions[epi, step]
            disc = disc_tensor[epi, step]
            dists = torch.norm(centroids - disc, dim=1)
            centroid_idx = int(torch.argmin(dists).item())
            centroid = centroids[centroid_idx]
            disc_centroid_max = float((disc - centroid).abs().max().item())
            orig_centroid_dist = float(torch.norm(orig - centroid).item())
            orig_disc_dist = float(torch.norm(orig - disc).item())
            print(f"  epi={epi:05d}, step={step:03d}, centroid={centroid_idx:02d}, |orig-centroid|={orig_centroid_dist:.4f}, |orig-disc|={orig_disc_dist:.4f}, max|disc-centroid|={disc_centroid_max:.2e}")
            print(f"    original    : {orig.tolist()}")
            print(f"    centroid    : {centroid.tolist()}")
            print(f"    discretized : {disc.tolist()}")
        del disc_tensor
    del src_actions
    gc.collect()
