# Hierarchical Goal Latents (HIGL)

In traditional RL agents struggle to explore effectively. Hierarchical RL tackles this by decomposing a task into temporally abstract sub-tasks. However a major bottleneck lies in defining what the high-level policy should output, instead of raw actions or hardcoded goals, HIGL proposes learning a continuous latent goal space, allowing flexible and reusable nehaviours.
HIGL: Learn a goal embedding space + train a goal-conditioned low-level policy to reach high-level latent goals.

## Backgorund
1. **Two-level HRL**:
    -High-level policy: $\pi_h(z_g|s_t)$, output a latent goal $z_g\in \mathbb{R}^d$ every k steps.
    -Low-level policy: $\pi_l(a_t|s_t,z_g)$, tries to reach the current goal $z_g$ from the current state $s_t$.

2. **Goal Encoder $E: s \leftarrow z$**:
    -Learns to map state observations $s\in\mathcal{S}$ into latent goals $z\in\mathcal{Z}$.
    -Often implemented as CNN or MLP.
3. **Discriminator(Optional)**:
    - Encourages the latent goals to be informative and diverse, ensuring that goals correspond to distinct, reachable behaviours.

## Theoretical Analysis

We define the framework as a Semi-MDP:
- Low-level policy Objective: Train a policy $\pi_l(a_t|s_t,z_g)$ to reach a goal $z_g$, the latent representation of a future state. We define the goal-reaching reward as:
$$r_t^{goal} = - ||E(s_{t+1})-z_g||^2$$
- High.level Policy Objective: Every k steps, a high-level policy selects a new latent goal:
$$z_g \sim \pi_h(z_g|s_t)$$
It is trained via policy gradient, but using a reward signal that reflects the success of reaching the goal.
This can be Extrinsic from the env, or goal-reaching reward defined above.
- Training the Encoder: The encoder E must ensure that latent distances reflect the true notion of reaching. There are two popular ways to train it:
   1. Temporal Contrastive Loss: Maximizes the mutual information between current state $s_t$ and future state $s_{t+k}$:
   $$\mathcal{L}_{contrastive} = - \log \frac{\exp(\sim(E(s_t),E(s_{t+k}))}{\sum_{i=1}^N \exp(\sim(E(s_t),E(s_i)))}$$
   Where $\sim$ is cosine similarity or dot product, and $s_j$ are negative samples.
   2. Reconstruction Loss: Make the latent goal predictive of some future state or reward:
   $$\mathcal{L}_{reconstruction} = ||\hat{s}_{t+k} - s_{t+k}||^2$$

The high-level policy only makes decisions every k steps. The low-level policy continues to act, conditioned on the same latent goal $z_g$, until the next high-level decision.
This introduces a temporal abstraction:
- Reduces the frequency of high-level decisions.
- Encourages the low-level policy to learn to reach the goal.
- Improves sample efficiency and long-horizon planning.

## Loss function

**Low-level SAC or PPO Loss**:
$$ r_t = - ||E(s_{t+k}) - z_g||^2$$
Train using standard SAC or PPO on this reward.

**High-Level PPO or REINFORCE**:
$$\mathcal{J}(\pi_h) = \mathbb{E}_{s_t,z_t}[R_{t:t+k}]$$
Optimize using the returns obtained while pursuing goal $z_g$.

**Optional Contrastive/InfoNCE Loss for Encoder**:
$$\mathcal{L}_{info} = - \log \frac{\exp(E(s_t)^T E(s_{t+k}))}{\sum_{j}\exp(E(s_t)^T E(s_j))}$$

## Implementation

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np


In [2]:
class ConvEncoder(nn.Module):
    def __init__(self, input_shape=(3,96,96), latent_dim=64):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2), nn.ReLU(),
        )

        with torch.no_grad():
            dummy = torch.zeros(1, *input_shape)
            conv_out = self.conv(dummy)
            self.flattened = conv_out.view(1,-1).size(1)
        
        self.fc = nn.Linear(self.flattened, latent_dim)

    def forward(self, x):
        x = x/255.0
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [3]:
class SkillEmbedding(nn.Module):
    def __init__(self, num_skills, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(num_skills, embedding_dim)

    def forward(self, x):
        return self.embedding(x)

In [4]:
class MI_Estimator(nn.Module):
    def __init__(self, latent_dim, embedding_dim):
        super().__init__()
        self.project_state = nn.Linear(latent_dim, embedding_dim)
        self.project_skill = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, encoded_states, skill_embeddings):
        z = self.project_state(encoded_states)
        s = self.project_skill(skill_embeddings)

        logits = torch.matmul(z, s.T)
        labels = torch.arange(len(z), device=z.device)
        mi_loss = F.cross_entropy(logits, labels)

        return -mi_loss

In [5]:
class TransformerPolicy(nn.Module):
    def __init__(self, input_dim, skill_dim, action_dim, seq_len, hidden_dim=128, n_heads=4, n_layers=2):
        super().__init__()
        self.seq_len = seq_len
        self.input_proj = nn.Linear(input_dim+skill_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model = hidden_dim,
            nhead = n_heads,
            dim_feedforward = hidden_dim*2,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.head = nn.Linear(hidden_dim, action_dim)

    def forward(self, state_seq, skill_seq):
        x = torch.cat([state_seq, skill_seq], dim=-1)
        x = self.input_proj(x)
        x = self.transformer(x)
        out = self.head(x[:, -1])
        return torch.tanh(out)

In [6]:
class Critic(nn.Module):
    def __init__(self, input_dim, skill_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.q1 = nn.Sequential(
            nn.Linear(input_dim+skill_dim+action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        self.q2 = nn.Sequential(
            nn.Linear(input_dim+skill_dim+action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, state, action, skill):
        x = torch.cat([state, action, skill], dim=-1)
        return self.q1(x), self.q2(x)

In [7]:
class LatentGoalSACAgent(nn.Module):
    def __init__(self, image_shape, num_skills, action_dim, latent_dim=64, seq_len=10, device="cuda"):
        super().__init__()
        self.device = device
        self.seq_len = seq_len
        self.encoder = ConvEncoder(input_shape=image_shape, latent_dim=latent_dim).to(device)
        self.skill_gen = SkillEmbedding(num_skills, latent_dim).to(device)

        self.policy = TransformerPolicy(
            input_dim = latent_dim,
            skill_dim= latent_dim,
            action_dim= action_dim,
            seq_len = seq_len,
        ).to(device)

        self.critic = Critic(latent_dim, latent_dim, action_dim).to(device)
        self.target_critic = Critic(latent_dim, latent_dim, action_dim).to(device)
        self.target_critic.load_state_dict(self.critic.state_dict())

        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=3e-4)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=3e-4)

        self.log_alpha = torch.tensor(np.lot(0.1), requires_grad=True, device=device)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=3e-4)

        self.target_entropy = -action_dim

    @property
    def alpha(self):
        return self.log_alpha.exp()
    
    def act(self,obs_seq, skill_seq):
        with torch.no_grad():
            z_seq = self.encoder(obs_seq)
            a = self.policy(z_seq, skill_seq)
            return a

In [8]:
def update(agent, replay_buffer, batch_size=128, gamma=0.99, tau=0.005):
    obs_seq, skill_seq, actions, rewards, next_obs_seq, dones = replay_buffer.sample(batch_size)

    z_seq = agent.encoder(obs_seq)
    next_z_seq = agent.encoder(next_obs_seq)

    skill_embeddings = agent.skill_gen(skill_seq)

    with torch.no_grad():
        next_action = agent.policy(next_z_seq, skill_embeddings)
        q1_tgt, q2_tgt = agent.target_critic(next_z_seq[:,-1], next_action, skill_embeddings[:,-1])
        min_q_tgt = torch.min(q1_tgt, q2_tgt)
        target = rewards + (1-dones) * gamma * (min_q_tgt - agent.alpha * agent.log_prob(next_action))

    q1, q2 = agent.critic(z_seq[:,-1], actions, skill_embeddings[:,-1])
    critic_loss = F.mse_loss(q1, target) + F.mse_loss(q2, target)

    agent.critic_optimizer.zero_grad()
    critic_loss.backward()
    agent.critic_optimizer.step()

    new_action = agent.policy(z_seq, skill_embeddings)
    q1_new, q2_new = agent.critic(z_seq[:,-1], new_action, skill_embeddings[:,-1])
    min_q_new = torch.min(q1_new, q2_new)

    log_prob = agent.log_prob(new_action)
    policy_loss = (agent.alpha * log_prob - min_q_new).mean()

    agent.policy_optimizer.zero_grad()
    policy_loss.backward()
    agent.policy_optimizer.step()

    alpha_loss = -(agent.log_alpha * (log_prob + agent.target_entropy).detach()).mean()
    agent.alpha_optimizer.zero_grad()
    alpha_loss.backward()
    agent.alpha_optimizer.step()

    for target_param, param in zip(agent.target_critic.parameters(), agent.critic.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

In [9]:
class GoalReplayBuffer:
    def __init__(self, max_size, obs_shape, action_dim, latent_dim, seq_len=10):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        self.seq_len = seq_len

        self.obs = torch.zeros((max_size, seq_len, *obs_shape))
        self.next_obs = torch.zeros((max_size, seq_len, *obs_shape))
        self.actions = torch.zeros((max_size, action_dim))
        self.rewards = torch.zeros((max_size, 1))
        self.dones = torch.zeros((max_size, 1))
        self.skills = torch.zeros((max_size, seq_len, latent_dim))

    def add(self,obs,skill,action, reward, next_obs, done):
        self.obs[self.ptr] = obs
        self.skills[self.ptr] = skill
        self.actions[self.ptr] = action
        self.rewards[self.ptr] = reward
        self.next_obs[self.ptr] = next_obs
        self.dones[self.ptr] = done

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
    
    def sample(self,batch_size, device="cuda"):
        idxs = np.random.choice(0, self.size, size=batch_size)
        return (
            self.obs[idxs].to(device),
            self.skills[idxs].to(device),
            self.actions[idxs].to(device),
            self.rewards[idxs].to(device),
            self.next_obs[idxs].to(device),
            self.dones[idxs].to(device)
        )

In [10]:
def train_latent_goal_agent(env, agent, replay_buffer, episodes=1000, train_after=500, seq_len=10):
    all_rewards = []
    for ep in range(episodes):
        obs_seq, skill_seq = [], []
        obs = env.reset()[0]
        tot_reward = 0
        done = False

        for t in range(200):
            obs_tensor = preprocess_obs(obs)
            skill = agent.skill_gen.sample().squeeze(0)

            obs_seq.append(obs_tensor)
            skill_seq.append(skill)

            if len(obs_seq) > seq_len:
                action = env.action_space.sample()
            else:
                obs_seq_tensor = torch.stack(obs_seq[-seq_len:]).unsqueeze(0).to(device)
                skill_seq_tensor = torch.stack(skill_seq[-seq_len:]).unsqueeze(0).to(device)
                action = agent.act(obs_seq_tensor, skill_seq_tensor).cpu().numpy()

            next_obs, reward, terminated, truncated, _ = env.step(action)
            tot_reward += reward
            done = terminated or truncated

            if len(obs_seq) >= seq_len:
                replay_buffer.add(
                    torch.stack(obs_seq[-seq_len:]),
                    torch.stack(skill_seq[-seq_len:]),
                    torch.tensor(action,dtype=torch.float32),
                    torch.tensor([reward], dtype=torch.float32),
                    preprocess_obs(next_obs).unsqueeze(0),
                    torch.tensor([done], dtype=torch.float32)
                )

            obs = next_obs
            if done:
                break

            if replay_buffer.size >= train_after:
                for _ in range(5):
                    update(agent, replay_buffer)

            if ep % 10 == 0:
                print(f"Episode {ep}, Step {t}, Total Reward: {tot_reward:.2f}")
    
    return all_rewards

def preprocess_obs(obs):
    obs = torch.tensor(obs, dtype=torch.float32).permute(2,0,1)
    return obs / 255.0