In [None]:
# Import the necessary packages for the whole script
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal, Independent, TransformedDistribution
from torch.distributions.transforms import TanhTransform
import gymnasium as gym39
import mujoco
from matplotlib.collections import LineCollection
import matplotlib.cm as cm
import numpy as np
import math
import random
import matplotlib.pyplot as plt
import minari
from torch.utils.data import Dataset, DataLoader
import wandb
import os

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# According to the paper, each layer contains 256 neurons
NUM_NEURONS = 256
# The dimension of the abstract skill variable, z
Z_DIM = 256

# Skill Posterior, q_phi
class SkillPosterior(nn.Module):
    def __init__(self, state_dim, action_dim, h_dim=NUM_NEURONS, n_gru_layers=4):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.state_emb = nn.Sequential(
            nn.Linear(state_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU()
        )

        self.bi_gru = nn.GRU(
            input_size=h_dim + action_dim,
            hidden_size=h_dim,
            batch_first=True,
            bidirectional=True,
            num_layers=n_gru_layers
        )

        self.mean = MeanNetwork(in_dim=2*h_dim, out_dim=Z_DIM)
        self.std  = StandardDeviationNetwork(in_dim=2*h_dim, out_dim=Z_DIM)


    def forward(self, state_sequence, action_sequence):
        # state_sequence: [B, T, state_dim]
        s_emb = self.state_emb(state_sequence)                 
        x_in  = torch.cat([s_emb, action_sequence], dim=-1)   
        feats, _ = self.bi_gru(x_in)                          
        seq_emb = feats[:, -1, :] # *** use last time step, not mean ***
        mean = self.mean(seq_emb)
        std  = self.std(seq_emb)
        return mean, std


# Low-Level Skill-Conditioned Policy, pi_theta
class SkillPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, h_dim=NUM_NEURONS, a_dist='normal', max_sig=None, fixed_sig=None):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.a_dist = a_dist
        self.max_sig = max_sig
        self.fixed_sig = fixed_sig

        self.layers = nn.Sequential(
            nn.Linear(state_dim + Z_DIM, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU()
        )
        self.mean_head = nn.Sequential(
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, action_dim)
        )
        self.sig_head  = nn.Sequential(
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, action_dim)
        )

    def forward(self, state, z):
        # state: [B*T, state_dim], z: [B*T, Z_DIM]
        x = torch.cat([state, z], dim=-1)
        feats = self.layers(x)
        mean  = self.mean_head(feats)
        if self.max_sig is None:
            sig = F.softplus(self.sig_head(feats))
        else:
            sig = self.max_sig * torch.sigmoid(self.sig_head(feats))
        if self.fixed_sig is not None:
            sig = self.fixed_sig * torch.ones_like(sig)
        return mean, sig

        

# Temporally-Abstract World Model, p_psi
class TAWM(nn.Module):
    """
    Input: initial state, along with the abstract skill
    Output: mean and std over terminal state

    1. 2-layer shared network w/ ReLU activations for initial state and abstract skill (concatenated)
    2. Extract mean and std of layer 1's output
    """
    def __init__(self, state_dim, h_dim=NUM_NEURONS, per_element_sigma=True):
        super().__init__()
        self.state_dim = state_dim
        self.per_element_sigma = per_element_sigma

        self.layers = nn.Sequential(
            nn.Linear(state_dim + Z_DIM, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU()
        )
        self.mean_head = nn.Sequential(
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, state_dim)
        )
        if per_element_sigma:
            self.sig_head = nn.Sequential(
                nn.Linear(h_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, state_dim),
                nn.Softplus()
            )
        else:
            self.sig_head = nn.Sequential(
                nn.Linear(h_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, 1),
                nn.Softplus()
            )

    def forward(self, s0, z):
        # s0: [B, state_dim], z: [B, Z_DIM]
        x = torch.cat([s0, z], dim=-1)
        feats = self.layers(x)
        mean  = self.mean_head(feats)
        sig   = self.sig_head(feats)
        if not self.per_element_sigma:
            sig = sig.expand(-1, self.state_dim)
        return mean, sig


# Skill Prior, p_omega
class SkillPrior(nn.Module):
    """
    Input: Initial state, s0, in the trajectory
    Output: mean and std over the abstract skill, z

    1. 2-layer shared network w/ ReLU activations for the initial state
    2. Extract mean and std of layer 1's output
    """
    def __init__(self, state_dim, h_dim=NUM_NEURONS):
        super().__init__()
        self.state_dim = state_dim
        self.layers = nn.Sequential(
            nn.Linear(state_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU()
        )
        self.mean_head = nn.Sequential(
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, Z_DIM)
        )
        self.sig_head = nn.Sequential(
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, Z_DIM),
            nn.Softplus()
        )

    def forward(self, s0):
        feats = self.layers(s0)
        mean = self.mean_head(feats)
        std  = self.sig_head(feats)
        return mean, std


class MeanNetwork(nn.Module):
    """
    Input: tensor to calculate mean
    Output: mean of input w/ dimension out_dim

    1. 2-layer network w/ ReLU activation for the first layer
    """
    def __init__(self, in_dim, out_dim):
        super().__init__()
        
        self.fc1 = nn.Linear(in_dim, NUM_NEURONS)
        self.fc2 = nn.Linear(NUM_NEURONS, out_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
        
        
class StandardDeviationNetwork(nn.Module):
    """
    Input: tensor to calculate std
    Output: std of input w/ dimension out_dim

    Note: the standard deviation is lower and upper bounded at 0.05 and 2.0
    - if std is 0, then log(std) -> inf
    - if std is large, then can affect training

    1. 2-layer linear network with ReLU activation after first layer and softplus after second

    """
    def __init__(self, in_dim, out_dim, min_std=0.05, max_std=5.0):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, NUM_NEURONS)
        self.fc2 = nn.Linear(NUM_NEURONS, out_dim)
        self.relu = nn.ReLU()
        self.softplus = nn.Softplus()
        self.min_std = min_std
        self.max_std = max_std
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        std = self.softplus(x) 
        #+ self.min_std  # lower bound
        #std = torch.clamp(std, max=self.max_std)
        return std


In [None]:
# Loads the AntMaze dataset in Minari format
ant_maze_dataset = minari.load_dataset('D4RL/antmaze/medium-diverse-v1')

print(ant_maze_dataset[0].actions.shape)
print(ant_maze_dataset[0].observations.keys())
print(ant_maze_dataset[0].observations["observation"].shape)
print(ant_maze_dataset[0].observations["achieved_goal"].shape)

# B, the number of subtrajectories per batch (from paper)
B = 100

# T, the length of each subtrajectory (from paper)
T = 40

# AntMaze state and action dims (from Minari)
state_dim = 31
action_dim = 8

# Initialize the models
q_phi = SkillPosterior(state_dim=state_dim, action_dim=action_dim).to(device)
pi_theta = SkillPolicy(state_dim=state_dim, action_dim=action_dim).to(device)
p_psi = TAWM(state_dim=state_dim).to(device)
p_omega = SkillPrior(state_dim=state_dim).to(device)

(1000, 8)
dict_keys(['achieved_goal', 'desired_goal', 'observation'])
(1001, 27)
(1001, 2)


In [6]:

def make_episode_splits(minari_dataset, train=0.8, val=0.1, test=0.1, seed=0):
    """Return three lists of episode indices (train_ids, val_ids, test_ids)."""
    # Materialize all episodes once so we know how many there are
    episodes = list(minari_dataset.iterate_episodes())
    n = len(episodes)
    idxs = list(range(n))
    # Shuffle the indices
    random.Random(seed).shuffle(idxs)
    n_train = int(round(train * n))
    n_val = int(round(val * n))
    train_ids = idxs[:n_train]
    val_ids = idxs[n_train:n_train+n_val]
    test_ids = idxs[n_train+n_val:]
    return train_ids, val_ids, test_ids

class SubtrajDataset(Dataset):
    """
    Loops over minari_dataset.iterate_episodes(), but keeps only episodes whose index is in episode_ids
    """
    def __init__(self, minari_dataset, T, episode_ids, stride=3):
        self.T = T
        self.items = []  

        # Iterate all episodes but only process those whose global index is in episode_ids
        for ep_idx, ep in enumerate(minari_dataset.iterate_episodes()):
            if ep_idx not in set(episode_ids):
                continue
            obs = ep.observations["observation"]          
            ach = ep.observations["achieved_goal"]        
            act = ep.actions                               
            Ltot = len(obs)
            if Ltot < T + 1:
                continue

            state_ext = np.concatenate([obs, ach], axis=-1).astype(np.float32)
            for t in range(0, Ltot - T, stride):
                state_seq = state_ext[t:t+T]         
                s0 = state_seq[0]             
                action_seq = act[t:t+T].astype(np.float32)  
                sT = state_ext[t+T]           
                self.items.append((s0, state_seq, action_seq, sT))

    def __len__(self): 
        return len(self.items)

    def __getitem__(self, i):
        """standardize s0, state_sequence, and sT by (x - mean) / std"""
        
        s0, S, A, sT = self.items[i]
        if hasattr(self, "stats") and self.stats is not None:
            S_mean, S_std = self.stats
            S  = (S  - S_mean) / S_std
            s0 = (s0 - S_mean) / S_std
            sT = (sT - S_mean) / S_std
            A  = A
        return {
            "s0": torch.as_tensor(s0, dtype=torch.float32),
            "state_sequence": torch.as_tensor(S, dtype=torch.float32),
            "action_sequence": torch.as_tensor(A, dtype=torch.float32),
            "sT": torch.as_tensor(sT, dtype=torch.float32),
        }

def collate(batch):
    return {
        "s0": torch.stack([b["s0"] for b in batch], 0),
        "state_sequence": torch.stack([b["state_sequence"] for b in batch], 0),
        "action_sequence": torch.stack([b["action_sequence"] for b in batch], 0),
        "sT": torch.stack([b["sT"] for b in batch], 0),
    }


In [7]:
# Pick indices for train/test/split
train_ids, val_ids, test_ids = make_episode_splits(ant_maze_dataset, train=0.8, val=0.0, test=0.2, seed=0)
print(f"train:{len(train_ids)}  val:{len(val_ids)}  test:{len(test_ids)}")

# Datasets from episode subsets
train_ds = SubtrajDataset(ant_maze_dataset, T=T, episode_ids=train_ids, stride=1)
val_ds = SubtrajDataset(ant_maze_dataset, T=T, episode_ids=val_ids,   stride=1)
test_ds = SubtrajDataset(ant_maze_dataset, T=T, episode_ids=test_ids,  stride=1)  

print(f"train:{len(train_ds)}  val:{len(val_ds)}  test:{len(test_ds)}")

# find per-feature mean and std from all state_sequence timesteps in train_ds
def compute_stats(ds):
    Ss = []
    for item in ds.items:
        Ss.append(item[1])  # state_sequence [T,29]
    S = np.concatenate([x.reshape(-1, x.shape[-1]) for x in Ss], axis=0)
    S_mean, S_std = S.mean(0), S.std(0) + 1e-6
    return (S_mean, S_std)

S_mean, S_std = 0, 1

# pass stats into datasets
train_ds.stats = (S_mean, S_std)
val_ds.stats = (S_mean, S_std)

train_loader = DataLoader(train_ds, batch_size=B, shuffle=True,  collate_fn=collate, drop_last=False)
val_loader = DataLoader(val_ds, batch_size=B, shuffle=False, collate_fn=collate, drop_last=False)

test_ds.stats = (S_mean, S_std)
test_loader = DataLoader(test_ds, batch_size=B, shuffle=False, collate_fn=collate, drop_last=False)

train:800  val:0  test:200
train:768800  val:0  test:192200


In [None]:
alpha, beta = 1.0, 1.0  

def compute_loss(batch):
    s0, S, A, sT = batch["s0"], batch["state_sequence"], batch["action_sequence"], batch["sT"]
    B, T, _  = S.shape
    denom = B * T

    # State encoder
    mu_q, std_q = q_phi(S, A)                      
    z = mu_q + std_q * torch.randn_like(mu_q)

    # Low-level policy pi_theta(a|s,z)
    z_bt = z.unsqueeze(1).expand(B, T, -1)         
    mu_pi, std_pi = pi_theta(S.reshape(B*T, -1), z_bt.reshape(B*T, -1))
    mu_pi, std_pi = mu_pi.view(B, T, -1), std_pi.view(B, T, -1)
    a_dist  = Independent(Normal(mu_pi, std_pi), 1)        

    # Compute policy loss
    a_loss  = -a_dist.log_prob(A).sum() / denom
    
    mu_pr, std_pr = p_omega(s0)                              
    prior_dist = Independent(Normal(mu_pr, std_pr), 1)
    log_prior = prior_dist.log_prob(z).sum() / denom
    post_dist = Independent(Normal(mu_q,  std_q),  1)
    log_post  = post_dist.log_prob(z).sum() / denom
    
    # Compute KL loss
    kl_loss = - log_prior + log_post

    # Detach gradient
    z_detached = z.detach()

    # TAWM over terminal state
    mu_T, std_T = p_psi(s0, z_detached)                            
    sT_dist = Independent(Normal(mu_T, std_T), 1)

    # State decoder loss
    sT_loss = -sT_dist.log_prob(sT).sum() / denom

    # Overall loss (naive VI loss from paper)
    loss = alpha * sT_loss + a_loss + beta * kl_loss
    return {
        "loss": loss,
        "policy_loss": a_loss,
        "kl_loss": kl_loss,
        "state_decoder_loss": sT_loss
    }

@torch.no_grad()
def eval_epoch(val_loader, q_phi, pi_theta, p_psi, p_omega, device):
    """Compute validation loss"""
    q_phi.eval()
    pi_theta.eval()
    p_psi.eval()
    p_omega.eval()
    loss_sum,policy_loss_sum, kl_loss_sum, state_decoder_loss_sum, n = 0.0, 0.0, 0.0, 0.0, 0
    for batch in val_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        terms = compute_loss(batch)
        loss = terms["loss"]
        policy_loss = terms["policy_loss"]
        kl_loss = terms["kl_loss"]
        state_decoder_loss = terms["state_decoder_loss"]
        loss_sum += float(loss.item())
        policy_loss_sum += float(policy_loss.item())
        kl_loss_sum += float(kl_loss.item())
        state_decoder_loss_sum += float(state_decoder_loss.item())

        n += 1
    if n == 0: 
        return None, None, None, None
    return loss_sum / n, policy_loss_sum / n, kl_loss_sum / n, state_decoder_loss_sum / n

def skill_model_training_with_val(
    train_loader, val_loader,
    q_phi, pi_theta, p_psi, p_omega,
    lr=5e-5,
    epochs=50, steps=1, grad_clip=1.0
):
    q_phi.to(device)
    pi_theta.to(device)
    p_psi.to(device)
    p_omega.to(device)

    opt = torch.optim.Adam(list(q_phi.parameters()) + list(pi_theta.parameters()) + list(p_psi.parameters()) + list(p_omega.parameters()), lr=lr)

    tr, va = [], []

    for epoch in range(1, epochs+1):
        q_phi.train()
        pi_theta.train()
        p_psi.train()
        p_omega.train()
        loss_run, policy_loss_run, kl_loss_run, state_decoder_loss_run = 0.0, 0.0, 0.0, 0.0 # Running loss in current epoch

        nb = 0

        for batch in train_loader:
            # Rebuilds dictionary but moves tensors to the device
            batch = {k: v.to(device) for k, v in batch.items()}
            nb += 1

            for _ in range(steps):
                opt.zero_grad(set_to_none=True)
                terms = compute_loss(batch)
                loss = terms["loss"]
                policy_loss = terms["policy_loss"]
                kl_loss = terms["kl_loss"]
                state_decoder_loss = terms["state_decoder_loss"]
                loss.backward()
                if grad_clip is not None:
                    torch.nn.utils.clip_grad_norm_(list(q_phi.parameters()) + list(pi_theta.parameters()) + list(p_psi.parameters()) + list(p_omega.parameters()), grad_clip)
                opt.step()
            loss_run += float(loss.item())
            policy_loss_run += float(policy_loss.item())
            kl_loss_run += float(kl_loss.item())
            state_decoder_loss_run += float(state_decoder_loss.item())


        # Calculate the average losses over all the batches in the epoch
        loss_epoch = loss_run / max(1, nb)
        policy_loss_epoch = policy_loss_run / max(1, nb)
        kl_loss_epoch = kl_loss_run / max(1, nb)
        state_decoder_loss_epoch = state_decoder_loss_run / max(1, nb)

        tr.append(loss_epoch)

        # validation
        v_loss, v_policy_loss, v_kl_loss, v_state_decoder_loss = eval_epoch(val_loader, q_phi, pi_theta, p_psi, p_omega, device)
        va.append(v_loss)

        print(f"[Epoch {epoch:03d}/{epochs}] "
              f"train loss:{loss_epoch:.4f} "
              f"| val loss:{v_loss:.4f}")

        wandb.log({
            "train/loss": loss_epoch,
            "train/policy_loss": policy_loss_epoch,
            "train/kl_loss": kl_loss_epoch,
            "train/state_decoder_loss": state_decoder_loss_epoch,
            "val/loss": v_loss,
            "val/policy_loss": v_policy_loss,
            "val/kl_loss": v_kl_loss,
            "val/state_decoder_loss": v_state_decoder_loss,
            "epoch": epoch
        }, step=epoch)

    plt.figure(figsize=(7.5,4.5))
    plt.plot(tr, label="Train loss")
    if all(v is not None for v in va):
        plt.plot(va, label="Val loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.title("EM training: train vs. val losses")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    fig = plt.gcf()
    wandb.log({"plots/loss_curves": wandb.Image(fig)}, step=epoch)
    plt.close(fig)

    return {"train_loss": tr, "val_E": va}



In [None]:
wandb.init(
    project="tawm-skill-learning",
    name="antmaze-medium_em",
    config=dict(
        B=B, T=T, Z_DIM=Z_DIM, NUM_NEURONS=NUM_NEURONS,
        e_lr=5e-5, m_lr=5e-5, e_steps=1, m_steps=1,
        dataset="D4RL/antmaze/medium-diverse-v1",
        device=device
    )
)

wandb.watch([q_phi, pi_theta, p_psi, p_omega], log="gradients", log_freq=200)

[34m[1mwandb[0m: Currently logged in as: [33mwilliam-huang-08[0m ([33mwilliam-huang-08-yale-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
curves = skill_model_training_with_val(train_loader, test_loader, q_phi, pi_theta, p_psi, p_omega, epochs=100, e_lr=5e-5, m_lr=5e-5, e_steps=1, m_steps=1)