# Overview

Model-Based Temporal Abstraction involves simultaneuously learning
1) skill-conditioned low-level policy
2) skill-conditioned temporally abstract world model

Notation
- skill-conditioned low-level policy: $\pi_{\theta}(a_t|s_t, z)$
    - $\theta$ are parameters
    - $a_t \in A$ is current action selected by agent
    - $s_t \in S$ is current state
    - $z \in Z$ is abstract skill variable that encodes a skill

- skill-conditioned temporally abstract world model (TAWM): $p_{\psi}(s'|s,z)$ (models distribution of states agent is in after skill $z$)
    - $\psi$ parameters
    - $z$ is current skill

Note: low-level policy and TAWM not trained on rewards, reward function is provided later for planning with the learned skills 





In [183]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchrl.modules import CEMPlanner
from torch.distributions import Normal, Independent, kl_divergence
import numpy as np
import minari
import math
from torch.utils.data import Dataset, DataLoader


device = "cuda" if torch.cuda.is_available() else "cpu"

### Learning $\pi_{\theta}$ and $p_{\psi}$

Learning $\pi_{\theta}$ and $p_{\psi}$ requires treating skills as latent variables and optimizing the ELBO

$$
\mathcal{L}(\theta,\psi,\phi,\omega)
= \mathbb{E}_{\tau_T \sim \mathcal{D}}\!\left[
  \mathbb{E}_{q_\phi(z\,|\,\tau_T)}\!\left[
    \log \pi_\theta(\bar{a}\,|\,\bar{s}, z)
    + \log p_\psi(s_T \,|\, s_0, z)
  \right]
  - D_{\mathrm{KL}}\!\left(q_\phi(z\,|\,\tau_T)\,\|\,p_\omega(z\,|\,s_0)\right)
\right].
$$

where $\tau_T$ is a T-length subtrajectory sampled from the offline dataset $\mathcal{D}$, $\bar{s}$ and $\bar{a}$ are state and action sequences of $\tau_T$, $q_{\psi}$ is a posterior over $z$ given $\tau_T$, and $p_{\omega}$ is a prior of $z$ given $s_0$.

The first term is the log-likelihood of demonstrator actions. This ensures that the low-level policy can reproduce a demonstrator's action sequence given a skill. This forces $z$ to encode control-relevant information.

The second term is the log-likelihood of long-term state transitions. This term ensures that we learn relationships between $z$ to what possible $s_T$ could result from. the skill.

Finally, the last term is the KL divergence between skill posterior and prior (encourages compression of skills). Therefore, maximizing this ELBO makes skills $z$ explain the data and keeps the KL divergence small. This ensures that the skill is start-state predictable.

In [184]:
NUM_NEURONS = 256
Z_DIM = 256

# Skill Posterior, q_phi
class SkillPosterior(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.fc1 = nn.Linear(in_features=self.state_dim, out_features=NUM_NEURONS)
        self.relu = nn.ReLU()
        self.bi_gru = nn.GRU(input_size=NUM_NEURONS+self.action_dim, hidden_size= NUM_NEURONS//2, bidirectional=True, batch_first=True)
        self.mean = MeanNetwork(Z_DIM)
        self.std = StandardDeviationNetwork(Z_DIM)

    def forward(self, state_sequence, action_sequence):
        embedded_states = self.relu(self.fc1(state_sequence))
        concatenated = torch.cat([embedded_states, action_sequence], dim=-1)
        x, _ = self.bi_gru(concatenated)
        mean = self.mean.forward(x)
        std = self.std.forward(x)
        return mean, std


# Low-Level Skill-Conditioned Policy, pi_theta
class SkillPolicy(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.fc1 = nn.Linear(in_features=self.state_dim+Z_DIM, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.relu = nn.ReLU()
        self.mean = MeanNetwork(self.action_dim)
        self.std = StandardDeviationNetwork(self.action_dim)
    
    def forward(self, state, z):
        c = torch.cat([state, z], dim=-1)
        x = self.relu(self.fc1(c))
        x = self.relu(self.fc2(x))
        mean = self.mean(x)
        std = self.std(x)
        return mean, std
        

# Temporally-Abstract World Model, p_psi
class TAWM(nn.Module):
    def __init__(self, state_dim):
        super().__init__()

        self.state_dim = state_dim
        self.fc1 = nn.Linear(in_features=self.state_dim+Z_DIM, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.relu = nn.ReLU()
        self.mean = MeanNetwork(self.state_dim)
        self.std = StandardDeviationNetwork(self.state_dim)
    
    def forward(self, input_state, z):
        c = torch.cat([input_state, z], dim=-1)
        x = self.relu(self.fc1(c))
        x = self.relu(self.fc2(x))
        mean = self.mean(x)
        std = self.std(x)
        return mean, std

# Skill Prior, p_omega
class SkillPrior(nn.Module):
    def __init__(self, state_dim):
        super().__init__()

        self.state_dim = state_dim
        self.fc1 = nn.Linear(in_features=self.state_dim, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.relu = nn.ReLU()
        self.mean = MeanNetwork(Z_DIM)
        self.std = StandardDeviationNetwork(Z_DIM)
    
    def forward(self, input_state):
        x = self.relu(self.fc1(input_state))
        x = self.relu(self.fc2(x))
        mean = self.mean(x)
        std = self.std(x)
        return mean, std

class MeanNetwork(nn.Module):
    def __init__(self, out_dim):
        super().__init__()
        
        self.fc1 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=out_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
        
class StandardDeviationNetwork(nn.Module):
    def __init__(self, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=out_dim)
        self.relu = nn.ReLU()
        self.softplus = nn.Softplus()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softplus(x)
        return x

#### The Expectation-Maximization (EM) Algorithm

Since calculating the true posterior of $z$ given $\tau_T$ is intractable, we infer $q_{\psi}(z|\tau_T)$.

1. E-Step:
- Update $\psi$ w/gradient descent so that KL divergence between $q_\psi$ and true posterior is minimized

2. M-Step:
- Fixing $q_{\psi}$, update ($\theta, \psi, \omega$) s.t. ELBO is maximized using gradient ascent.


##### E-Step (Update $\psi$)

In this step, we want to minimize 

$$
\mathbb{E}_{\mathcal{T}_T\sim\mathcal{D}}
\Bigg[
\mathbb{E}_{z\sim q_\phi}
\bigg[
\log \frac{q_\phi\!\left(z\mid \bar{s},\bar{a}\right)}
{\pi_\theta\!\left(\bar{a}\mid \bar{s}, z\right)\,p_\omega\!\left(z\mid s_0\right)}
\bigg]
\Bigg]
$$

Equivalently, we want to minimize $\mathcal{KL}(q_{\psi}||p)$ where $p$, the true posterior is 

$$
p(z \mid \bar{s}, \bar{a}) = \frac{1}{\eta}\,\pi_\theta(\bar{a}\mid \bar{s}, z)\,p_\omega(z\mid s_0).
$$


##### M-Step (Update $\psi$)

In this step, we want to update $\theta$, $\psi$, and $\omega$ using gradient ascent to maximize the ELBO from above.

Both steps are trained using an Adam optimizer.


##### Dataset 1: AntMaze

In [185]:
ant_maze_dataset = minari.load_dataset('D4RL/antmaze/medium-play-v1')

In [186]:
print(ant_maze_dataset[0].actions.shape)
print(ant_maze_dataset[0].observations.keys())
print(ant_maze_dataset[0].observations["observation"].shape)

(1000, 8)
dict_keys(['achieved_goal', 'desired_goal', 'observation'])
(1001, 27)


In [187]:
# B, the number of subtrajectories per batch
B = 100

# T, the length of each subtrajectory
T = 10

# AntMaze state and action dims
state_dim = 27
action_dim = 8

q_phi = SkillPosterior(state_dim=state_dim, action_dim=action_dim).to(device)
pi_theta = SkillPolicy(state_dim=state_dim, action_dim=action_dim).to(device)
p_psi = TAWM(state_dim=state_dim).to(device)
p_omega = SkillPrior(state_dim=state_dim).to(device)

In [188]:
# E loss
def compute_e_loss(subtrajectory_batch):
    batch_s0, batch_states, batch_actions, batch_sT = subtrajectory_batch["s0"], subtrajectory_batch["state_sequence"], subtrajectory_batch["action_sequence"], subtrajectory_batch["sT"]
    
    zs = []
    for i in range(len(batch_s0)):
        mu, std = q_phi(batch_states[i], batch_actions[i])
        epsilon = torch.randn_like(mu) # epsilon ~ N(0, 1)
        z = mu + std * epsilon
        zs.append(z)

    e_loss = 0
    for i in range(len(batch_s0)):
        s0, states, actions, sT = batch_s0[i], batch_states[i], batch_actions[i], batch_sT[i]
        mu_pi, std_pi = pi_theta(states, zs[i])
        pi_dist = Independent(Normal(mu_pi, std_pi), 1)
        log_pi_t = pi_dist.log_prob(actions)
        log_pi_seq = log_pi_t.sum(dim=0)

        mu_p_omega, std_p_omega = p_omega(s0)
        p_omega_dist = Independent(Normal(mu_p_omega, std_p_omega), 1)
        log_p_omega_t = p_omega_dist.log_prob(zs[i])
        log_p_omega_seq = log_p_omega_t.sum(dim=0)

        mu_q, std_q = q_phi(states, actions)
        q_dist = Independent(Normal(mu_q, std_q), 1)
        log_q_t = q_dist.log_prob(zs[i])
        log_q_seq = log_q_t.sum(dim=0)  

        e_loss += log_pi_seq + log_p_omega_seq - log_q_seq

    return e_loss * -(1/B)

    
# M loss
def compute_m_loss(subtrajectory_batch):
    
    zs = []
    batch_s0, batch_states, batch_actions, batch_sT = (subtrajectory_batch["s0"], subtrajectory_batch["state_sequence"], subtrajectory_batch["action_sequence"], subtrajectory_batch["sT"]) 
    
    for i in range(len(batch_s0)):
        mu, std = q_phi(batch_states[i], batch_actions[i])
        epsilon = torch.randn_like(mu) # epsilon ~ N(0, 1)
        z = mu + std * epsilon
        zs.append(z)

    m_loss = 0
    for i in range(len(batch_s0)):
        s0, states, actions, sT = batch_s0[i], batch_states[i], batch_actions[i], batch_sT[i]
        mu_pi, std_pi = pi_theta(states, zs[i])
        pi_dist = Independent(Normal(mu_pi, std_pi), 1)
        log_pi_t = pi_dist.log_prob(actions)
        log_pi_seq = log_pi_t.sum(dim=0)

        z_seq = zs[i].mean(dim=0)
        mu_p_psi, std_p_psi = p_psi(s0, z_seq)
        p_psi_dist = Independent(Normal(mu_p_psi, std_p_psi), 1)
        log_p_psi_t = p_psi_dist.log_prob(sT)
        log_p_psi_seq = log_p_psi_t

        mu_p_omega, std_p_omega = p_omega(s0)
        p_omega_dist = Independent(Normal(mu_p_omega, std_p_omega), 1)
        log_p_omega_seq = p_omega_dist.log_prob(z_seq)

        m_loss += log_pi_seq + log_p_psi_seq + log_p_omega_seq

    return m_loss * -(1/B)


In [151]:
class TrainingDataset(Dataset):

    def __init__(self, ant_maze_dataset, T):
        self.T = T
        self.subtrajectories = []

        for ep in ant_maze_dataset.iterate_episodes():
            s = ep.observations["observation"]
            a = ep.actions
            l = len(s)
            if l < T + 1:
                continue
        
        for t in range(0, l - T):
            s0 = s[t]
            state_sequence = s[t: t + T]
            action_sequence = a[t: t + T]
            sT = s[t + T]
            self.subtrajectories.append((s0, state_sequence, action_sequence, sT))

    def __len__(self):
        return len(self.subtrajectories)

    def __getitem__(self, idx):
        s0, state_sequence, action_sequence, sT = self.subtrajectories[idx]
        return {
            "s0": torch.as_tensor(s0, dtype=torch.float32),
            "state_sequence": torch.as_tensor(state_sequence, dtype=torch.float32),
            "action_sequence": torch.as_tensor(action_sequence, dtype=torch.float32),
            "sT": torch.as_tensor(sT, dtype=torch.float32)
        }
    
def collate(batch):
    return {
        "s0": torch.stack([b["s0"] for b in batch], 0),
        "state_sequence": torch.stack([b["state_sequence"] for b in batch], 0),
        "action_sequence": torch.stack([b["action_sequence"] for b in batch], 0),
        "sT": torch.stack([b["sT"] for b in batch], 0)
    }

dataset = TrainingDataset(ant_maze_dataset, T)
loader = DataLoader(dataset, batch_size=B, shuffle=True, collate_fn=collate)

In [189]:
def skill_model_training():
    q_phi.to(device).train()
    pi_theta.to(device).train()
    p_psi.to(device).train()
    p_omega.to(device).train()

    e_optimizer = optim.Adam(q_phi.parameters(), lr=5e-5)
    m_optimizer = optim.Adam(list(pi_theta.parameters()) + list(p_psi.parameters()) + list(p_omega.parameters()), lr=5e-5)

    for batch in loader:
        for p in q_phi.parameters():
            p.requires_grad_(True)
        for m in (pi_theta, p_psi, p_omega):
            for p in m.parameters():
                p.requires_grad_(False)
                
        e_optimizer.zero_grad(set_to_none=True)
        e_loss = compute_e_loss(batch)
        e_loss.backward()
        e_optimizer.step()

        # freeze q_phi
        for p in q_phi.parameters():
            p.requires_grad_(False)
        for m in (pi_theta, p_psi, p_omega):
            for p in m.parameters():
                p.requires_grad_(True)
                
        m_optimizer.zero_grad(set_to_none=True)
        m_loss = compute_m_loss(batch)
        m_loss.backward()
        m_optimizer.step()
    
    

In [190]:
skill_model_training()

In [154]:
H = 40
K = 1000
L = 3
N_keep = 200
N_iters = 10
tau = 30