# Latent Goal-Conditioned Hierarchical Reinforcement Learning

## Introduction

Traditional skill-learning methods like DIAYN or VALOR allow agents to learn diverse behaviors without external reward signals. However, these skills are often passive and lack direction towards solving meaningful long-term tasks. That's where goal-conditioned learning comes in.

In latent goal-conditioned HRL, the agent does not just execute skills it learns to reach goals in a learned latent space. The idea is to:
- Learn compact represetnations of goals
- Condition the policy on a goal vector g, instead of just a skill id
- Learn a goal-transition model that predicts how goals evolve
- Plan or adapt dynamically using these learned latent subgoals.
This is useful for:
- Compositional behavior
- Skill planning
- Multi-task learning
- Hierarchical exploration

## Background

-**Goal-conditioned Policy**: A policy $\pi(a|s,g)$ takes a latent goal as input and produces an action to reach it.
-**Goal Space**: Instead of using raw coordinates or pixel targetsm we learn a goal space $g\in\mathcal{Z}$, via an encoder. This makes the representation compact and transferable.
- **Latent goal Transitions**: We model how latent goals evolve, using a simple feedforward network:
$$g_{t+1} = f(g_t, s_t, a_t)$$
- **Reward Signal**: Instead of hand-crafted rewards, we can use:
    - Distance in latent space between achieved and target goal
    - Discriminator scores
    - Contrastive learning signals
- **Hierarchy**: The high-level policy picks or predicts a new latent goal every k steps. The low-level policy is trained to reach this goal.

## Theory

**Latent Goal Model**

WE learn a represetnation $g = E(s)$ from a goal encoder $E$, where:
- s is an observation
- $g \in \mathbb{R}^d$ is a vector in goal space

**Goal-Conditioned Policy**

The agent maximizes the probability of achieving a goal $g$, optionally using entropy regularization (SAC-style):
$$\max_\pi \mathbb{E}_{s,a,g}[r(s,g)+\alpha \mathcal{H}(\pi(\cdot|s,g))]$$
Where:
- $r(s,g) = - ||E(s)-g||^2$ is a distance-based reward
- $\mathcal{H}$ is the entropy of the policy 

## Mathematical Formulation

- $s_t$ environment state
- $g_t$ latent goal
- $E(s)$ encoder mapping obeservations to latent goal space
- $\pi(a|s,g)$ goal-conditioned policy
- $f(g_t,s_t,a_t)$ goal transition function
- $Q(s,a,g)$ goal-conditioned critic

**Goal-conditioned Reward**:
$$r(s_t,g) - ||E(s_t)-g||^2_2$$
**High-level Transition**:
$$g_{t+1} = f(g_t,s_t,a_t)$$
**SAC-style critic loss**:
$$\mathcal{L}_{Q} = (Q(s_t,g,a_t) - [r(s_t,g) + \gamma \mathbb{E}_{a'}Q(s_{t+1},g,a')-\alpha \log \pi(a'|s_{t+1},g)])^2$$

## Implementation

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt

In [2]:
class GoalEncoder(nn.Module):
    def __init__(self, input_shape=(3, 96, 96), latent_dim=64):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2), nn.ReLU()
        )

        with torch.no_grad():
            dummy = torch.zeros(1, *input_shape)
            conv_out = self.conv(dummy).view(1, -1).shape[1]

        self.fc = nn.Sequential(
            nn.Linear(conv_out, latent_dim),
            nn.ReLU()
        )

    def forward(self, x):
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float32)
        if x.ndim == 3:
            x = x.permute(2, 0, 1).unsqueeze(0)
        elif x.ndim == 4 and x.shape[-1] == 3:
            x = x.permute(0, 3, 1, 2)
        x = x / 255.0
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [3]:
class GoalConditionedPolicy(nn.Module):
    def __init__(self, obs_dim, goal_dim, action_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(128, 256), nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, state, goal):
        if goal.ndim == 3:
            goal = goal.squeeze(1)
        x = torch.cat([state, goal], dim=-1)
        return torch.tanh(self.fc(x))


In [4]:
class GoalConditionedCritic(nn.Module):
    def __init__(self, obs_dim, action_dim, goal_dim):
        super().__init__()
        self.q1 = nn.Sequential(
            nn.Linear(obs_dim + action_dim + goal_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(obs_dim + action_dim + goal_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, state, action, goal):
        x = torch.cat([state, action, goal], dim=-1)
        return self.q1(x), self.q2(x)


In [5]:
class GoalCritic(nn.Module):
    def __init__(self, state_dim, action_dim, goal_dim):
        super().__init__()
        self.q1 = nn.Sequential(
            nn.Linear(state_dim + action_dim + goal_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(state_dim + action_dim + goal_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, state, action, goal):
        if state.ndim == 4:
            state = state.view(state.size(0), -1)
        if goal.ndim == 3:
            goal = goal.squeeze(1)
        x = torch.cat([state, action, goal], dim=-1)
        return self.q1(x), self.q2(x)


In [13]:
import random
from collections import deque

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done, goal):
        self.buffer.append((state, action, reward, next_state, done, goal))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done, goal = zip(*batch)

        def process(x):
            x = torch.tensor(np.array(x), dtype=torch.float32)
            if x.ndim == 4 and x.shape[-1] == 3:
                x = x.permute(0, 3, 1, 2)
            return x / 255.0

        return (process(state), torch.tensor(action, dtype=torch.float32),
                torch.tensor(reward, dtype=torch.float32).unsqueeze(1),
                process(next_state),
                torch.tensor(done, dtype=torch.float32).unsqueeze(1),
                process(goal))
    
    def __len__(self):
        return len(self.buffer)


In [14]:
def latent_goal_reward(achieved, desired):
    return -((achieved-desired)**2).sum(dim=1, keepdim=True)

In [15]:
def train(agent, buffer, encoder, critic, target_critic, policy, critic_optimizer, policy_optimizer, alpha=0.2, gamma=0.99, tau=0.005, batch_size=128):
    if len(buffer) < batch_size:
        return 0
    
    state, action, _, next_state, done, goal = buffer.sample(batch_size)

    with torch.no_grad():
        if next_state.ndim == 4 and next_state.shape[1] != 3:
            next_state = next_state.permute(0, 3, 1, 2)  # Fix channel order
        next_goal = encoder(next_state)
        next_action = policy(next_state, next_goal)
        target_q1, target_q2 = target_critic(next_state, next_action, next_goal)
        target_q = torch.min(target_q1, target_q2)
        target_val = latent_goal_reward(encoder(state), goal)+(1-done)*gamma*target_q

    q1, q2 = critic(state, action, goal)
    critic_loss = F.mse_loss(q1, target_val) + F.mse_loss(q2, target_val)

    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()

    new_action = policy(state, goal)
    new_q1, new_q2 = critic(state, new_action, goal)
    new_q = torch.min(new_q1, new_q2)

    policy_loss = (alpha * torch.log(torch.clamp(torch.ones_like(new_q) - new_q, min=1e-6)) - new_q).mean()

    policy_optimizer.zero_grad()
    policy_loss.backward()
    policy_optimizer.step()

    for param, target_param in zip(critic.parameters(), target_critic.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

In [16]:
def sample_goal_from_obs(obs, encoder):
    obs = torch.tensor(obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0)
    with torch.no_grad():
        goal = encoder(obs).squeeze(0)
    return goal

In [17]:
def collect_episodes(env, agent, encoder, buffer, max_steps=200):
    obs, _ = env.reset()
    goal = sample_goal_from_obs(obs, encoder)
    tot_reward = 0

    for _ in range(max_steps):
        obs_tensor = torch.tensor(obs, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
        with torch.no_grad():
            goal_tensor = goal.unsqueeze(0)
            action = agent.act(obs_tensor, goal_tensor)

        next_obs, _, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        buffer.push(obs, action, 0.0, next_obs, done, goal)
        tot_reward += -np.linalg.norm(encoder(torch.tensor(next_obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0)).detach().numpy() - goal.numpy())
        
        obs = next_obs
        if done:
            break

    return tot_reward

In [18]:
class GoalAgent:
    def __init__(self, encoder, policy):
        self.encoder = encoder
        self.policy = policy

    def act(self, obs, goal):
        obs_tensor = torch.tensor(obs, dtype=torch.float32)
        if obs_tensor.ndim == 3:
            obs_tensor = obs_tensor.permute(2, 0, 1).unsqueeze(0)
        elif obs_tensor.ndim == 4 and obs_tensor.shape[-1] == 3:
            obs_tensor = obs_tensor.permute(0, 3, 1, 2)

        obs_tensor = obs_tensor / 255.0
        latent_obs = self.encoder(obs_tensor)
        action = self.policy(latent_obs, goal)
        return action.squeeze(0).detach().cpu().numpy()


In [None]:
env = gym.make('CarRacing-v2', render_mode='rgb_array')
obs_shape = (3, 96, 96)
action_dim = env.action_space.shape[0]
encoder = GoalEncoder(input_shape=obs_shape)
policy = GoalConditionedPolicy(obs_shape[0], 64, action_dim)
agent = GoalAgent(encoder, policy)
buffer = ReplayBuffer()
critic = GoalCritic(obs_shape[0], action_dim, 64)
target_critic = GoalCritic(obs_shape[0], action_dim, 64)
target_critic.load_state_dict(critic.state_dict())
critc_optimizer = optim.Adam(critic.parameters(), lr=3e-4)
policy_optimizer = optim.Adam(policy.parameters(), lr=3e-4)

rewards = []

episodes = 200
for ep in range(episodes):
    ep_reward = collect_episodes(env, agent, agent.encoder, buffer)
    rewards.append(ep_reward)

    for _ in range(5):
        train(agent, buffer, encoder, critic, target_critic, policy, critc_optimizer, policy_optimizer)
    if ep % 10 == 0:
        print(f"Episode {ep}, Reward: {ep_reward:.2f}")

plt.plot(rewards)
plt.xlabel('Episodes')
plt.ylabel('Reward')
plt.title('Training Rewards')
plt.show()
env.close()