# Hyerarchical RL

Standard RL learns a monolithic policy to map states to actions. However, in complex, long-horizon environments, such flat policies struggle with sparse rewards, Learn inefficient exploration strategies and fail to transfer knowledge across tasks.
Hierarchical RL addresses this by decomposing decision-making into multiple levels of abstractions, allowing an agent to reason over long time horizons and reuse behaviors across contexts.

## Hierarchical Decomposition
HRL proposes a decomposition of the agent's behavior into two(o more) levels of policies:
- **Low-level policy (Worker)**: executes the chosen sub-policy or primitive actions.
- **High-level policy (Manager)**: selects goals, sub-policies or options.
This supports temporal abstraction: high-level decisions are made less frequently and operate over extended periodso of time.

## Option Framework
The most common formalism in HRL ius the Options Framework. It extends the standard MDP to inclde temporally extended actions, called options.
### Options
An option $\omega$ is a tuple:
$$\omega = <I_\omega, \pi_\omega, \beta_\omega>$$
- $I_\omega \subseteq S$: initiation set.
- $\pi_\omega(a|s)$: intra-policy, policy used while the option is active.
- $\beta_\omega(s) \rightarrow [0,1]$: termination condition, the probability the option ends in state s. 

## Semi-Markov Decision Process (SMDP)
The use of temporally extended options means we must move from standard MDPs to Semi-Markov Decision Processes, where actions(options) may last for multiple time steps.
The agent makes decisions at time t, then executes a chosen option $\omega_t$ for k steps,until termination condition is met.

## Value function in HRL
Let $\Omega$ be the set of all options. The high-level policy chooses options $\omega \in \Omega$.
We define:
- $Q_\Omega(s,\omega)$: expected return starting from state s, choosing option $\omega$, and acting optimally afterward.
- SMDP Q-value Bellman:
$$Q_\Omega(s,\omega) = \mathbb{E}[\sum_{t=0}^{k-1}\gamma^t r_t + \gamma^k V_\Omega(s')]$$
where $s'$ is the state where the option terminates after k steps.
- Intra-option Bellman:
We can define an intra-option Q-function $Q(s,a,\omega)$ as:
$$Q(s,a,\omega) = r(s,a) + \gamma \mathbb{E}_{s'}[(1-\beta_\omega(s'))Q_\Omega(s',\omega) + \beta_\omega(s')\max_{\omega'}Q_\Omega(s',\omega')]$$
This formula is used in the Option-Critic architecture.

## Policy Optimization
We aim to learn:
- The high-level policy $\pi_\Omega(\omega|s)$.
- The low-level policy $\pi_\omega(a|s)$.
- The termination function $\beta_\omega(s)$.

Each component can be parameterized using NN and optimized using gradient-based methods.

### Intra-option policy
Leg $\theta$ parametrize $\pi_\omega(a|s)$. The intra-option policy gradient is:
$$\nabla_\theta J = \mathbb{E}[\sum_t \nabla_\theta \log \pi_{\omega_t}(a_t|s_t) \cdot Q(s_t, a_t, \omega_t)]$$
### Termination function
The termination gradient encourages continuing an option if it's successful:
$$\nabla_\phi J=\mathbb{E}[\nabla_\phi\beta_\omega(s_t) \cdot (Q_\Omega(s_t, \omega)-V_\Omega(s_t))]$$

## Temporal Abstraction and Efficiency

A key idea in HRL is acting over multiple time scales:
- High-level decisions (every k steps) guide long-term planning
- Low-level policies (every step) execute short-term actions.
This leads to:
- Better exploration
- More efficient credit assignment
- Modular skills that can be reused across tasks

## Skill Discovery
Skills (options) can be predefined or discovered autonomously. Autonomous discovery methods include:
- Diversity-driven: DIAYN, HAC
- Information-theoretic: maximize mutual information between skills and outcomes
- Clustering trajectories: detect patterns and segment skills

## Implementation

### Option-Critic Architecture
Introduces temporal abstraction by modeling options, which are like skills or macro-actions.
Each option includes:
- its own policy
- termination condition
- high-level policy that selects which option to run

**Markov Option $\omega = (\mathcal{I}_\omega, \pi_\omega, \beta_\omega)$**
- $\mathcal{I}_\omega$: initiation set
- $\pi_\omega$: policy
- $\beta_\omega$: termination condition
Agent selects an option $\omega$ at each time step, then executes it until termination.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import random

In [None]:
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, num_options):
        super().__init__()
        self.q_net = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_dim + action_dim, 128),
                nn.ReLU(),
                nn.Linear(128, 1)
            ) for _ in range(num_options)
        ])

    def forward(self, state, action, option):
        x = torch.cat([state, action], dim=-1)
        return self.q_net[option](x)

In [None]:
class IntraOption(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.intra_option = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )

    def forward(self, state):
        return self.intra_option(state)

In [None]:
class Termination(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, state):
        return self.net(state)

In [None]:
class Manager(nn.Module):
    def __init__(self, state_dim, num_options):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_options),
        )

    def forward(self, state):
        logits = self.net(state)
        probs = F.softmax(logits, dim=-1)
        return probs

In [None]:
class OptionCriticAgent:
    def __init__(self, state_dim,action_dim, num_options):
        self.num_options = num_options
        self.actor_heads = nn.ModuleList([IntraOption(state_dim, action_dim) for _ in range(num_options)])
        self.termination_heads = nn.ModuleList([Termination(state_dim) for _ in range(num_options)])
        self.critic = Critic(state_dim, action_dim, num_options)
        self.manager = Manager(state_dim, num_options)

        self.optim = optim.Adam(
            list(self.critic.parameters()) +
            list(self.manager.parameters()) +
            [p for net in self.actor_heads for p in net.parameters()]+
            [p for net in self.termination_heads for p in net.parameters()],
            lr = 1e-4
        )

        self.current_option = None

    def select_action(self, state):
        state_tensor= torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        option_probs = self.manager(state_tensor)
        dist = torch.distributions.Categorical(option_probs)
        return dist.sample().item()
    
    def should_terminate(self, state, option):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        termination_prob =  self.termination_heads[option](state_tensor)
        return torch.rand(1).item() < termination_prob.item()
    
    def select_action(self, state, option):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        action = self.actor_heads[option](state_tensor)
        return action.detach().numpy()
    
    def train(self, state, action, reward, next_state, done, option):
        self.optim.zero_grad()

        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        action_tensor = torch.tensor(action, dtype=torch.float32).unsqueeze(0)
        next_state_tensor = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)
        reward_tensor = torch.tensor([reward], dtype=torch.float32)

        with torch.no_grad():
            next_action = self.actor_heads[option](next_state_tensor)
            q_next = self.critic(next_state_tensor, next_action, option)
            target = reward_tensor + 0.99 * q_next * (1 - int(done))

        q_pred = self.critic(state_tensor, action_tensor, option)
        critic_loss = nn.MSELoss()(q_pred, target)

        pred_action = self.actor_heads[option](state_tensor)
        q_val = self.critic(state_tensor, pred_action, option)
        actor_loss = -q_val.mean()

        beta = self.termination_heads[option](state_tensor)
        q_option = self.critic(next_state_tensor, self.actor_heads[option](next_state_tensor), option)
        v = sum([self.critic(next_state_tensor, self.actor_heads[o](next_state_tensor), o) for o in range(self.num_options)]) / self.num_options
        termination_loss = beta * (q_option - v).detach()

        total_loss = actor_loss + critic_loss + termination_loss.mean()
        total_loss.backward()
        self.optim.step()

In [None]:
env= gym.make("Pendulum-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
num_options = 2

agent = OptionCriticAgent(state_dim, action_dim, num_options)

for ep in range(300):
    state, _ = env.reset()
    agent.current_option = agent.select_action(state)
    total_reward = 0

    for step in range(200):
        if agent.should_terminate(state, agent.current_option):
            agent.current_option = agent.select_action(state)

        action = agent.select_action(state, agent.current_option)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        agent.train(state, action, reward, next_state, done, agent.current_option)

        state = next_state
        total_reward += reward
        if done: break

    if ep %10 == 0:
        print("Episode: {}, Total Reward: {}".format(ep, total_reward:.2f))


### Soft Option-Critic
Implement Soft Q-learning for better exploration and stability.
Soft Q-function:
$$Q_\omega(s,a) = r + \gamma \mathbb{E}_{s'}[(1-\beta_\omega(s'))V_\omega(s') + \beta_\omega(s')V(s')]$$

In [None]:
class SoftOptionActor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(),
            nn.Linear(128,127), nn.ReLU(),
        )
        self.mean = nn.Linear(128, action_dim)
        self.log_std = nn.Linear(128, action_dim)

    def forward(self, state):
        x = self.fc(state)
        mean = self.mean(x)
        log_std = torch.clamp(self.log_std(x), -20,2)
        std = torch.exp(log_std)
        return mean, std
    
    def sample(self, state):
        mean, std = self.forward(state)
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        action = torch.tanh(x_t)
        log_prob = normal.log_prob(x_t).sum(dim=-1)
        log_prob -= torch.log(1-action.pow(2)+1e-6).sum(dim=-1)
        return action, log_prob

In [None]:
class SoftOptionCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.q1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128), nn.ReLU(),
            nn.Linear(128, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(state_dim+action_dim, 128), nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, state, action):
        sa = torch.cat([state, action], dim=-1)
        return self.q1(sa), self.q2(sa)

In [None]:
class TerminationHead(nn.Sequential):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(),
            nn.Linear(128, 1), nn.Sigmoid()
        )

    def forward(self,state):
        return self.net(state)

In [None]:
class OptionValue(nn.Module):
    def __init__(self, state_dim, num_options):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(),
            nn.Linear(128, num_options)
        )

    def forward(self, state):
        return self.net(state)

In [None]:
class SoftOptionCriticAgent:
    def __init__(self, state_dim, action_dim, num_options, alpha=0.2):
        self.num_options = num_options
        self.alpha = alpha

        self.actors = nn.ModuleList([SoftOptionActor(state_dim, action_dim) for _ in range(num_options)])
        self.critics = nn.ModuleList([SoftOptionCritic(state_dim, action_dim) for _ in range(num_options)])
        self.target_critics = nn.ModuleList([SoftOptionCritic(state_dim, action_dim) for _ in range(num_options)])
        self.terminations = nn.ModuleList([TerminationHead(state_dim) for _ in range(num_options)])
        self.option_value = OptionValue(state_dim, num_options)

        self.optimizers = []
        for i in range(num_options):
            self.target_critics[i].load_state_dict(self.critics[i].state_dict())
            self.optimizers.append(optim.Adam(
                list(self.actors[i].parameters()) +
                list(self.critics[i].parameters()) +
                list(self.terminations[i].parameters()),
                lr = 3e-4
            ))
        
        self.value_optimizer = optim.Adam(self.option_value.parameters(), lr=3e-4)
        self.current_option = None

    def select_option(self, state):
        with torch.no_grad():
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            option_vals = self.option_value(state_tensor)
            probs = torch.softmax(option_vals /self.alpha, dim=-1)
            dist = torch.distributions.Categorical(probs)
            return dist.sample().item()
        
    def should_terminate(self, state, option):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        action, _ = self.actors[option].sample(state_tensor)
        return action.squeeze(0).detach().numpy()
    
    def select_action(self, state, option):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        action, _ = self.actors[option].sample(state_tensor)
        return action.squeeze(0).detach().numpy()

    def train(self, batch, option):
        states, actions, rewards, next_states, dones = batch

        with torch.no_grad():
            next_actions, next_log_probs = self.actors[option].sample(next_states)
            q1_target, q2_target = self.target_critics[option](next_states, next_actions)
            q_min = torch.min(q1_target, q2_target)
            target_q = rewards + (1-dones) * 0.99 * (q_min - self.alpha * next_log_probs)

        q1, q2 = self.critics[option](states, actions)
        critic_loss = nn.MSELoss()(q1, target_q) + nn.MSELoss()(q2, target_q)

        new_actions, new_log_probs = self.actors[option].sample(states)
        q1_pi, q2_pi = self.critics[option](states, new_actions)
        q_pi = torch.min(q1_pi, q2_pi)
        actor_loss = (self.alpha * new_log_probs - q_pi).mean()

        with torch.no_grad():
            current_v = self.option_value(states).gather(1, torch.tensor([[option]]*states.size(0)))

        termination = self.terminations[option](states)
        termination_loss = termination * (q_pi.detach()-current_v).detach()
        termination_loss = termination_loss.mean()

        total_loss=actor_loss + critic_loss + termination_loss

        self.optimizers[option].zero_grad()
        total_loss.backward()
        self.optimizers[option].step()


In [None]:
class ReplayBuffer:
    def __init__(self, max_size=1000000):
        self.buffer = []
        self.max_size = max_size
        
    def add(self, *transition):
        if len(self.buffer) >= self.max_size:
            self.buffer.pop(0)
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.tensor(states, dtype=torch.float32),
            torch.tensor(actions, dtype=torch.float32),
            torch.tensor(rewards, dtype=torch.float32).unsqueeze(-1),
            torch.tensor(next_states, dtype=torch.float32),
            torch.tensor(dones, dtype=torch.float32).unsqueeze(-1)
        )

In [None]:
env = gym.make("Pendulum-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
num_options = 2

agent = SoftOptionCriticAgent(state_dim, action_dim, num_options)
replay_buffer = ReplayBuffer()
total_rewards=[]

episodes = 200
batch_size = 64
warmup_steps = 1000
update_every = 1

for ep in range(episodes):
    state, _ = env.reset()
    option = agent.select_option(state)
    ep_reward = 0

    for step in range(200):
        if agent.should_terminate(state, option):
            option = agent.select_option(state)
        
        action = agent.select_action(state, option)

        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        replay_buffer.add(state, action, reward, next_state, done)

        state = next_state
        ep_reward += reward

        if len(replay_buffer.buffer) > warmup_steps and step % update_every == 0:
            batch = replay_buffer.sample(batch_size)
            agent.train(batch, option)

        if done: break

    total_rewards.append(ep_reward)
    if ep % 10 == 0:
        print("Episode: {}, Total Reward: {}".format(ep, ep_reward))

plt.plot(total_rewards)
plt.show()

### Feudal RL and DIAYN (Diversity is All You Need)

#### **Feudal RL**
Introduces a hierachical structure where a manager sets goals, and a worker executes actions to fullfill them.
The key idea is to decompose control into high-level goals and low-level actions.
The worker policy is:
$$\pi(a_t | s_t, g_t)$$
where $g_t$ is the goal set by the manager.
The manager emits a goal every k steps:
$$g_t = f_m(s_t)$$
the intrinsic Reward(Feudal Signal) is given based on how much progress it was made towards the goal:
$$r_t^{intr} = g_t^T(f(s_{t+1}-f(s_t)))$$
Where:
- $f(\cdot)$ is a state encoder that maps raw states into a latent space
- $g_t$ the goal vector
- $f(s_{t+1}-f(s_t))$ is the vector of change in latent features
the inner product measures alignment of movement and goal- the more the agent moves in the goal's direction, the higher the reward.
**Optimization Objective**
Worker:
$$J_w = \mathbb{E}[\sum_t r_t^{intr}]$$
Manager:
$$J_m = \mathbb{E}[\sum_t r_t^{extr}]$$
where $r_t^{extr}$ is the extrinsic reward.

In [None]:
class ConvEncoder(nn.Module):
    def __init__(self, image_shape, latent_dim):
        super().__init__()
        c,h,w = image_shape
        self.conv = nn.Sequential(
            nn.Conv2d(c, 32, 8,1), nn.ReLU(),
            nn.Conv2d(32, 64, 4,2), nn.ReLU(),
            nn.Conv2d(64, 64, 3,1), nn.ReLU()
        )

        with torch.no_grad():
            dummy = torch.zeros(1,c,h,w)
            conv_out_dim = self.conv(dummy).view(1,-1).shape[1]
        self.fc = nn.Linear(conv_out_dim, latent_dim)

    def forward(self, x):
        x = x / 255.
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        return self.fc(x)

In [None]:
class Manager(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(),
            nn.Linear(128, latent_dim)
        )

    def forward(self, x):
        goal = self.fc(x)
        return F.normalize(goal, dim=-1)

In [None]:
class Worker(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim*2, 128), nn.ReLU(),
            nn.Linear(128, action_dim)
        )

    def forward(self, encoder_state, goal):
        x = torch.cat([encoder_state, goal], dim=-1)
        return torch.tanh(self.fc(x))

In [None]:
def compute_intrinsic_reward(f_st, f_st1, goal):
    delta = f_st1 - f_st
    return torch.sum(goal*delta, dim=-1, keepdim=True)

In [None]:
class FeudalAgent(nn.Module):
    def __init__(self, image_shape, action_dim, latent_dim, goal_interval=10):
        super().__init__()
        self.encoder = ConvEncoder(image_shape, latent_dim)
        self.manager = Manager(latent_dim)
        self.worker = Worker(latent_dim, action_dim)

        self.goal_interval = goal_interval
        self.latent_dim = latent_dim
        self.action_dim = action_dim

        self.encoder_optim = optim.Adam(self.encoder.parameters(), lr=1e-4)
        self.manager_optim = optim.Adam(self.manager.parameters(), lr=1e-4)
        self.worker_optim = optim.Adam(self.worker.parameters(), lr=1e-4)

    def get_goal(self, encoded_state):
        return self.manager(encoded_state)
    
    def get_action(self, encoded_state, goal):
        return self.worker(encoded_state, goal)

In [None]:
env = gym.make("CarRacing-v2", render_mode="rgb_array")
obs_shape = (3,96,96)
agent =FeudalAgent(obs_shape, 3, 64)
rewards = []

for ep in range(100):
    obs, _ = env.reset()
    obs = torch.tensor(obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0)

    f_st = agent.encoder(obs)
    goal = agent.get_goal(f_st)
    ep_reward = 0

    for t in range(1000):
        if t % agent.goal_interval == 0:
            goal = agent.get_goal(f_st.detach())

        action = agent.get_action(f_st, goal)
        np_action = action[0].detach().numpy()
        next_obs, reward, terminated, truncated, _ = env.step(np_action)
        done = terminated or truncated

        next_obs_tensor = torch.tensor(next_obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0)
        f_st1 = agent.encoder(next_obs_tensor)

        #Compute Intrinsic reward
        r_intr = compute_intrinsic_reward(f_st, f_st1, goal)

        #Update Worker
        agent.worker_optim.zero_grad()
        pred_action = agent.get_action(f_st, goal)
        action_loss = F.mse_loss(pred_action, action)
        (-r_intr.mean()+action_loss).backward()
        agent.worker_optim.step()

        #Update Manager
        agent.manager_optim.zero_grad()
        (-torch.tensor(reward)).backward()
        agent.manager_optim.step()

        f_st = f_st1
        ep_reward += reward
        if done: break

    rewards.append(ep_reward)
    if ep % 10 == 0:
        print("Episode: {}, Total Reward: {}".format(ep, ep_reward))


#### DIAYN (Diversity is All You Need)
Proposes a framework where agents learn useful, diverse skills without any extrinsic reward and do so in an unsupervised way. These skills can later be reused for downstream tasks or transfer learning.

DIYAN is based on mutual information:
$$\mathcal{I}(S;Z) = H(S) - H(S|Z)$$
where:
- $S$ is the state
- $Z$ is the skill

DIAYN maximizes this by:
- Ensuring skills are diverse
- Ensuring each skill is predictable from the state

No reward function needed the agent rewards itself for doing things differently across skills.

**Skill-Conditioned Policy**:
$$\pi(a|s,z) \space \text{where} \space z \sim p(z)$$
- One shared policy $\pi(a|s,z)$
- Latent skill z sampled from uniform prior

**Skill Discriminator**:
$$D_\phi(z|s)$$
Learns to classifu the skill z from state s, used to compute the intrinsic reward:
$$r^{intr}(s,z) = \log D_\phi(z|s) - \log p(z)$$
this encourages:
- High skill identifiability
- Maximally diverse behaviors

**Optimization Objective**:
1. Policy:
$$\mathbb{E}_{s,z}[\log D_\phi(z|s)]$$
2. Discriminator:
$$\mathcal{L}_D = -\mathbb{E}_{s,z}[\log D_\phi(z|s)]$$

In [None]:
class ConvEncoder(nn.Module):
    def __init__(self, image_shape, latent_dim):
        super().__init__()
        c, h, w = image_shape
        self.conv = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU()
        )
        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            conv_out_dim = self.conv(dummy).view(1, -1).shape[1]
        self.fc = nn.Linear(conv_out_dim, latent_dim)

    def forward(self, x):
        x = x / 255.0
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [None]:
class SkillPolicy(nn.Module):
    def __init__(self, latent_dim, skill_dim, action_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim+skill_dim, 256), nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, state_feat, skill_onehot):
        x = torch.cat([state_feat, skill_onehot], dim=-1)
        return torch.tanh(self.fc(x))

In [None]:
class Discriminator(nn.Module):
    def __init__(self, latent_dim, skill_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 256), nn.ReLU(),
            nn.Linear(256, skill_dim)
        )

    def forward(self, state_feat):
        return self.fc(state_feat)

In [None]:
class DIAYNAgent:
    def __init__(self,image_shape, action_dim, latent_dim=64, skill_dim=10, lr=1e-4):
        self.latent_dim = latent_dim
        self.skill_dim = skill_dim

        self.encoder = ConvEncoder(image_shape, latent_dim)
        self.skill_policy = SkillPolicy(latent_dim, skill_dim, action_dim)
        self.discriminator = Discriminator(latent_dim, skill_dim)

        self.policy_optim = optim.Adam(list(self.encoder.parameters())+list(self.skill_policy.parameters()), lr=lr)
        self.discriminator_optim = optim.Adam(self.discriminator.parameters(), lr=lr)
        self.skill_dist = torch.distributions.Categorical(torch.ones(skill_dim)/skill_dim)

    def sample_skill(self, batch_size=1):
        return self.skill_dist.sample((batch_size,))
    
    def one_hot(self, skill_idx):
        return F.one_hot(skill_idx, num_classes=self.skill_dim).float()
    
    def get_action(self, obs, skill_idx):
        obs_tensor = torch.tensor(obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0)
        with torch.no_grad():
            feat = self.encoder(obs_tensor)
            skill_onehot = self.one_hot(torch.tensor([skill_idx]))
            action = self.skill_policy(feat, skill_onehot)
        return action[0].detach().numpy()
    
    def compute_intrinsic_reward(self, feat, skill_idx):
        with torch.no_grad():
            logits = self.discriminator(feat)
            log_probs = F.log_softmax(logits, dim=-1)
            log_pz = torch.log(torch.tensor(1.0/self.skill_dim))
            return log_probs[0, skill_idx]-log_pz
        

In [None]:
env = gym.make("CarRacing-v2", render_mode="rgb_array")
image_shape = (3,96,96)
action_dim = env.action_space.shape[0]
agent = DIAYNAgent(image_shape, action_dim)

rewards = []

for ep in range(100):
    skill = agent.sample_skill().item()
    obs, _ = env.reset()
    ep_reward = 0

    for t in range(500):
        action = agent.get_action(obs, skill)
        next_obs, _, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        obs_tensor = torch.tensor(obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0)
        next_obs_tensor = torch.tensor(next_obs, dtype=torch.float32).permute(2,0,1).unsqueeze(0)
        feat = agent.encoder(obs_tensor)
        next_feat = agent.encoder(next_obs_tensor)

        logits = agent.discriminator(next_feat)
        target = torch.tensor([skill], dtype=torch.long)
        disc_loss = F.cross_entropy(logits, target)

        agent.discriminator_optim.zero_grad()
        disc_loss.backward()
        agent.discriminator_optim.step()

        logits = agent.discriminator(next_feat)
        log_probs = F.log_softmax(logits, dim=-1)
        intrinsic_reward = log_probs[0, skill]-np.log(1.0/agent.skill_dim)

        action_pred = agent.skill_policy(feat, agent.one_hot(torch.tensor([skill])))
        policy_loss = -intrinsic_reward

        agent.policy_optim.zero_grad()
        policy_loss.backward()
        agent.policy_optim.step()

        obs = next_obs
        ep_reward += intrinsic_reward
        if done: break

    rewards.append(ep_reward)
    if ep % 10 == 0:
        print("Episode: {}, Total Reward: {}".format(ep, ep_reward))