In [30]:
#!/usr/bin/env python3
from datasets import load_from_disk, Dataset

import stable_worldmodel as swm
from stable_worldmodel.data import StepsDataset, dataset_info
from stable_worldmodel.policy import AutoCostModel
from stable_worldmodel.wm.dinowm import DINOWM

import stable_pretraining as spt

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

import math, time, contextlib

In [None]:
# params
NUM_WORKERS = 6
NUM_STEPS = 2 # T
BATCH_SIZE = 256 # B
FRAMESKIP = 5 # S
EPOCHS = 25
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-4
USE_ACTIONS = True

# file paths
TRAIN_DIR = "pusht_expert_dataset_train"
VAL_DIR = "pusht_expert_dataset_val"
CHECKPOINT_NAME = "dinowm_pusht_object.ckpt"

In [32]:
# Simple MLP
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, out_dim)
        )
    
    def forward(self, x):
        return self.layers(x)

In [33]:
# Transform function: normalize + reshape
def step_transform(num_steps):
    transforms = []
    for t in range(num_steps):
        key = f"pixels.{t}"
        transforms.append(
            spt.data.transforms.Compose(
                spt.data.transforms.ToImage(
                    mean=[0.5, 0.5, 0.5],
                    std=[0.5, 0.5, 0.5],
                    source=key,
                    target=key,
                ),
                spt.data.transforms.Resize(224, source=key, target=key),
                spt.data.transforms.CenterCrop(224, source=key, target=key),
            )
        )
    return spt.data.transforms.Compose(*transforms)

In [None]:
def attach_goal(steps: StepsDataset):
    data = steps.dataset
    if "goal" in data.column_names in data.column_names:
        return
    data = data.with_format("python")

    ep_ids = data["episode_idx"]
    pixels = data["pixels"]
    proprio = data["proprio"]
    
    last = {}
    for idx, ep_id in enumerate(ep_ids):
        last[ep_id] = idx

    goal_pixels = [pixels[last[ep_id]] for ep_id in ep_ids]
    goal_proprio = [proprio[last[ep_id]] for ep_id in ep_ids]
    data = data.add_column("goal", goal_pixels)
    data = data.add_column("goal_proprio", goal_proprio)

    steps.dataset = data.with_format("torch")

In [None]:
# Preprocessing function
def load_dataset(dir, num_steps=1, frameskip=5):
    transform = step_transform(num_steps)

    dataset = StepsDataset(dir, num_steps=num_steps, transform=transform, frameskip=frameskip) # frameskip?
    # hacky
    dataset.data_dir = dataset.data_dir.parent

    # attach goal col
    attach_goal(dataset)

    return dataset

In [36]:
train_data, val_data = load_dataset(TRAIN_DIR, num_steps=2), load_dataset(VAL_DIR, num_steps=2)

[32m22:08:57.706[0m | [1mINFO   [0m ([36m24996, stable_pretraining.data.datasets[0m) | [1mLoading dataset with load_from_disk /Users/ashton/.stable_worldmodel/pusht_expert_dataset_train[0m
[32m22:09:07.486[0m | [1mINFO   [0m ([36m24996, stable_pretraining.data.datasets[0m) | [1mLoading dataset with load_from_disk /Users/ashton/.stable_worldmodel/pusht_expert_dataset_val[0m


In [47]:
print(len(train_data), len(val_data))
for col in train_data.column_names:
    print(col)
    print(train_data[0][col].shape)
    actions = train_data[8]['action'][:-1]
    print(actions.shape)

2168571 2325
episode_idx
torch.Size([2])
torch.Size([1, 10])
step_idx
torch.Size([2])
torch.Size([1, 10])
action
torch.Size([2, 10])
torch.Size([1, 10])
state
torch.Size([2, 5])
torch.Size([1, 10])
proprio
torch.Size([2, 4])
torch.Size([1, 10])
pixels
torch.Size([2, 3, 224, 224])
torch.Size([1, 10])
sample_idx
torch.Size([2])
torch.Size([1, 10])


In [39]:
def get_loaders(train_data, val_data, batch_size, device, num_workers):
    # optionally pin_memory on CUDA, not Mac
    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        persistent_workers=True,
        pin_memory=(device.type=='cuda')
    )

    val_loader = DataLoader(
        val_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        persistent_workers=True,
        pin_memory=(device.type=='cuda')
    )
    return train_loader, val_loader

In [52]:
# load model checkpoint from cache_dir in inference mode
def load_checkpoint(checkpoint_name, device):
    cache_dir = swm.data.get_cache_dir()
    checkpoint_path = cache_dir / checkpoint_name
    model = torch.load(checkpoint_path, map_location=device, weights_only=False) # weights_only false
    model = model.to(device).eval()

    for param in model.parameters():
        param.requires_grad_(False)
    return model

In [41]:
# pad 0s in front of action for action_embedding
def pad_actions(actions, frame_skip=5):
    if frame_skip == 1:
        return actions
    B, T, d_a = actions.shape
    zeros = torch.zeros((B, T, d_a * (frame_skip - 1)))
    padded = torch.cat((zeros, actions), dim=2)
    return padded

In [42]:
def repeat_actions(actions: torch.Tensor, frame_skip=5):
    repeated = actions.repeat(1,1,frame_skip)
    print(actions.shape)
    print(repeated.shape)

In [None]:
# Encoder func: call dinowm encoders
@torch.inference_mode()
def encode(batch, dinowm, device, use_actions=False):
    if device.type in ("cuda", "mps"):
        context = torch.autocast(device_type=device.type, dtype=torch.float16) # float32?
    else:
        context = contextlib.nullcontext()
    
    # no pad, we use frameskip=5
    # actions = batch["action"][::-1]
    # print(actions.shape)
    # padded_actions = pad_actions(actions)
    # repeated_actions = repeat_actions(actions)
    # print(repeated_actions.shape)

    data = {
        "pixels":  batch["pixels"].to(device, non_blocking=True),
        "proprio": batch["proprio"].to(device, non_blocking=True),
    }
    if use_actions:
        # drop last block of frame_skip actions
        actions = batch['action'][:,:-1] # B x (T - 1) * d_actions_effective := (d_actions * frame_skip)
        data["action"] = actions.to(device, non_blocking=True)

    with context:
        out = dinowm.encode(
            data,
            target="embed",
            pixels_key="pixels",
            proprio_key="proprio",
            action_key=("action" if use_actions else None),
        )

    z_pixels = out["pixels_embed"].mean(dim=2).float() # B x T x d_pixels (pooled by patch)
    z_proprio = out["proprio_embed"].float() # B x T x d_proprio
    z_actions = (out["action_embed"].float() if use_actions else None) # B x (T-1) x d_actions_effective
    return z_pixels, z_proprio, z_actions

In [None]:
def to_feature(z_pixels, z_proprio, z_actions):
    parts = [z_pixels[:,:-1], z_proprio[:,:-1]]
    if z_actions is not None:
        parts.append(z_actions)

    # history latents (block of S actions + state := (pixel + proprio embeddings))
    z_hist = torch.cat(parts, dim=2) # B x (T-1) x d_embed := [d_pixels + d_proprio (+ d_actions_effective)]
    z_hist = torch.flatten(z_hist, start_dim=1, end_dim=2) # B x ((T-1) * d_embed)
    
    # current latents (just pixel + proprio embeedings)
    z_cur = torch.cat((z_pixels[:,-1], z_proprio[:,-1]), dim=2) # B x 1 x (d_pixels + d_proprio)
    z_cur = torch.flatten(z_cur, start_dim=1, end_dim=2) # B x ((T - 1) * d_embed)

    # concat
    z = torch.cat((z_hist, z_cur), dim=1)
    return z

In [54]:
def run():
    # find device
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    def sync():
        if device.type=='cuda':
            torch.cuda.synchronize()
        if device.type=='mps':
            torch.mps.synchronize()

    print("device:", device)

    # load datasets
    train_data, val_data = load_dataset(TRAIN_DIR, NUM_STEPS), load_dataset(VAL_DIR, NUM_STEPS)

    # build loaders
    train_loader, val_loader = get_loaders(train_data, val_data, BATCH_SIZE, device, NUM_WORKERS)

    # load DINO-WM
    dinowm = load_checkpoint(CHECKPOINT_NAME, device)
    assert isinstance(dinowm, DINOWM)
        
    # calculate dims
    d_pixel = dinowm.backbone.config.hidden_size
    d_proprio = dinowm.proprio_encoder.emb_dim
    d_action = dinowm.action_encoder.emb_dim
    LATENT_DIM = d_pixel + d_proprio + (d_action if USE_ACTIONS else 0)
    ACTION_DIM = 2 # predict actual action, not latent
    print(f'latent_dim={LATENT_DIM}, action_dim={ACTION_DIM}')

    # train action head
    action_head = MLP(LATENT_DIM, ACTION_DIM).to(device)
    optimizer = torch.optim.AdamW(action_head.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    def report_stats():
        sync()
        elapsed = time.perf_counter() - t0
        sps = i / max(1e-9, elapsed)
        bps = n / max(1e-9, elapsed)
        eta = (num_batches - i) / max(1e-9, sps)
        print(
            f"Epoch: {epoch} Step: {i}/{len(train_loader)} "
            f"Loss = {loss.item():.4f} "
            f"steps / sec = {sps:.1f}, samples / sec = {bps:.1f} "
            f"ETA = {eta/60:.1f} min"
        )

    for epoch in range(1, EPOCHS + 1):
        action_head.train()

        t0 = time.perf_counter()
        n = 0
        num_batches = len(train_loader) # could be outside loop

        for i, batch in enumerate(train_loader):

            z_pix, z_prp, z_act = encode(batch, dinowm, device, USE_ACTIONS)
            z = to_feature(z_pix, z_prp, z_act)
            action = batch['action'][:,-1,:2].to(device) # first action from the last (current) step

            pred = action_head(z)

            loss = F.mse_loss(pred, action)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 50 == 0:
                report_stats()
            n += BATCH_SIZE
        
        # eval
        action_head.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                z_pix, z_prp, z_act = encode(batch, USE_ACTIONS)
                z = to_feature(z_pix, z_prp, z_act)
                action = batch['action'][:,-1,:2].to(device)

                z = action_head(z)
                val_loss += F.mse_loss(pred, action)
        val_rmse = math.sqrt(val_loss / len(val_data))
        print(f'epoch {epoch}: RMSE: {val_rmse:.6f}')

if __name__ == "__main__":
    run()


device: mps
[32m01:01:09.410[0m | [1mINFO   [0m ([36m24996, stable_pretraining.data.datasets[0m) | [1mLoading dataset with load_from_disk /Users/ashton/.stable_worldmodel/pusht_expert_dataset_train[0m
[32m01:01:19.694[0m | [1mINFO   [0m ([36m24996, stable_pretraining.data.datasets[0m) | [1mLoading dataset with load_from_disk /Users/ashton/.stable_worldmodel/pusht_expert_dataset_val[0m


NameError: name 'USE_ACTIONS' is not defined