# Overview

Model-Based Temporal Abstraction involves simultaneuously learning
1) skill-conditioned low-level policy
2) skill-conditioned temporally abstract world model

Notation
- skill-conditioned low-level policy: $\pi_{\theta}(a_t|s_t, z)$
    - $\theta$ are parameters
    - $a_t \in A$ is current action selected by agent
    - $s_t \in S$ is current state
    - $z \in Z$ is abstract skill variable that encodes a skill

- skill-conditioned temporally abstract world model (TAWM): $p_{\psi}(s'|s,z)$ (models distribution of states agent is in after skill $z$)
    - $\psi$ parameters
    - $z$ is current skill

Note: low-level policy and TAWM not trained on rewards, reward function is provided later for planning with the learned skills 





### Learning $\pi_{\theta}$ and $p_{\psi}$

Learning $\pi_{\theta}$ and $p_{\psi}$ requires treating skills as latent variables and optimizing the ELBO

$$
\mathcal{L}(\theta,\psi,\phi,\omega)
= \mathbb{E}_{\tau_T \sim \mathcal{D}}\!\left[
  \mathbb{E}_{q_\phi(z\,|\,\tau_T)}\!\left[
    \log \pi_\theta(\bar{a}\,|\,\bar{s}, z)
    + \log p_\psi(s_T \,|\, s_0, z)
  \right]
  - D_{\mathrm{KL}}\!\left(q_\phi(z\,|\,\tau_T)\,\|\,p_\omega(z\,|\,s_0)\right)
\right].
$$

where $\tau_T$ is a T-length subtrajectory sampled from the offline dataset $\mathcal{D}$, $\bar{s}$ and $\bar{a}$ are state and action sequences of $\tau_T$, $q_{\psi}$ is a posterior over $z$ given $\tau_T$, and $p_{\omega}$ is a prior of $z$ given $s_0$.

The first term is the log-likelihood of demonstrator actions. This ensures that the low-level policy can reproduce a demonstrator's action sequence given a skill. This forces $z$ to encode control-relevant information.

The second term is the log-likelihood of long-term state transitions. This term ensures that we learn relationships between $z$ to what possible $s_T$ could result from. the skill.

Finally, the last term is the KL divergence between skill posterior and prior (encourages compression of skills). Therefore, maximizing this ELBO makes skills $z$ explain the data and keeps the KL divergence small. This ensures that the skill is start-state predictable.

#### The Expectation-Maximization (EM) Algorithm

Since calculating the true posterior of $z$ given $\tau_T$ is intractable, we infer $q_{\psi}(z|\tau_T)$.

1. E-Step:
- Update $\psi$ w/gradient descent so that KL divergence between $q_\psi$ and true posterior is minimized

2. M-Step:
- Fixing $q_{\psi}$, update ($\theta, \psi, \omega$) s.t. ELBO is maximized using gradient ascent.


In [4]:
import torch
import torch.nn as nn
from torch.distributions import Normal, kl_divergence
import numpy as np

##### E-Step (Update $\psi$)

In this step, we want to minimize 

$$
\mathbb{E}_{\mathcal{T}_T\sim\mathcal{D}}
\Bigg[
\mathbb{E}_{z\sim q_\phi}
\bigg[
\log \frac{q_\phi\!\left(z\mid \bar{s},\bar{a}\right)}
{\pi_\theta\!\left(\bar{a}\mid \bar{s}, z\right)\,p_\omega\!\left(z\mid s_0\right)}
\bigg]
\Bigg]
$$

Equivalently, we want to minimize $\mathcal{KL}(q_{\psi}||p)$ where $p$, the true posterior is 

$$
p(z \mid \bar{s}, \bar{a}) = \frac{1}{\eta}\,\pi_\theta(\bar{a}\mid \bar{s}, z)\,p_\omega(z\mid s_0).
$$



In [11]:
NUM_NEURONS = 256
Z_DIM = 256

# Skill Posterior, q_psi
class SkillPosterior(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.fc1 = nn.Linear(in_features=self.state_dim, out_features=NUM_NEURONS)
        self.relu = nn.ReLU()
        self.bi_gru = nn.GRU(input_size=NUM_NEURONS+self.action_dim, hidden_size= NUM_NEURONS/2, bidirectional=True, batch_first=True)
        self.mean = MeanNetwork()
        self.std = StandardDeviationNetwork()

    def forward(self, state_sequence, action_sequence):
        embedded_states = self.relu(self.fc1(state_sequence))
        concatenated = torch.cat([embedded_states, action_sequence], dim=0)
        x, _ = self.bi_gru(concatenated)
        mean = MeanNetwork.forward(x)
        std = StandardDeviationNetwork.forward(x)
        return mean, std


# Low-Level Skill-Conditioned Policy, pi_theta
class SkillPolicy(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        
        self.state_dim = state_dim
        self.fc1 = nn.Linear(in_features=self.state_dim+Z_DIM, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.relu = nn.Relu()
        self.mean = MeanNetwork()
        self.std = StandardDeviationNetwork()
    
    def forward(self, state, z):
        c = torch.cat([state, z], dim=0)
        x = self.relu(self.fc1(c))
        x = self.relu(self.fc2(x))
        mean = self.mean(x)
        std = self.std(x)
        return mean, std
        

# Temporally-Abstract World Model, p_psi
class TAWM(nn.Module):
    def __init__(self):
        pass

# Skill Prior, p_omega
class SkillPrior(nn.Module):
    def __init__(self):
        pass


class MeanNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
        
class StandardDeviationNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.fc2 = nn.Linear(in_features=NUM_NEURONS, out_features=NUM_NEURONS)
        self.relu = nn.ReLU()
        self.softplus = nn.Softplus()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softplus(x)
        return x




