# Overview

Model-Based Temporal Abstraction involves simultaneuously learning
1) skill-conditioned low-level policy
2) skill-conditioned temporally abstract world model

Notation
- skill-conditioned low-level policy: $\pi_{\theta}(a_t|s_t, z)$
    - $\theta$ are parameters
    - $a_t \in A$ is current action selected by agent
    - $s_t \in S$ is current state
    - $z \in Z$ is abstract skill variable that encodes a skill

- skill-conditioned temporally abstract world model (TAWM): $p_{\psi}(s'|s,z)$ (models distribution of states agent is in after skill $z$)
    - $\psi$ parameters
    - $z$ is current skill

Note: low-level policy and TAWM not trained on rewards, reward function is provided later for planning with the learned skills 





In [None]:
# Import the necessary packages for the whole script
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal, Independent, kl_divergence
import numpy as np
import math
import matplotlib.pyplot as plt
import minari
from torch.utils.data import Dataset, DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

### Learning $\pi_{\theta}$ and $p_{\psi}$

Learning $\pi_{\theta}$ and $p_{\psi}$ requires treating skills as latent variables and optimizing the ELBO

$$
\mathcal{L}(\theta,\psi,\phi,\omega)
= \mathbb{E}_{\tau_T \sim \mathcal{D}}\!\left[
  \mathbb{E}_{q_\phi(z\,|\,\tau_T)}\!\left[
    \log \pi_\theta(\bar{a}\,|\,\bar{s}, z)
    + \log p_\psi(s_T \,|\, s_0, z)
  \right]
  - D_{\mathrm{KL}}\!\left(q_\phi(z\,|\,\tau_T)\,\|\,p_\omega(z\,|\,s_0)\right)
\right].
$$

where $\tau_T$ is a T-length subtrajectory sampled from the offline dataset $\mathcal{D}$, $\bar{s}$ and $\bar{a}$ are state and action sequences of $\tau_T$, $q_{\psi}$ is a posterior over $z$ given $\tau_T$, and $p_{\omega}$ is a prior of $z$ given $s_0$.

The first term is the log-likelihood of demonstrator actions. This ensures that the low-level policy can reproduce a demonstrator's action sequence given a skill. This forces $z$ to encode control-relevant information.

The second term is the log-likelihood of long-term state transitions. This term ensures that we learn relationships between $z$ to what possible $s_T$ could result from. the skill.

Finally, the last term is the KL divergence between skill posterior and prior (encourages compression of skills). Therefore, maximizing this ELBO makes skills $z$ explain the data and keeps the KL divergence small. This ensures that the skill is start-state predictable.

In [465]:
# According to the paper, each layer contains 256 neurons
NUM_NEURONS = 256
# The dimension of the abstract skill variable, z
Z_DIM = 256

# Skill Posterior, q_phi
class SkillPosterior(nn.Module):
    """
    Input: sequence of skills and actions
    Output: mean and std over z

    1. Linear layer w/ ReLU activation for the state sequence
    2. Single-layer bidirectional GRU for embedded states and action sequence (concatenated)
    3. Extract mean and std of layer 2's output
    """
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.fc1 = nn.Linear(in_features=self.state_dim, out_features=NUM_NEURONS)
        self.relu = nn.ReLU()
        self.bi_gru = nn.GRU(input_size=NUM_NEURONS+self.action_dim, hidden_size= NUM_NEURONS//2, bidirectional=True, batch_first=True)
        self.mean = MeanNetwork(Z_DIM)
        self.std = StandardDeviationNetwork(Z_DIM)

    def forward(self, state_sequence, action_sequence):
        embedded_states = self.relu(self.fc1(state_sequence))
        concatenated = torch.cat([embedded_states, action_sequence], dim=-1)
        x, _ = self.bi_gru(concatenated) # [B, T, NUM_NEURONS]
        seq_emb = x.mean(dim=1) # [B, NUM_NEURONS]
        mean = self.mean.forward(seq_emb)
        std = self.std.forward(seq_emb)
        return mean, std

# Low-Level Skill-Conditioned Policy, pi_theta
class SkillPolicy(nn.Module):
    """
    Input: Current state and a skill, z
    Output: mean and std over a

    1. 2-layer shared network w/ ReLU activations for the state and abstract skill (concatenated)
    2. Extract mean and std of layer 1's output
    """
    def __init__(self, state_dim, action_dim):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.fc1 = nn.Linear(in_features=self.state_dim+Z_DIM, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.relu = nn.ReLU()
        self.mean = MeanNetwork(self.action_dim)
        self.std = StandardDeviationNetwork(self.action_dim)
    
    def forward(self, state, z):
        c = torch.cat([state, z], dim=-1)
        x = self.relu(self.fc1(c))
        x = self.relu(self.fc2(x))
        mean = self.mean(x)
        std = self.std(x)
        return mean, std
        

# Temporally-Abstract World Model, p_psi
class TAWM(nn.Module):
    """
    Input: initial state, along with the abstract skill
    Output: mean and std over terminal state

    1. 2-layer shared network w/ ReLU activations for initial state and abstract skill (concatenated)
    2. Extract mean and std of layer 1's output
    """
    def __init__(self, state_dim):
        super().__init__()

        self.state_dim = state_dim
        self.fc1 = nn.Linear(in_features=self.state_dim+Z_DIM, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.relu = nn.ReLU()
        self.mean = MeanNetwork(self.state_dim)
        self.std = StandardDeviationNetwork(self.state_dim)
    
    def forward(self, input_state, z):
        c = torch.cat([input_state, z], dim=-1)
        x = self.relu(self.fc1(c))
        x = self.relu(self.fc2(x))
        mean = self.mean(x)
        std = self.std(x)
        return mean, std

# Skill Prior, p_omega
class SkillPrior(nn.Module):
    """
    Input: Initial state, s0, in the trajectory
    Output: mean and std over the abstract skill, z

    1. 2-layer shared network w/ ReLU activations for the initial state
    2. Extract mean and std of layer 1's output
    """
    def __init__(self, state_dim):
        super().__init__()

        self.state_dim = state_dim
        self.fc1 = nn.Linear(in_features=self.state_dim, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.relu = nn.ReLU()
        self.mean = MeanNetwork(Z_DIM)
        self.std = StandardDeviationNetwork(Z_DIM)
    
    def forward(self, input_state):
        x = self.relu(self.fc1(input_state))
        x = self.relu(self.fc2(x))
        mean = self.mean(x)
        std = self.std(x)
        return mean, std

class MeanNetwork(nn.Module):
    """
    Input: tensor to calculate mean
    Output: mean of input w/ dimension out_dim

    1. 2-layer network w/ ReLU activation for the first layer
    """
    def __init__(self, out_dim):
        super().__init__()
        
        self.fc1 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=out_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
        
        
class StandardDeviationNetwork(nn.Module):
    """
    Input: tensor to calculate std
    Output: std of input w/ dimension out_dim

    Note: the standard deviation is lower and upper bounded at 0.05 and 2.0
    - if std is 0, then log(std) -> inf
    - if std is large, then can affect training

    1. 2-layer linear network with ReLU activation after first layer and softplus after second

    """
    def __init__(self, out_dim, min_std=0.05, max_std=2.0):
        super().__init__()
        self.fc1 = nn.Linear(NUM_NEURONS, NUM_NEURONS)
        self.fc2 = nn.Linear(NUM_NEURONS, out_dim)
        self.relu = nn.ReLU()
        self.softplus = nn.Softplus()
        self.min_std = min_std
        self.max_std = max_std
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        std = self.softplus(x) + self.min_std  # lower bound
        std = torch.clamp(std, max=self.max_std)
        return std


#### The Expectation-Maximization (EM) Algorithm

Since calculating the true posterior of $z$ given $\tau_T$ is intractable, we infer $q_{\psi}(z|\tau_T)$.

1. E-Step:
- Update $\psi$ w/gradient descent so that KL divergence between $q_\psi$ and true posterior is minimized

2. M-Step:
- Fixing $q_{\psi}$, update ($\theta, \psi, \omega$) s.t. ELBO is maximized using gradient ascent.


##### E-Step (Update $\psi$)

In this step, we want to minimize 

$$
\mathbb{E}_{\mathcal{T}_T\sim\mathcal{D}}
\Bigg[
\mathbb{E}_{z\sim q_\phi}
\bigg[
\log \frac{q_\phi\!\left(z\mid \bar{s},\bar{a}\right)}
{\pi_\theta\!\left(\bar{a}\mid \bar{s}, z\right)\,p_\omega\!\left(z\mid s_0\right)}
\bigg]
\Bigg]
$$

Equivalently, we want to minimize $\mathcal{KL}(q_{\psi}||p)$ where $p$, the true posterior is 

$$
p(z \mid \bar{s}, \bar{a}) = \frac{1}{\eta}\,\pi_\theta(\bar{a}\mid \bar{s}, z)\,p_\omega(z\mid s_0).
$$


##### M-Step (Update $\theta$, $\psi$, $\omega$)

In this step, we want to update $\theta$, $\psi$, and $\omega$ using gradient ascent to maximize the ELBO from above.

Both steps are trained using an Adam optimizer.


##### Dataset 1: AntMaze Medium

In [466]:
# Loads the AntMaze dataset in Minari format
ant_maze_dataset = minari.load_dataset('D4RL/antmaze/medium-play-v1')

In [467]:
print(ant_maze_dataset[0].actions.shape)
print(ant_maze_dataset[0].observations.keys())
print(ant_maze_dataset[0].observations["observation"].shape)
print(ant_maze_dataset[0].observations["achieved_goal"].shape)

(1000, 8)
dict_keys(['achieved_goal', 'desired_goal', 'observation'])
(1001, 27)
(1001, 2)


In [468]:
# B, the number of subtrajectories per batch (from paper)
B = 100

# T, the length of each subtrajectory (from paper)
T = 9

# AntMaze state and action dims (from Minari)
state_dim = 29
action_dim = 8

# Initialize the models
q_phi = SkillPosterior(state_dim=state_dim, action_dim=action_dim).to(device)
pi_theta = SkillPolicy(state_dim=state_dim, action_dim=action_dim).to(device)
p_psi = TAWM(state_dim=state_dim).to(device)
p_omega = SkillPrior(state_dim=state_dim).to(device)

In [470]:
class TrainingDataset(Dataset):
    """
    Input: LoadedMinari dataset and the length of the number of actions in each subtrajectory (T)
    Output: Dictionary with keys "s0, state_sequence, action_sequence, and sT"

    Finds all episodes that have at least T actions. Then, for each of those episodes, creates a sliding window to create subtrajectories w/ T actions
    """
    def __init__(self, ant_maze_dataset, T):
        self.T = T
        self.subtrajectories = []

        for ep in ant_maze_dataset.iterate_episodes():
            s = ep.observations["observation"]
            a = ep.actions
            l = len(s)
            if l < T + 1:
                continue
            
            # Consider skipping timesteps so that subtrajectories don't overlap that much
            stride = 3
            for t in range(0, l - T, stride):
                obs = ep.observations["observation"][t:t+T+1] # (27,)
                ach = ep.observations["achieved_goal"][t:t+T+1]  # (2,)
                state_sequence_extended = np.concatenate([obs, ach], axis=-1)
                state_sequence = state_sequence_extended[:-1]
                s0 = state_sequence[0]
                action_sequence = a[t: t + T]
                sT = state_sequence_extended[-1]
                self.subtrajectories.append((s0, state_sequence, action_sequence, sT))

    def __len__(self):
        return len(self.subtrajectories)

    def __getitem__(self, idx):
        s0, state_sequence, action_sequence, sT = self.subtrajectories[idx]

        return {
            "s0": torch.as_tensor(s0, dtype=torch.float32),
            "state_sequence": torch.as_tensor(state_sequence, dtype=torch.float32),
            "action_sequence": torch.as_tensor(action_sequence, dtype=torch.float32),
            "sT": torch.as_tensor(sT, dtype=torch.float32)
        }
    
def collate(batch):
    # Vertically stacks each of the components of the subtrajectories such that the first dimension is the batch
    return {
        "s0": torch.stack([b["s0"] for b in batch], 0),
        "state_sequence": torch.stack([b["state_sequence"] for b in batch], 0),
        "action_sequence": torch.stack([b["action_sequence"] for b in batch], 0),
        "sT": torch.stack([b["sT"] for b in batch], 0)
    }

# Create the dictionary of subtrajectories
dataset = TrainingDataset(ant_maze_dataset, T)

# Iterator w/ groups of subtrajectories from the dataset of size B 
loader = DataLoader(dataset, batch_size=B, shuffle=True, collate_fn=collate)

In [471]:
def compute_e_loss(batch):
    s0  = batch["s0"]               
    S   = batch["state_sequence"]   
    A   = batch["action_sequence"]  
    Bsz, T, _ = S.shape

    # Sampling z using the reparameterization trick where z = mu(tau) + std(tau) * epsilon
    mu_q, std_q = q_phi(S, A) # [B, z_dim]
    eps = torch.randn_like(mu_q)
    z = (mu_q + std_q * eps)      

    # Freeze the weights of the low-level skill-conditioned policy, pi_theta
    with torch.no_grad():
        z_bt = z.unsqueeze(1).expand(Bsz, T, -1) # [B, T, z_dim]
        # Flatten time dimension since pi_theta does not the expect extra dim ([B * T, s_dim])
        Sf = S.reshape(Bsz*T, -1)
        Zf = z_bt.reshape(Bsz*T, -1)
        mu_pi, std_pi = pi_theta(Sf, Zf)     
        # Restore time dimension ([B, T, a_dim])
        mu_pi  = mu_pi.view(Bsz, T, -1)
        std_pi = std_pi.view(Bsz, T, -1)
        # Build a MVN over independent actions at each timestep of each batch and sum log_probs across action_dim
        pi_dist = Independent(Normal(mu_pi, std_pi), 1)
        # Compute the log probability of observed actions
        log_pi = pi_dist.log_prob(A) # ([B, T])    

    # Freeze the weights of the skill prior, p_omega
    with torch.no_grad():
        # Find the distribution of the abstract skill given start states (mu_pr & std_pr: [B, z_dim])
        mu_pr, std_pr = p_omega(s0)  
        # Build a MVN over independent skills at each timestep of each batch and sum log_probs across z_dim
        prior_dist = Independent(Normal(mu_pr, std_pr), 1)
        # Compute the log-probability over the sampled skills using the prior
        log_p_omega = prior_dist.log_prob(z)   

    post_dist = Independent(Normal(mu_q, std_q), 1)
    # Compute the log-probability over the sampled skills using the inferred posterior
    log_q = post_dist.log_prob(z)              

    # Calculate the E-objective
    e_obj = (log_pi.sum(dim=1) + log_p_omega - log_q)
    e_loss = -e_obj.mean() # minimize the negative objective
    return e_loss



def compute_m_loss(batch):
    s0  = batch["s0"]               
    S   = batch["state_sequence"]   
    A   = batch["action_sequence"] 
    sT  = batch["sT"]               
    Bsz, T, Sdim = S.shape

    # Freeze the weights of the inferred posterior, only update omega, psi, and theta
    with torch.no_grad():
        mu_q, std_q = q_phi(S, A)        # [B, z_dim]
        eps = torch.randn_like(mu_q)
        # same reparameterization trick as E-loss
        z = (mu_q + std_q * eps)  

    z_bt = z.unsqueeze(1).expand(Bsz, T, -1)
    # Flatten time dimension since pi_theta does not the expect extra dim ([B * T, s_dim])
    Sf = S.reshape(Bsz*T, -1)
    Zf = z_bt.reshape(Bsz*T, -1)
    mu_pi, std_pi = pi_theta(Sf, Zf)    
    # Restore time dimension ([B, T, a_dim])     
    mu_pi = mu_pi.reshape(Bsz, T, -1)
    std_pi = std_pi.reshape(Bsz, T, -1)
    # Build a MVN over independent actions at each timestep of each batch and sum log_probs across action_dim
    pi_dist = Independent(Normal(mu_pi, std_pi), 1)
    log_pi = pi_dist.log_prob(A)               

    # Zs: [B, T, z_dim]
    mu_T, std_T = p_psi(s0, z)
    # Build a MVN over independent terminal states at each timestep of each batch and sum log_probs across state_dim     
    ppsi_dist = Independent(Normal(mu_T, std_T), 1)
    # Compute the log-probability over the observed terminal states using the TAWM
    log_p_psi = ppsi_dist.log_prob(sT)          

    # Find the distribution of the abstract skill given start states (mu_pr & std_pr: [B, z_dim])
    mu_pr, std_pr = p_omega(s0)                
    # Build a MVN over independent skills at each timestep of each batch and sum log_probs across z_dim
    prior_dist = Independent(Normal(mu_pr, std_pr), 1)
    # Compute the log-probability over the sampled skills using the prior
    log_p_omega = prior_dist.log_prob(z)   

    # sum to find join log-likelihood of the whole sequence
    obj = log_pi.sum(dim=1) + log_p_psi + log_p_omega   
    m_loss = -obj.mean()
    return m_loss


In [None]:

def skill_model_training(
    loader,
    q_phi, pi_theta, p_psi, p_omega,
    e_lr=5e-5, m_lr=5e-5,
    epochs=50,
    e_steps=1, m_steps=1,
    grad_clip=1.0 # prevents runaway gradients
):
    # Skill model training setup
    q_phi.to(device)
    pi_theta.to(device)
    p_psi.to(device)
    p_omega.to(device)

    e_optimizer = torch.optim.Adam(q_phi.parameters(), lr=e_lr)
    m_optimizer = torch.optim.Adam(list(pi_theta.parameters()) + list(p_psi.parameters()) + list(p_omega.parameters()), lr=m_lr)

    e_curve = []   
    m_curve = []   

    for epoch in range(1, epochs+1):
        # Running e_loss, m_loss, and batches in current epoch
        e_running, m_running, nb = 0.0, 0.0, 0

        for batch in loader:
            # Rebuilds dictionary but moves tensors to the device
            batch = {k: v.to(device) for k, v in batch.items()}
            nb += 1

            # In E-step, train the posterior while freezing other parameters
            q_phi.train()
            pi_theta.eval()
            p_psi.eval()
            p_omega.eval()

            for p in q_phi.parameters(): 
                p.requires_grad_(True)
            for m in (pi_theta, p_psi, p_omega):
                for p in m.parameters(): 
                    p.requires_grad_(False)

            # For the e-step, 
            for _ in range(e_steps):
                # Resent gradients
                e_optimizer.zero_grad(set_to_none=True)
                e_loss = compute_e_loss(batch)
                e_loss.backward() 
                if grad_clip is not None:
                    torch.nn.utils.clip_grad_norm_(q_phi.parameters(), grad_clip)
                e_optimizer.step() # Update the parameters of the posterior

            e_running += e_loss.item()

            # Freeze posterior weights, update all other weights
            q_phi.eval()
            pi_theta.train()
            p_psi.train()
            p_omega.train()

            for p in q_phi.parameters(): 
                p.requires_grad_(False)
            for m in (pi_theta, p_psi, p_omega):
                for p in m.parameters(): 
                    p.requires_grad_(True)

            for _ in range(m_steps):
                # Reset gradients
                m_optimizer.zero_grad(set_to_none=True)
                m_loss = compute_m_loss(batch)
                m_loss.backward()
                if grad_clip is not None:
                    torch.nn.utils.clip_grad_norm_(list(pi_theta.parameters()) + list(p_psi.parameters()) + list(p_omega.parameters()), grad_clip)
                m_optimizer.step() # Update theta, psi, and omega

            m_running += m_loss.item()

        # Calculate the average losses over all the batches in the epoch
        e_epoch = e_running / max(1, nb)
        m_epoch = m_running / max(1, nb)
        e_curve.append(e_epoch)
        m_curve.append(m_epoch)
        print(f"[Epoch {epoch:03d}/{epochs}]  E: {e_epoch:.4f}   M: {m_epoch:.4f}")



    plt.figure(figsize=(6,4))
    plt.plot(e_curve, label="E-loss")
    plt.plot(m_curve, label="M-loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("EM training losses")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    return e_curve, m_curve


In [473]:
skill_model_training(loader, q_phi, pi_theta, p_psi, p_omega, epochs=25)

[Epoch 001/25]  E: 47.1526   M: 612.1870
[Epoch 002/25]  E: 31.5192   M: 588.5160
[Epoch 003/25]  E: 24.8898   M: 576.8893
[Epoch 004/25]  E: 21.0094   M: 569.8205
[Epoch 005/25]  E: 18.4950   M: 564.6932
[Epoch 006/25]  E: 16.6712   M: 560.8886
[Epoch 007/25]  E: 15.2432   M: 558.1020
[Epoch 008/25]  E: 14.0861   M: 555.8779
[Epoch 009/25]  E: 13.1440   M: 554.0309
[Epoch 010/25]  E: 12.3362   M: 552.3174
[Epoch 011/25]  E: 11.6588   M: 551.0306
[Epoch 012/25]  E: 11.0686   M: 549.7999
[Epoch 013/25]  E: 10.5685   M: 548.7119
[Epoch 014/25]  E: 10.1259   M: 547.9745
[Epoch 015/25]  E: 9.7445   M: 546.9560
[Epoch 016/25]  E: 9.4035   M: 546.2930
[Epoch 017/25]  E: 9.0964   M: 545.6731
[Epoch 018/25]  E: 8.8150   M: 545.1026
[Epoch 019/25]  E: 8.5697   M: 544.4341
[Epoch 020/25]  E: 8.3305   M: 543.9675
[Epoch 021/25]  E: 8.0972   M: 543.5393
[Epoch 022/25]  E: 7.9058   M: 543.1313
[Epoch 023/25]  E: 7.7034   M: 542.7275
[Epoch 024/25]  E: 7.5208   M: 542.4650
[Epoch 025/25]  E: 7.3560 

  plt.tight_layout(); plt.show()


([47.15259306121089,
  31.51918206171687,
  24.889834441447185,
  21.00940923575548,
  18.495002500240176,
  16.67117750522233,
  15.243227473486586,
  14.086138733924336,
  13.14402588674309,
  12.336182354584201,
  11.658840118001955,
  11.068605773110404,
  10.568502830522657,
  10.125925442225263,
  9.74453379175094,
  9.403523720337185,
  9.096398457538145,
  8.814963859769874,
  8.569735810826913,
  8.3305361769632,
  8.097151646864377,
  7.905775060881143,
  7.703362999401784,
  7.520768780740727,
  7.356005560852069],
 [612.1869514580579,
  588.5159956917662,
  576.8893495127513,
  569.8204593531675,
  564.6932217416446,
  560.8886068753245,
  558.1020340933901,
  555.8779325087625,
  554.0308801011377,
  552.3174048847302,
  551.030571901258,
  549.7999219819499,
  548.7119467375142,
  547.9745004152603,
  546.9560301443602,
  546.2930223908669,
  545.6730867414676,
  545.1025967600842,
  544.4340908269508,
  543.9674537128552,
  543.5393139058369,
  543.1312514198511,
  542.7

In [489]:
@torch.no_grad()
def visualize_skills(p_omega, p_psi, s0_full, K=100):
    """
    Takes as input a start state, s0.
    Then, samples K skills randomly from the skill prior.
    Finally, uses the TAWM to predict the terminal state of each skill.

    Plots s0->sT_pred as vectors using the x-y coordinates.
    """
    s0_full = torch.as_tensor(s0_full, dtype=torch.float32)
    s0_b = s0_full.unsqueeze(0) # ensure [1, state_dim]

    mu_z, std_z = p_omega(s0_b) # compute the metrics for a sampled skill                            
    Z = mu_z.shape[-1]

    # Sample K epsilons and map to skills using reparameterization trick
    eps = torch.randn(K, Z)                  
    z = mu_z + std_z * eps                                  

    s0_expanded = s0_b.expand(K, -1) 
    # Preduct the terminal state using sampled skills and start state                      
    sT_mu, _ = p_psi(s0_expanded, z)                        

    xy0 = s0_b[..., -2:]                                      
    xyT = sT_mu[..., -2:]                                     

    x0, y0 = xy0[0].tolist()

    plt.figure(figsize=(6,6))
    for i in range(K):
        plt.plot([x0, xyT[i,0].item()], [y0, xyT[i,1].item()], alpha=0.25)

    plt.scatter([x0], [y0], c="k", s=60, label="start")
    plt.axis("equal")
    plt.legend()
    plt.title("Sampled skills from prior at s0 (TAWM endpoints)")
    plt.tight_layout()
    plt.savefig("skill_samples.png", dpi=150)
    print("Saved skill_samples.png")


s0_env, _ = env.reset(seed=0)                              
obs = torch.tensor(s0_env["observation"], dtype=torch.float32)
ach = torch.tensor(s0_env["achieved_goal"], dtype=torch.float32)

s0_full = torch.cat([obs, ach], dim=-1)                   
visualize_skills(p_omega, p_psi, s0_full, K=100)


Saved skill_samples.png


In [None]:
# CEM Planner Parameters
H = 40
K = 1000 # batch size at each iteration of planning
L = 3 # length of sequence of epsilons in the batch
N_keep = 200 # number of sequences of skills to keep
N_iters = 10 # number of iterations to update the diagonal gaussian
tau = 30 # number of timesteps before replanning

In [507]:
LOG2PI = math.log(2 * math.pi)

@torch.no_grad()
def diag_gaussian_log_p(x, mu, std):
    # Computes the log p(x) of a diagonal gaussian
    var = std * std
    return (-0.5 * ((x - mu)**2 / var + 2.0 * std.log() + LOG2PI)).sum(dim=-1)

class CEMPlanner:
    def __init__(self, pi_theta, p_psi, p_omega, z_dim=Z_DIM, plan_len=L,
                 batch_size=K, n_keep=N_keep, iters=N_iters):
        self.pi_theta = pi_theta.eval()
        self.p_psi = p_psi.eval()
        self.p_omega = p_omega.eval()
        self.z_dim = z_dim
        self.L = plan_len
        self.batch_size = batch_size
        self.k = n_keep
        self.iters = iters
        # Keep track of the current mean and std of the diagonal Gaussian
        self.eps_mean = torch.zeros(self.L, z_dim)
        self.eps_std  = torch.ones (self.L, z_dim)

    @torch.no_grad()
    def _eps_to_z_seq(self, s0, eps_seq):
        """
        Given a eps_seq, convert each epsilon into a skill using the skill prior and use the TAWM to predict the state after the skill
        Output: returns the sequence of skills and the final predicted state
        """
        s = s0 # [1, state_dim]
        z_seq = []
        for t in range(self.L):
            mu_w, std_w = self.p_omega(s)               
            z_t = mu_w + std_w * eps_seq[t:t+1, :] # convert the et to zt 
            z_seq.append(z_t)
            # roll abstract world model to get next state (use mean for planning)
            mu_T, _ = self.p_psi(s, z_t)                
            s = mu_T
        z_seq = torch.cat(z_seq, dim=0)                 
        return z_seq, s # final predicted state after L skills

    @torch.no_grad()
    def _cost_fn(self, s0, goal_xy, eps_batch):
        """
        Computes the cost of each sequence of skills in the batch
        """
        # eps_batch: [K, L]
        N = eps_batch.shape[0]
        costs = torch.empty(N)
        # For each batch, find the final predicted state and extract the xy coordinates from the state. Then, compute the l2 distance
        for i in range(N):
            _, sT = self._eps_to_z_seq(s0, eps_batch[i])    
            sT_xy = sT[..., -2:]                 
            costs[i] = torch.linalg.norm(sT_xy[0] - goal_xy[0], ord=2)
        return costs

    @torch.no_grad()
    def plan(self, s0, goal):
        """
        Given a start state, fit a diagonal Gaussian using the N_keep best costs for N_iters
        """
        mean = self.eps_mean.clone()
        std  = self.eps_std.clone()

        for _ in range(self.iters):
            eps = mean + std * torch.randn(self.batch_size, self.L, self.z_dim) # sample K L-length sequences of epsilons from a unit Gaussian
            costs = self._cost_fn(s0, goal, eps) # compute the costs of each sequence
            top_k = torch.topk(-costs, self.k).indices # extract the indices of the N_keep best
            top_k_eps = eps[top_k]                           
            mean = top_k_eps.mean(dim=0) # take the mean and std of the epsilon sequences across the N_keep best sequences
            std  = top_k_eps.std(dim=0) + 1e-6

        final_eps = mean + std * torch.randn(self.batch_size, self.L, self.z_dim)
        final_costs = self._cost_fn(s0, goal, final_eps)
        best_idx = torch.argmin(final_costs)
        best_eps = final_eps[best_idx]

        self.eps_mean = torch.zeros_like(self.eps_mean)
        self.eps_std  = torch.ones_like(self.eps_std)

        # find the sequence of skills in the best sequence
        z_seq, _ = self._eps_to_z_seq(s0, best_eps)
        return z_seq[0:1, :], mean, std # z_seq[0:1, :]: [1, z_dim] returns the first skill in the best sequence



In [None]:
def combine_state(obs, device):
    base = np.asarray(obs["observation"], dtype=np.float32)        # [27,]
    ach = np.asarray(obs["achieved_goal"], dtype=np.float32)      # [2,]
    s = np.concatenate([base, ach], axis=-1)                    # [29,]
    return torch.as_tensor(s, dtype=torch.float32, device=device).unsqueeze(0)  # [1,27]

def get_goal_xy(obs, device):
    assert isinstance(obs, dict)
    g = np.asarray(obs["desired_goal"], dtype=np.float32)          # (2,)
    return torch.as_tensor(g, dtype=torch.float32, device=device).unsqueeze(0)  # [1,2]




planner = CEMPlanner(
    pi_theta, p_psi, p_omega,
    z_dim=Z_DIM, plan_len=L,
    batch_size=K, n_keep=N_keep, iters=N_iters 
)

max_steps = 2000
obs = env.reset(seed=0)[0]

state   = combine_state(obs, device)      
goal_xy = get_goal_xy(obs, device) 

steps = 0
while steps < max_steps:
    z_1, eps_mean, eps_std = planner.plan(state, goal_xy) # z_1: [1, Z_DIM]

    # Execute this skill for τ low-level steps, then replan
    for _ in range(tau):
        with torch.no_grad():
            mu_a, _ = pi_theta(state, z_1) # [1, a_dim]
            a = mu_a.squeeze(0).cpu().numpy().astype(np.float32)

        next_obs, reward, terminated, truncated, info = env.step(a)
        done  = terminated or truncated
        state = combine_state(next_obs, device) # rebuild [1,29] by adding in the xy location
        steps += 1

        # Distance check using last two dims of the concatenated state
        curr_xy = state[..., -2:]      
        if torch.linalg.vector_norm(curr_xy - goal_xy[0]) < 1.0 or done:
            break

    curr_xy = state[..., -2:] 
    if torch.linalg.vector_norm(curr_xy - goal_xy[0]) < 1.0 or done:
        break
