# Option-Critic

Addresses the fundamental limitation of RL methods: their inabiliy to naturally handle temporal abstraction. Humans chunk complex tasks into simpler subtasks; Option-Critic formalizes this by introducing options, temporally extended actions, to traditional RL frameworks.


Options are defined by three components:
- a policy $\pi°$: specifies how the agent acts while following an option
- Initiation set $I°$: States from which an option may begin
- Termination Condition $\beta°$: Probability of an option terminating at each state.
The Option-Critic architecture simultaneously learns:
- Option policies $\pi°$ deciding actions during an option
- Termination conditions $\beta°$ when to stop following an option
- Policy over options $\pi\_\Omega$ how to select among multiple options.

## Math

An option $o$ is defined formally as a triple:
$$o = <I°, \pi°(a|s), \beta°(s)>$$
- $I°\sube S$: initiation set of options
- $\pi°(a|s)$: policy within the option
- $\beta°(s) \in [0,1]$: termination condition of the option  

Option-Critic Objective:

Given a set of options $O$, the objective is maximizing expected return:
$$J(\theta) = \mathbb{E}[\sum_{t=0}^\infty \gamma^t r(s_t,a_t)]$$


Bellman Equation for Options:

The option-value function $Q_\Omega(s,o)$:
$$Q_\Omega(s,o) = \sum_a\pi°(a|s)[r(s,a)+\gamma\sum_{s'}P(s'|s,a)U_\Omega(s',o)]$$
Where the option utility function $U_\Omega(s,o)$ is defined as:
$$U_\Omega(s',o) = (1-\beta°(s'))Q_\Omega(s',o) + \beta°(s')V_\Omega(s')$$
Where $V_\Omega(s')$ is the state-value function for the option:
$$V_\Omega(s',o) = \sum_o \pi_\Omega(o|s)Q_\Omega(s,o)$$

Policy Gradient for Options:
The policy gradient for options can be derived as:
- For intra-option policy $\pi°$ (policy within the option):
$$\nabla_{\theta_\beta°}J = \mathbb{E}[\sum_t \gamma^t \nabla_{\theta_\beta°} \beta°(s_{t+1})-Q_\Omega(s_{t+1}, o_t)]$$
These gradients adjust how the agent picks actions an when it terminates options, based on expected advantage.

## Implementation

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim

# Option policy network per option
class OptionPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, num_options):
        super().__init__()
        self.option_nets = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_dim, 128),
                nn.ReLU(),
                nn.Linear(128, action_dim),
                nn.Softmax(dim=-1)
            ) for _ in range(num_options)
        ])

    def forward(self, state, option):
        return self.option_nets[option](state)

# Termination network
class Termination(nn.Module):
    def __init__(self, state_dim, num_options):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_options),
            nn.Sigmoid()
        )

    def forward(self, state):
        return self.net(state)

# Q-value over options and high-level option policy
class OptionCritic(nn.Module):
    def __init__(self, state_dim, action_dim, num_options):
        super().__init__()
        self.num_options = num_options
        self.q_option = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(),
            nn.Linear(128, num_options)
        )
        self.pi_o = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(),
            nn.Linear(128, num_options),
            nn.Softmax(dim=-1)
        )
        self.termination = Termination(state_dim, num_options)
        self.option_policy = OptionPolicy(state_dim, action_dim, num_options)

    def get_action(self, state, option):
        with torch.no_grad():
            probs = self.option_policy(state, option)
            return torch.multinomial(probs, 1).item()

    def get_option(self, state):
        with torch.no_grad():
            option_probs = self.pi_o(state)
            return torch.multinomial(option_probs, 1).item()

In [6]:
from collections import deque
import random
import numpy as np

class ReplayBuffer:
    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)

    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        return map(np.array, zip(*batch))

In [7]:
class OptionCriticTrainer:
    def __init__(self, env, model, buffer, batch_size=64, gamma=0.99):
        self.env = env
        self.model = model
        self.buffer = buffer
        self.batch_size = batch_size
        self.gamma = gamma
        self.optimizer = optim.Adam(model.parameters(), lr=0.001)

    def train_step(self):
        if len(self.buffer.buffer) < self.batch_size:
            return
        
        states, actions, rewards, next_states, dones, options = self.buffer.sample(self.batch_size)
        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.int64).unsqueeze(1)
        rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1)
        next_states = torch.tensor(next_states, dtype=torch.float32)
        dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)
        options = torch.tensor(options, dtype=torch.int64).unsqueeze(1)

        q_values = self.model.q_option(states).gather(1,options)
        next_q = self.model.q_option(next_states).max(1, keepdim=True)[0].detach()
        targets = rewards + self.gamma * (1-dones) * next_q

        loss = nn.MSELoss()(q_values, targets)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [8]:
class DummyEnv:
    def __init__(self, state_dim, action_dim):
        self.state_dim = state_dim
        self.action_dim = action_dim

    def reset(self):
        return np.random.randn(self.state_dim)

    def step(self, action):
        next_state = np.random.randn(self.state_dim)
        reward = np.random.rand()
        done = np.random.rand() > 0.95
        return next_state, reward, done, {}

def run_training(env, agent, trainer, episodes=100):
    for episode in range(episodes):
        state = torch.tensor(env.reset(), dtype=torch.float32)
        option = agent.get_option(state)

        for t in range(200):
            action = agent.get_action(state, option)
            next_state_np, reward, done, _ = env.step(action)
            next_state = torch.tensor(next_state_np, dtype=torch.float32)

            # Termination decision
            term_prob = agent.termination(next_state)[option].item()
            terminate = torch.bernoulli(torch.tensor(term_prob)).item() == 1.0

            trainer.buffer.push((state.numpy(), action, reward, next_state.numpy(), done, option))
            trainer.train_step()

            if done or terminate:
                option = agent.get_option(next_state)
            state = next_state

            if done:
                break

# Setup
state_dim = 10
action_dim = 4
num_options = 2

env = DummyEnv(state_dim, action_dim)
model = OptionCritic(state_dim, action_dim, num_options)
buffer = ReplayBuffer()
trainer = OptionCriticTrainer(env, model, buffer)

run_training(env, model, trainer, episodes=10)
