TODO:

This is for prototyping

I will turn this into a example_action_head.py which loads the StepsDataset then trains the MLP and saves the weights

In [None]:
from datasets import load_from_disk, Dataset

import stable_worldmodel as swm
from stable_worldmodel.data import StepsDataset
from stable_worldmodel.policy import AutoCostModel
from stable_worldmodel.wm.dinowm import DINOWM

import stable_pretraining as spt

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

import math

In [None]:
# cache_dir = "~/.stable_worldmodel"
cache_dir = swm.data.get_cache_dir()  

train_dir = "pusht_expert_dataset_train"
val_dir = "pusht_expert_dataset_val"

In [None]:
train_path = cache_dir / train_dir
val_path   = cache_dir / val_dir

ds_train = load_from_disk(train_path)
ds_val   = load_from_disk(val_path)
# ds_train = Dataset.from_file(f"{train_dir}/data-00000-of-00001.arrow")
# ds_val   = Dataset.from_file(f"{val_dir}/data-00000-of-00001.arrow")

In [None]:
print(ds_train.features)
print(ds_train[0])

In [None]:
NUM_STEPS = 1

def step_transform():
    transforms = []
    for t in range(NUM_STEPS):
        key = f"pixels.{t}"
        transforms.append(
            spt.data.transforms.Compose(
                spt.data.transforms.ToImage(
                    mean=[0.5, 0.5, 0.5],
                    std=[0.5, 0.5, 0.5],
                    source=key,
                    target=key,
                ),
                spt.data.transforms.Resize(224, source=key, target=key),
                spt.data.transforms.CenterCrop(224, source=key, target=key),
            )
        )
    return spt.data.transforms.Compose(*transforms)

transform = step_transform()
train_data = StepsDataset("pusht_expert_dataset_train", num_steps=NUM_STEPS, transform=transform)
val_data   = StepsDataset("pusht_expert_dataset_val",   num_steps=NUM_STEPS, transform=transform)

for df in (train_data, val_data):
    df.data_dir = df.data_dir.parent

In [None]:
# print(train_data[0])
print(train_data[0]['pixels'].shape)

In [None]:
BATCH_SIZE = 256
NUM_WORKERS = 4

# optionally pin_memory on CUDA, not Mac
train_loader = DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    persistent_workers=True
)

val_loader = DataLoader(
    val_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    persistent_workers=True
)

In [None]:
dinowm = AutoCostModel(model_name="dinowm") # no _object
# this is just cost_head?
dinowm.eval().requires_grad_(False)

# print(type(model))

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

# torch load better?
checkpoint_name = "dinowm_pusht_object.ckpt"

checkpoint = cache_dir / checkpoint_name
dinowm = torch.load(checkpoint, map_location=device, weights_only=False)
dinowm = dinowm.to(device).eval()

for p in dinowm.parameters():
    p.requires_grad_(False)

In [None]:
# no proprio encoder for now?
def encode(batch):
    info_d = {"pixels": batch["pixels"].to(device), "proprio": batch["proprio"].to(device)}
    with torch.no_grad():
        info_d = dinowm.encode(
            info_d,
            target="embed",
            pixels_key= "pixels",
            proprio_key="proprio")
    # Bx(d_pixels + d_proprio)?
    return info_d["embed"][:,-1].mean(dim=1) # last step, mean across patches

In [None]:
for i, batch in enumerate(train_loader):
    if i ==1:
        break
    print(batch["pixels"].shape)
    latent = encode(batch)
    print(batch["proprio"].shape)
    print(latent.shape)

In [None]:
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, out_dim)
        )
    
    def forward(self, x):
        return self.layers(x)

In [None]:
LATENT_DIM = dinowm.backbone.config.hidden_size + dinowm.proprio_encoder.emb_dim
ACTION_DIM = len(train_data.dataset['action'][0])

# action = (x, y)
print(LATENT_DIM, ACTION_DIM)
action_head = MLP(LATENT_DIM, ACTION_DIM).to(device)

In [None]:
optimizer = torch.optim.AdamW(action_head.parameters(), lr=3e-4, weight_decay=1e-4)

EPOCHS = 25
for epoch in range(1, EPOCHS + 1):
    # train
    action_head.train()
    for batch in train_loader:
        latent = encode(batch)
        action = batch['action'][:,0].to(device)
        pred = action_head(latent)
        loss = F.mse_loss(pred, action)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    
    # eval
    action_head.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            latent = encode(batch)
            action = batch['action'][:,0].to(device)
            pred = action_head(latent)
            val_loss += F.mse_loss(pred, action)
    val_rmse = math.sqrt(val_loss / len(val_data))
    print(f'epoch {epoch}: RMSE: {val_rmse}')

In [None]:
# transforms

# from train/dinowm.py -> use 'pixels.0' for step size 1
# all the expert data should all be 224x224 already
transform = spt.data.transforms.Compose(
    spt.data.transforms.ToImage(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5],
                                source="pixels.0",
                                target="pixels.0"),
    spt.data.transforms.Resize(224, source="pixels.0", target="pixels.0"),
    spt.data.transforms.CenterCrop(224, source="pixels.0", target="pixels.0"),
)

def batch_transform(batch):
    batch["pixels"] = torch.stack([transform({"pixels.0": img})["pixels.0"] for img in batch["pixels"]])
    return batch