In [1]:
#!/usr/bin/env python3
from datasets import load_from_disk, Dataset

import stable_worldmodel as swm
from stable_worldmodel.data import StepsDataset, dataset_info
from stable_worldmodel.policy import AutoCostModel
from stable_worldmodel.wm.dinowm import DINOWM

from PIL import Image
import stable_pretraining as spt

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

import math, time, contextlib

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# params
NUM_WORKERS = 6
NUM_STEPS = 2 # T
BATCH_SIZE = 256 # B
FRAMESKIP = 5 # S
EPOCHS = 25
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-4
USE_ACTIONS = True

# file paths
TRAIN_DIR = "pusht_expert_dataset_train"
VAL_DIR = "pusht_expert_dataset_val"
CHECKPOINT_NAME = "dinowm_pusht_object.ckpt"

In [3]:
# Simple MLP
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, out_dim)
        )
    
    def forward(self, x):
        return self.layers(x)

In [None]:
# Gaussian MLP - outputs distribution over actions
class GaussianMLP(nn.Module):
    def __init__(self, in_dim, out_dim, dropout_prob=0.1, hidden_dim=512, 
                 feature_dim=256, head_hidden_dim=128):
        super().__init__()
        self.out_dim = out_dim
        
        # Shared feature extractor backbone
        self.feature_extractor = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_dim, feature_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
        )
        
        # Separate MLP head for mean
        self.mean_head = nn.Sequential(
            nn.Linear(feature_dim, head_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(head_hidden_dim, out_dim)
        )
        
        # Separate MLP head for log_std
        self.log_std_head = nn.Sequential(
            nn.Linear(feature_dim, head_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(head_hidden_dim, out_dim)
        )
        
        # Initialize log_std head output to reasonable values
        nn.init.constant_(self.log_std_head[-1].bias, -0.5)
    
    def forward(self, x):
        features = self.feature_extractor(x)
        mean = self.mean_head(features)
        log_std = self.log_std_head(features)
        log_std = torch.clamp(log_std, min=-10, max=2)
        return mean, log_std
    
    def log_prob(self, mean, log_std, action):
        std = torch.exp(log_std)
        var = std ** 2
        log_prob = -0.5 * (
            ((action - mean) ** 2) / var +
            2 * log_std +
            math.log(2 * math.pi)
        )
        return log_prob.sum(dim=-1)
    
    def sample(self, mean, log_std):
        std = torch.exp(log_std)
        eps = torch.randn_like(mean)
        return mean + eps * std


In [None]:
# Gaussian MLP - outputs distribution over actions
class GaussianMLP(nn.Module):
    def __init__(self, in_dim, out_dim, dropout_prob=0.1, hidden_dim=512, feature_dim=256, head_hidden_dim=128):
        super().__init__()
        self.out_dim = out_dim
        
        # Shared feature extractor backbone
        self.feature_extractor = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_dim, feature_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
        )
        
        # Separate MLP head for mean
        self.mean_head = nn.Sequential(
            nn.Linear(feature_dim, head_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(head_hidden_dim, out_dim)
        )
        
        # Separate MLP head for log_std
        self.log_std_head = nn.Sequential(
            nn.Linear(feature_dim, head_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(head_hidden_dim, out_dim)
        )
        
        # Initialize log_std head output to reasonable values (e.g., log(0.5) ≈ -0.69)
        nn.init.constant_(self.log_std_head[-1].bias, -0.5)
    
    def forward(self, x):
        """
        Returns:
            mean: predicted mean of action distribution [B, out_dim]
            log_std: predicted log standard deviation [B, out_dim]
        """
        # Extract shared features
        features = self.feature_extractor(x)
        
        # Pass through separate MLP heads
        mean = self.mean_head(features)
        log_std = self.log_std_head(features)
        
        # Clamp log_std for numerical stability
        log_std = torch.clamp(log_std, min=-10, max=2)
        
        return mean, log_std
    
    def log_prob(self, mean, log_std, action):
        """
        Compute log probability of action under Gaussian distribution.
        
        Args:
            mean: predicted mean [B, out_dim]
            log_std: predicted log std [B, out_dim]
            action: ground truth action [B, out_dim]
            
        Returns:
            log_prob: log probability [B]
        """
        std = torch.exp(log_std)
        var = std ** 2
        
        # Gaussian log probability: -0.5 * [(x-μ)²/σ² + log(2πσ²)]
        log_prob = -0.5 * (
            ((action - mean) ** 2) / var +
            2 * log_std +
            math.log(2 * math.pi)
        )
        
        # Sum over action dimensions
        return log_prob.sum(dim=-1)
    
    def sample(self, mean, log_std):
        """
        Sample action from the predicted distribution.
        
        Args:
            mean: predicted mean [B, out_dim]
            log_std: predicted log std [B, out_dim]
            
        Returns:
            action: sampled action [B, out_dim]
        """
        std = torch.exp(log_std)
        eps = torch.randn_like(mean)
        return mean + eps * std


In [4]:
# Transform function: normalize + reshape
def make_transform(keys):
    transforms = []
    for key in keys:
        transforms.append(
            spt.data.transforms.Compose(
                spt.data.transforms.ToImage(
                    mean=[0.5, 0.5, 0.5],
                    std=[0.5, 0.5, 0.5],
                    source=key,
                    target=key,
                ),
                spt.data.transforms.Resize(224, source=key, target=key),
                spt.data.transforms.CenterCrop(224, source=key, target=key),
            )
        )
    return spt.data.transforms.Compose(*transforms)

In [5]:
# Preprocessing function
def load_dataset(dir, num_steps=1, frameskip=FRAMESKIP):
    transform = make_transform([f'pixels.{i}' for i in range(num_steps)])

    dataset = StepsDataset(dir, num_steps=num_steps, transform=transform, frameskip=frameskip) # frameskip?

    # hacky
    dataset.data_dir = dataset.data_dir.parent

    # add goal column if not there
    goals = cache_goals(dataset)

    return dataset, goals

def cache_goals(steps: StepsDataset):
    data = steps.dataset.with_format("python")
    goals = {} # {episode -> {goal pixel, goal_proprio}}
    transform = make_transform(['goal_pixels'])

    for ep, indices in steps.episode_slices.items():
        goal_idx = indices[-1]
        goal_px_path = steps.data_dir / data["pixels"][goal_idx]
        # transform the goal here
        with Image.open(goal_px_path) as img:
            pixels = {'goal_pixels': img.convert('RGB')}
            transform(pixels)
            goal_pixels = pixels['goal_pixels']
        goals[ep] = {"goal_pixels": goal_pixels,
                     "goal_proprio": torch.as_tensor(data["proprio"][goal_idx]),
                     }

    steps.dataset = data.with_format("torch")
    return goals

def attach_goals(batch, goals):
    goal_pixels = []
    goal_proprios = []

    for ep in batch["episode_idx"].tolist():
        goal = goals[ep[0]]
        goal_pixels.append(goal["goal_pixels"])
        goal_proprios.append(goal["goal_proprio"])

    batch["goal_pixels"] = torch.stack(goal_pixels).unsqueeze(1)
    batch["goal_proprio"] = torch.stack(goal_proprios).unsqueeze(1)
    return batch

In [6]:
def get_loaders(train_data, val_data, batch_size, device, num_workers):
    # optionally pin_memory on CUDA, not Mac
    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        persistent_workers=True,
        pin_memory=(device.type=='cuda')
    )

    val_loader = DataLoader(
        val_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        persistent_workers=True,
        pin_memory=(device.type=='cuda')
    )
    return train_loader, val_loader

In [7]:
# load model checkpoint from cache_dir in inference mode
def load_checkpoint(checkpoint_name, device):
    cache_dir = swm.data.get_cache_dir()
    checkpoint_path = cache_dir / checkpoint_name
    model = torch.load(checkpoint_path, map_location=device, weights_only=False) # weights_only false
    model = model.to(device).eval()

    for param in model.parameters():
        param.requires_grad_(False)
    return model

In [8]:
# Encoder func: call dinowm encoders
@torch.inference_mode()
def encode(batch, dinowm, device, use_actions=False):
    # this is hacky
    pixels = torch.cat((batch["pixels"],
                        batch['goal_pixels']), dim=1).to(device, non_blocking=True)
    proprio = torch.cat((batch["proprio"],
                         batch["goal_proprio"]), dim=1).to(device, non_blocking=True)
    actions = (
        torch.cat(
            (batch["action"], torch.zeros_like(batch['action'][:, :1])),
            dim=1,
        ).to(device, non_blocking=True) if use_actions else None
    ) # pad with a single step of 0s

    data = {
        "pixels": pixels,
        "proprio": proprio,
        "action": actions,
    }

    context = torch.autocast(device_type=device.type, dtype=torch.float16) if device.type in ("cuda", "mps") else contextlib.nullcontext()
    with context:
        out = dinowm.encode(
            data,
            target="embed",
            pixels_key="pixels",
            proprio_key="proprio",
            action_key=("action" if use_actions else None),
        )

    # attach attention pooler here
    pix_out = out["pixels_embed"].mean(dim=2).float()
    prp_out = out["proprio_embed"].float()

    # detach goal pixels + proprio
    z_pix, z_gpix = pix_out[:,:-1], pix_out[:,-1] # B x T x d_pixels (pooled by patch), B x d_pixels
    z_prp, z_gprp = prp_out[:,:-1], prp_out[:,-1] # B x T x d_proprio, B x d_proprio
    
    z_act = None
    if use_actions:
        z_act = out["action_embed"][:,:-2].float() # B x (T - 1) * d_actions_effective := (d_actions * frame_skip)
    
    return z_pix, z_prp, z_act, z_gpix, z_gprp

In [9]:

def to_feature(z_pix, z_prp, z_act, z_gpix, z_gprp):
    parts = [z_pix[:,:-1], z_prp[:,:-1]]
    if z_act is not None:
        parts.append(z_act)

    # history latents (block of S actions + state := (pixel + proprio embeddings))
    z_hist = torch.cat(parts, dim=2) # B x (T-1) x d_embed := [d_pixels + d_proprio (+ d_actions_effective)]
    z_hist = torch.flatten(z_hist, start_dim=1, end_dim=2) # B x ((T-1) * d_embed)
    
    # current + goal latents (just pixel + proprio embeedings)
    z_cur = torch.cat((z_pix[:,-1], z_prp[:,-1], z_gpix, z_gprp), dim=1) # B x 2 * (d_pixels + d_proprio)

    # concat
    z = torch.cat((z_hist, z_cur), dim=1)
    return z


In [10]:
# find device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

def sync():
    if device.type=='cuda':
        torch.cuda.synchronize()
    if device.type=='mps':
        torch.mps.synchronize()

print("device:", device)

# load datasets
train_data, train_goals  = load_dataset(TRAIN_DIR, NUM_STEPS)
val_data, val_goals = load_dataset(VAL_DIR, NUM_STEPS)
print(f"Loaded datasets: train size={len(train_data)}, val size={len(val_data)}")

# build loaders
train_loader, val_loader = get_loaders(train_data, val_data, BATCH_SIZE, device, NUM_WORKERS)

# load DINO-WM
dinowm = load_checkpoint(CHECKPOINT_NAME, device)
assert isinstance(dinowm, DINOWM)
print(f"Loaded DINO-WM from checkpoint: '{CHECKPOINT_NAME}'")
    
# calculate dims
d_pixel = dinowm.backbone.config.hidden_size
d_proprio = dinowm.proprio_encoder.emb_dim
d_action = dinowm.action_encoder.emb_dim

LATENT_DIM = (d_pixel + d_proprio) * (NUM_STEPS + 1) + (d_action if USE_ACTIONS else 0) * (NUM_STEPS - 1)
ACTION_DIM = 2 # predict actual action, not latent
print(f'latent_dim={LATENT_DIM}, action_dim={ACTION_DIM}')

device: mps
[32m17:56:22.114[0m | [1mINFO   [0m ([36m74564, stable_pretraining.data.datasets[0m) | [1mLoading dataset with load_from_disk /Users/ashton/.stable_worldmodel/pusht_expert_dataset_train[0m
[32m17:58:05.021[0m | [1mINFO   [0m ([36m74564, stable_pretraining.data.datasets[0m) | [1mLoading dataset with load_from_disk /Users/ashton/.stable_worldmodel/pusht_expert_dataset_val[0m
Loaded datasets: train size=2168571, val size=2325
Loaded DINO-WM from checkpoint: 'dinowm_pusht_object.ckpt'
latent_dim=1192, action_dim=2


In [11]:
print(train_data.column_names)

['episode_idx', 'step_idx', 'action', 'state', 'proprio', 'pixels', 'sample_idx']


In [None]:
# Train Gaussian MLP with negative log likelihood loss
print("\n" + "="*70)
print("Training Gaussian MLP")
print("="*70 + "\n")

gaussian_action_head = GaussianMLP(LATENT_DIM, ACTION_DIM, dropout_prob=0.15).to(device)
optimizer_gaussian = torch.optim.AdamW(
    gaussian_action_head.parameters(), 
    lr=LEARNING_RATE, 
    weight_decay=WEIGHT_DECAY
)

for epoch in range(1, EPOCHS + 1):
    gaussian_action_head.train()
    epoch_nll = 0.0
    num_samples = 0
    
    t0 = time.perf_counter()
    n = 0
    num_batches = len(train_loader)
    
    for i, batch in enumerate(train_loader):
        attach_goals(batch, train_goals)
        z_pix, z_prp, z_act, z_gpix, z_gprp = encode(batch, dinowm, device, USE_ACTIONS)
        z = to_feature(z_pix, z_prp, z_act, z_gpix, z_gprp)
        action = batch['action'][:,-1,:2].to(device)
        
        # Forward pass - get mean and log_std
        mean, log_std = gaussian_action_head(z)
        
        # Compute negative log likelihood loss
        log_prob = gaussian_action_head.log_prob(mean, log_std, action)
        nll_loss = -log_prob.mean()
        
        # Backward pass
        optimizer_gaussian.zero_grad()
        nll_loss.backward()
        optimizer_gaussian.step()
        
        # Track stats
        batch_size = action.shape[0]
        epoch_nll += nll_loss.item() * batch_size
        num_samples += batch_size
        n += batch_size
        
        if i % 100 == 0:
            sync()
            elapsed = time.perf_counter() - t0
            sps = i / max(1e-9, elapsed)
            eta = (num_batches - i) / max(1e-9, sps)
            print(f"Epoch {epoch}: step {i}/{num_batches}, NLL={nll_loss.item():.4f}, ETA={eta/60:.1f}min")
    
    # Validation
    gaussian_action_head.eval()
    val_nll = 0.0
    val_mse = 0.0
    val_samples = 0
    
    with torch.no_grad():
        for batch in val_loader:
            attach_goals(batch, val_goals)
            z_pix, z_prp, z_act, z_gpix, z_gprp = encode(batch, dinowm, device, USE_ACTIONS)
            z = to_feature(z_pix, z_prp, z_act, z_gpix, z_gprp)
            action = batch['action'][:,-1,:2].to(device)
            
            mean, log_std = gaussian_action_head(z)
            log_prob = gaussian_action_head.log_prob(mean, log_std, action)
            nll = -log_prob.mean()
            mse = F.mse_loss(mean, action, reduction='sum')
            
            batch_size = action.shape[0]
            val_nll += nll.item() * batch_size
            val_mse += mse.item()
            val_samples += batch_size
    
    avg_train_nll = epoch_nll / num_samples
    avg_val_nll = val_nll / val_samples
    val_rmse = math.sqrt(val_mse / val_samples)
    
    print(f'\nEpoch {epoch} Summary:')
    print(f'  Train NLL: {avg_train_nll:.6f}')
    print(f'  Val NLL:   {avg_val_nll:.6f}')
    print(f'  Val RMSE (mean only): {val_rmse:.6f}\n')


In [None]:
# Save trained models
print("\n" + "="*70)
print("Saving Models")
print("="*70)

torch.save(action_head.state_dict(), "mlp_action_head.pt")
print("✓ Saved MLP model to: mlp_action_head.pt")

torch.save(gaussian_action_head.state_dict(), "gaussian_mlp_action_head.pt")
print("✓ Saved Gaussian MLP model to: gaussian_mlp_action_head.pt")

print("\nModels saved! Ready to test with test_mlp_policies.py")


In [None]:
# Example: Using the Gaussian MLP for inference
# This shows how to get both deterministic (mean) and stochastic (sampled) actions

gaussian_action_head.eval()

with torch.no_grad():
    # Get a single batch for demonstration
    batch = next(iter(val_loader))
    attach_goals(batch, val_goals)
    
    # Encode
    z_pix, z_prp, z_act, z_gpix, z_gprp = encode(batch, dinowm, device, USE_ACTIONS)
    z = to_feature(z_pix, z_prp, z_act, z_gpix, z_gprp)
    action_gt = batch['action'][:,-1,:2].to(device)
    
    # Get distribution parameters
    mean, log_std = gaussian_action_head(z)
    std = torch.exp(log_std)
    
    # Sample multiple actions from the distribution
    num_samples = 5
    sampled_actions = [gaussian_action_head.sample(mean, log_std) for _ in range(num_samples)]
    
    # Print results for first example in batch
    print("Example Predictions for First State in Batch:")
    print(f"Ground Truth Action: {action_gt[0].cpu().numpy()}")
    print(f"Predicted Mean:      {mean[0].cpu().numpy()}")
    print(f"Predicted Std:       {std[0].cpu().numpy()}")
    print(f"\nSampled Actions:")
    for i, sampled in enumerate(sampled_actions):
        print(f"  Sample {i+1}: {sampled[0].cpu().numpy()}")
    
    # Compute log probability of ground truth under learned distribution
    log_prob = gaussian_action_head.log_prob(mean, log_std, action_gt)
    print(f"\nLog Probability of Ground Truth: {log_prob[0].item():.4f}")
    
    # Use mean for deterministic prediction (best estimate)
    deterministic_action = mean
    print(f"\nFor deployment, you can use:")
    print(f"  - Deterministic (mean): {deterministic_action[0].cpu().numpy()}")
    print(f"  - Stochastic (sample):  {sampled_actions[0][0].cpu().numpy()}")


In [12]:

# train action head
action_head = MLP(LATENT_DIM, ACTION_DIM).to(device)
optimizer = torch.optim.AdamW(action_head.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

def report_stats():
    sync()
    elapsed = time.perf_counter() - t0
    sps = i / max(1e-9, elapsed)
    bps = n / max(1e-9, elapsed)
    eta = (num_batches - i) / max(1e-9, sps)
    print(
        f"Epoch {epoch}: step {i}/{len(train_loader)} "
        f"Loss = {loss.item():.4f} "
        f"steps / sec = {sps:.1f}, batches / sec = {bps:.1f} "
        f"ETA = {eta / 60.0:.1f} min"
    )

for epoch in range(1, EPOCHS + 1):
    action_head.train()

    t0 = time.perf_counter()
    n = 0
    num_batches = len(train_loader) # could be outside loop
    for i, batch in enumerate(train_loader):

        attach_goals(batch, train_goals)
        z_pix, z_prp, z_act, z_gpix, z_gprp = encode(batch, dinowm, device, USE_ACTIONS)
        z = to_feature(z_pix, z_prp, z_act, z_gpix, z_gprp)
        action = batch['action'][:,-1,:2].to(device) # first action from the last (current) step

        pred = action_head(z)

        loss = F.mse_loss(pred, action)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        n += BATCH_SIZE
        if i % 100 == 0:
            report_stats()
    
    # eval
    action_head.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            attach_goals(batch, val_goals)
            z_pix, z_prp, z_act, z_gpix, z_gprp = encode(batch, dinowm, device, USE_ACTIONS)
            z = to_feature(z_pix, z_prp, z_act, z_gpix, z_gprp)
            action = batch['action'][:,-1,:2].to(device)

            pred = action_head(z)
            loss = F.mse_loss(pred, action)
            val_loss += loss.item() * BATCH_SIZE

    val_rmse = math.sqrt(val_loss / len(val_data))
    print(f'epoch {epoch}: RMSE: {val_rmse:.6f}')

Epoch 1: step 0/8471 Loss = 9.8604 steps / sec = 0.0, batches / sec = 8.5 ETA = 141183333333.3 min


KeyboardInterrupt: 