# PushT discretization verification


This notebook inspects the discretized PushT datasets, confirms that every new artifact was generated, and compares
them against the original continuous-action data. It is intended to be rerunnable whenever the discretization pipeline
is updated or a new dataset copy is produced.


In [None]:
# === Imports & configuration ===
import os, gc, math, pickle
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch

torch.set_grad_enabled(False)

SPLITS = ("train", "val")
ACTIONS_FNAME = "rel_actions.pth"
ABS_ACTIONS_FNAME = "abs_actions.pth"
VELOCITIES_FNAME = "velocities.pth"
STATES_FNAME = "states.pth"
STATE_CONSTANT_FNAME = "states_constant.pth"
LENGTHS_FNAME = "seq_lengths.pkl"
TOKENS_FNAME = "tokens.pth"
VARIANTS = ("kmeans_k9", "kmeans_k15", "fixed_compass_r24")
SUBSET_SPECS = {"1k": 1_000, "10k": 10_000}

def _resolve_dataset_dir(path_str: str, default_leaf: str = "pusht_noise") -> Path:
    base = Path(path_str).expanduser()
    if (base / "train").exists() and (base / "val").exists():
        return base
    candidate = base / default_leaf
    if (candidate / "train").exists() and (candidate / "val").exists():
        return candidate
    raise FileNotFoundError(f"Could not resolve dataset directory from {path_str!r}. Provide SRC_DATASET_DIR or DATASET_DIR.")

src_root_env = os.getenv("SRC_DATASET_DIR", os.getenv("DATASET_DIR", "data"))
SRC_DATASET_DIR = _resolve_dataset_dir(src_root_env)
dst_root_env = os.getenv("DST_DATASET_DIR", str(SRC_DATASET_DIR.parent / "pusht_noise_discretized"))
DST_DATASET_DIR = Path(dst_root_env).expanduser()

subset_roots = {name: DST_DATASET_DIR.parent / f"{DST_DATASET_DIR.name}_{name}" for name in SUBSET_SPECS}

print(f"Source dataset       : {SRC_DATASET_DIR}")
print(f"Discretized dataset  : {DST_DATASET_DIR}")
print(f"Subset roots         : {subset_roots}")
print(f"Variants             : {VARIANTS}")


In [None]:
# === Helper utilities ===
def load_lengths(split_dir: Path) -> np.ndarray:
    with open(split_dir / LENGTHS_FNAME, "rb") as f:
        lengths = pickle.load(f)
    return np.asarray(lengths, dtype=np.int64)

def list_split_files(split_dir: Path) -> List[str]:
    return sorted(p.name for p in split_dir.iterdir() if p.is_file())

def describe_tensor(tensor: torch.Tensor) -> Dict[str, object]:
    return {"shape": tuple(int(x) for x in tensor.shape),
            "dtype": str(tensor.dtype),
            "min": float(tensor.min().item()),
            "max": float(tensor.max().item())}

def gather_episode_meta(dataset_dir: Path) -> Dict[str, Dict[str, object]]:
    meta: Dict[str, Dict[str, object]] = {}
    for split in SPLITS:
        split_dir = dataset_dir / split
        lengths = load_lengths(split_dir)
        if lengths.size == 0:
            min_len = max_len = total_steps = 0
        else:
            min_len = int(lengths.min())
            max_len = int(lengths.max())
            total_steps = int(lengths.sum())
        meta[split] = {
            "num_episodes": int(len(lengths)),
            "num_steps": total_steps,
            "min_len": min_len,
            "max_len": max_len,
            "lengths": lengths,
        }
    return meta

def allocate_subset_counts(total_target: int, split_meta: Dict[str, Dict[str, int]]) -> Dict[str, int]:
    total_available = sum(meta["num_episodes"] for meta in split_meta.values())
    if total_target >= total_available:
        return {split: split_meta[split]["num_episodes"] for split in SPLITS}
    counts: Dict[str, int] = {}
    remaining = total_target
    for idx, split in enumerate(SPLITS):
        meta = split_meta[split]
        if idx == len(SPLITS) - 1:
            take = remaining
        else:
            prop = meta["num_episodes"] / total_available
            take = min(meta["num_episodes"], int(round(total_target * prop)))
        counts[split] = max(0, min(meta["num_episodes"], take))
        remaining -= counts[split]
    while remaining > 0:
        for split in SPLITS:
            if counts[split] < split_meta[split]["num_episodes"]:
                counts[split] += 1
                remaining -= 1
                if remaining == 0:
                    break
    return counts

EPISODE_ALIGNED_FILES = [
    ABS_ACTIONS_FNAME,
    VELOCITIES_FNAME,
    ACTIONS_FNAME,
    STATES_FNAME,
    STATE_CONSTANT_FNAME,
    TOKENS_FNAME,
] + [f"rel_actions_discretized_{variant}.pth" for variant in VARIANTS]

def _episode_count_from_obj(obj) -> int | None:
    if torch.is_tensor(obj):
        return int(obj.shape[0])
    if isinstance(obj, (list, tuple)):
        return len(obj)
    return None

def verify_episode_aligned_files(split_dir: Path, expected_count: int):
    for fname in EPISODE_ALIGNED_FILES:
        path = split_dir / fname
        if not path.exists():
            raise FileNotFoundError(f"Missing {path}")
        obj = torch.load(path, map_location="cpu")
        count = _episode_count_from_obj(obj)
        if count is None:
            continue
        if count != expected_count:
            raise ValueError(f"{path} has {count} episodes, expected {expected_count}")

def count_observation_files(obs_dir: Path) -> int:
    if not obs_dir.exists():
        return 0
    return sum(1 for _ in obs_dir.glob("episode_*"))


def flatten_valid_entries(tensor: torch.Tensor, lengths: np.ndarray) -> torch.Tensor:
    chunks = []
    for epi in range(tensor.shape[0]):
        L = int(lengths[epi])
        if L > 0:
            chunks.append(tensor[epi, :L])
    if not chunks:
        return torch.empty((0, tensor.shape[-1]), dtype=tensor.dtype)
    return torch.cat(chunks, dim=0)


In [None]:
# === File inventory comparison ===
print("== File inventory comparison between source and discretized roots ==")
for split in SPLITS:
    src_files = set(list_split_files(SRC_DATASET_DIR / split))
    dst_files = set(list_split_files(DST_DATASET_DIR / split))
    missing = sorted(src_files - dst_files)
    extras = sorted(dst_files - src_files)
    print(f"[{split}] missing in discretized: {missing if missing else 'None'}")
    print(f"[{split}] new files          : {extras if extras else 'None'}")


In [None]:
# === Detailed comparison: source vs discretized base dataset ===
variant_stats: List[Dict[str, object]] = []
source_meta = gather_episode_meta(SRC_DATASET_DIR)
discretized_meta = gather_episode_meta(DST_DATASET_DIR)
for split in SPLITS:
    print(f"\n[{split}] episode counts")
    print(f"source episodes     : {source_meta[split]['num_episodes']}")
    print(f"discretized episodes: {discretized_meta[split]['num_episodes']}")
    if source_meta[split]['num_episodes'] != discretized_meta[split]['num_episodes']:
        raise ValueError(f"Episode counts disagree for {split}")
    src_lengths = source_meta[split]['lengths']
    dst_lengths = discretized_meta[split]['lengths']
    if not np.array_equal(src_lengths, dst_lengths):
        raise ValueError(f"Sequence lengths diverged for {split}")
    split_dir_src = SRC_DATASET_DIR / split
    split_dir_dst = DST_DATASET_DIR / split
    # Compare rel_actions tensors
    src_actions = torch.load(split_dir_src / ACTIONS_FNAME, map_location="cpu")
    dst_actions = torch.load(split_dir_dst / ACTIONS_FNAME, map_location="cpu")
    if src_actions.shape != dst_actions.shape:
        raise ValueError(f"rel_actions shape mismatch for {split}")
    if src_actions.dtype != dst_actions.dtype:
        raise ValueError(f"rel_actions dtype mismatch for {split}")
    actions_equal = bool(torch.equal(src_actions, dst_actions))
    max_abs_diff = float((dst_actions - src_actions).abs().max().item())
    print(f"rel_actions identical: {actions_equal} (max diff={max_abs_diff:.3e})")
    # Compare auxiliary continuous files
    for fname in (ABS_ACTIONS_FNAME, VELOCITIES_FNAME):
        src_tensor = torch.load(split_dir_src / fname, map_location="cpu")
        dst_tensor = torch.load(split_dir_dst / fname, map_location="cpu")
        if not torch.equal(src_tensor, dst_tensor):
            raise ValueError(f"{fname} mismatches for {split}")
    tokens_src = torch.load(split_dir_src / TOKENS_FNAME, map_location="cpu")
    tokens_dst = torch.load(split_dir_dst / TOKENS_FNAME, map_location="cpu")
    if len(tokens_src) != len(tokens_dst):
        raise ValueError(f"tokens length mismatch for {split}")
    # Validate discretized variants
    for variant in VARIANTS:
        var_path = split_dir_dst / f"rel_actions_discretized_{variant}.pth"
        if not var_path.exists():
            raise FileNotFoundError(f"Missing {var_path}")
        var_tensor = torch.load(var_path, map_location="cpu")
        if var_tensor.shape != src_actions.shape:
            raise ValueError(f"Variant {variant} shape mismatch for {split}")
        if var_tensor.dtype != src_actions.dtype:
            raise ValueError(f"Variant {variant} dtype mismatch for {split}")
        delta = var_tensor - src_actions
        mean_abs = float(delta.abs().mean().item())
        max_abs = float(delta.abs().max().item())
        valid_centroids = flatten_valid_entries(var_tensor, dst_lengths)
        if valid_centroids.numel() == 0:
            unique_centroids = 0
        else:
            unique_centroids = int(torch.unique(valid_centroids, dim=0).shape[0])
        variant_stats.append({"split": split, "variant": variant, "mean_abs_diff": mean_abs, "max_abs_diff": max_abs, "unique_centroids": unique_centroids})
        print(f"  {variant}: mean|Δ|={mean_abs:.4f}, max|Δ|={max_abs:.4f}, unique centroids={unique_centroids}")
        del var_tensor, delta, valid_centroids
    states = torch.load(split_dir_dst / STATES_FNAME, map_location="cpu")
    states_const = torch.load(split_dir_dst / STATE_CONSTANT_FNAME, map_location="cpu")
    if states_const.shape != states.shape:
        raise ValueError(f"states_constant shape mismatch for {split}")
    zero_max = float(states_const.abs().max().item())
    print(f"states_constant zeros check: max|value|={zero_max:.3e}")
    del src_actions, dst_actions, states, states_const, tokens_src, tokens_dst
    gc.collect()


In [None]:
# === Subset dataset verification (1k / 10k trajectories) ===
base_counts_for_alloc = {split: {"num_episodes": discretized_meta[split]["num_episodes"]} for split in SPLITS}
for subset_name, target_total in SUBSET_SPECS.items():
    subset_root = subset_roots[subset_name]
    print(f"\n== Checking subset {subset_name} ({target_total} requested trajectories) ==")
    if not subset_root.exists():
        print(f"[skip] {subset_root} does not exist")
        continue
    subset_meta = gather_episode_meta(subset_root)
    expected_counts = allocate_subset_counts(target_total, base_counts_for_alloc)
    total_expected = sum(expected_counts.values())
    total_actual = sum(meta["num_episodes"] for meta in subset_meta.values())
    print(f"expected total: {total_expected}, actual total: {total_actual}")
    if total_actual != total_expected:
        raise ValueError(f"Subset {subset_name} has {total_actual} episodes, expected {total_expected}")
    for split in SPLITS:
        expected = expected_counts.get(split, 0)
        actual = subset_meta[split]["num_episodes"]
        print(f"[{subset_name}/{split}] episodes: {actual} (expected {expected}), steps: {subset_meta[split]['num_steps']}")
        if actual != expected:
            raise ValueError(f"Subset {subset_name} split {split} mismatch: {actual} vs {expected}")
        verify_episode_aligned_files(subset_root / split, actual)
        lengths = subset_meta[split]["lengths"]
        if lengths.size:
            print(f"    lengths range: [{lengths.min()}, {lengths.max()}]")
        obs_count = count_observation_files((subset_root / split) / "obses")
        if obs_count and obs_count != actual:
            raise ValueError(f"Observation file count mismatch for {subset_name}/{split}: {obs_count} vs {actual}")


In [None]:
# === Aggregate discretization stats ===
if not variant_stats:
    print("No variant statistics collected. Run the previous cell first.")
else:
    header = f"{'Split':<8}{'Variant':<20}{'mean|Δ|':>12}{'max|Δ|':>12}{'Centroids':>12}"
    print(header)
    print('-' * len(header))
    for entry in variant_stats:
        print(f"{entry['split']:<8}{entry['variant']:<20}{entry['mean_abs_diff']:>12.6f}{entry['max_abs_diff']:>12.6f}{entry['unique_centroids']:>12}")
