# Soft-Option Critic (SOC)

Extends Option-Critic by incorporating maximum entropy RL principles. This means instead of just maximizing return, the agent also tries to maximize entropy, promoting stochastic and exploratory behaviors.\
Introduces these key ideas:
- Soft Q-learning: Objective instead of standard Q-learning
- Options that are optimized using soft value functions
- Encouraging diverse behaviors through entropy regularization


## Background

### Maximum Entropy Reinforcement Learning
Instead of maximizing cumulative reward R, we maximize:
$$\mathbb{E}[\sum_t r(s_t, a_t)+\alpha \mathcal{H}(\pi(\cdot|s_t))]$$
Where $\mathcal{H}$ is the entropy and $\alpha$ is a temperature parameter controlling the entropy/reward trade-off.

### SAC
Off-policy actor-critic method that uses stochastic policies and entropy regularization.\
Soft Q-value is:
$$Q^\pi(s,a)=r(s,a)+\gamma\mathbb{E}_{s'}[V^\pi(s')]$$
$$V^\pi(s)=\mathbb{E}_{a\sim\pi}[Q^\pi(s,a)-\alpha\log\pi(a|s)]$$


## Math

### Soft Option Value Function
$$Q_\Omega(s,o) = \mathbb{E}_{a\sim\pi°(\cdot|s)}[Q_U(s,p,a)-\log\pi°(a|s)]$$
### Option Termination
$$U(s',o) = (1-\beta°(s'))Q_\Omega(s',o)+\beta°(s')V_\Omega(s')$$
### State Value with Options
$$V_\Omega(s) = \mathbb{E}_{o\sim\pi°(\cdot|s)}[Q_\Omega(s,o)-\log\pi°(o|s)]$$

## Implementation

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class SoftOptionPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, num_options):
        super().__init__()
        self.policies = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.ReLU(),
                nn.Linear(64, action_dim),
            ) for _ in range(num_options)
        ])

    def forward(self, state,option):
        logits = self.policies[option](state)
        probs = F.softmax(logits, dim=-1)
        log_probs = F.log_softmax(logits, dim=-1)
        return probs, log_probs

class SoftTermination(nn.Module):
    def __init__(self, state_dim, num_options):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_options),
            nn.Sigmoid()
        )

    def forward(self, state):
        return self.net(state)

class SoftOptionCritic(nn.Module):
    def __init__(self, state_dim, action_dim, num_options):
        super().__init__()
        self.q_option = nn.Sequential(
            nn.Linear(state_dim + action_dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_options),
        )
        self.option_policy = SoftOptionPolicy(state_dim, action_dim, num_options)
        self.termination = SoftTermination(state_dim, num_options)
        self.pi_o = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_options),
        )

    def get_option_probs(self, state):
        logits = self.pi_o(state)
        return F.softmax(logits, dim=-1), F.log_softmax(logits, dim=-1)

    def get_action(self, state, option):
        with torch.no_grad():
            probs, _ = self.option_policy(state, option)
            return torch.multinomial(probs, 1).item()

    def get_option(self, state):
        with torch.no_grad():
            probs, _ = self.get_option_probs(state)
            return torch.multinomial(probs, 1).item()

In [3]:
import gymnasium as gym
import numpy as np
import random
from collections import deque
# Replay buffer
class ReplayBuffer:
    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)

    def push(self, *transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        return map(np.array, zip(*batch))
    
# Training hyperparameters
gamma = 0.99
alpha = 0.1
batch_size = 64
num_options = 2
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
model = SoftOptionCritic(state_dim, action_dim, num_options)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
buffer = ReplayBuffer()

# Training loop
for episode in range(500):
    state = torch.tensor(env.reset()[0], dtype=torch.float32)
    option = model.get_option(state)

    for t in range(200):
        action = model.get_action(state, option)
        next_state_np, reward, done, _, _ = env.step(action)
        next_state = torch.tensor(next_state_np, dtype=torch.float32)

        # Store transition
        buffer.push(state.numpy(), action, reward, next_state.numpy(), float(done), option)
        state = next_state

        # Train
        if len(buffer.buffer) >= batch_size:
            states, actions, rewards, next_states, dones, options = buffer.sample(batch_size)
            states = torch.tensor(states, dtype=torch.float32)
            actions = torch.tensor(actions, dtype=torch.int64).unsqueeze(1)
            rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1)
            next_states = torch.tensor(next_states, dtype=torch.float32)
            dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)
            options = torch.tensor(options, dtype=torch.int64).unsqueeze(1)

            # Q-values
            input = torch.cat([states, actions], dim=-1)
            q_vals = model.q_option(input).gather(1, options)

            with torch.no_grad():
                next_opt_probs, next_log_probs = model.get_option_probs(next_states)
                next_q_vals = model.q_option(next_states)
                next_v = (next_opt_probs * (next_q_vals - next_log_probs)).sum(dim=-1, keepdim=True)
                target_q = rewards + gamma * (1 - dones) * next_v

            q_loss = F.mse_loss(q_vals, target_q)

            optimizer.zero_grad()
            q_loss.backward()
            optimizer.step()

        if done:
            break

    if episode % 50 == 0:
        print(f"Episode {episode} complete")

Episode 0 complete


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x5 and 6x64)

## Next Steps

## SOC for Continuous Actions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Gaussian Policy per option
class GaussianPolicy(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU()
        )
        self.mean = nn.Linear(256, action_dim)
        self.log_std = nn.Linear(256, action_dim)

    def forward(self, state):
        x = self.fc(state)
        mean = self.mean(x)
        log_std = torch.clamp(self.log_std(x), -20, 2)
        std = log_std.exp()
        return mean, std

    def sample(self, state):
        mean, std = self.forward(state)
        dist = torch.distributions.Normal(mean, std)
        z = dist.rsample()
        action = torch.tanh(z)
        log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-6)
        return action, log_prob.sum(-1, keepdim=True), mean, std

# Q-network (per option)
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.q = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q(x)

# Option selector
class OptionSelector(nn.Module):
    def __init__(self, state_dim, num_options):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256), nn.ReLU(),
            nn.Linear(256, num_options)
        )

    def forward(self, state):
        logits = self.net(state)
        return F.softmax(logits, dim=-1), F.log_softmax(logits, dim=-1)

# Termination network
class Termination(nn.Module):
    def __init__(self, state_dim, num_options):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256), nn.ReLU(),
            nn.Linear(256, num_options), nn.Sigmoid()
        )

    def forward(self, state):
        return self.net(state)

# Complete SOC+SAC Agent
class SoftOptionCriticContinuous(nn.Module):
    def __init__(self, state_dim, action_dim, num_options):
        super().__init__()
        self.num_options = num_options
        self.policies = nn.ModuleList([GaussianPolicy(state_dim, action_dim) for _ in range(num_options)])
        self.q_funcs = nn.ModuleList([QNetwork(state_dim, action_dim) for _ in range(num_options)])
        self.q_targets = nn.ModuleList([QNetwork(state_dim, action_dim) for _ in range(num_options)])
        self.option_selector = OptionSelector(state_dim, num_options)
        self.termination = Termination(state_dim, num_options)

    def sample_action(self, state, option):
        return self.policies[option].sample(state)

    def get_q(self, state, action, option):
        return self.q_funcs[option](state, action)

    def get_target_q(self, state, action, option):
        return self.q_targets[option](state, action)

    def get_option_probs(self, state):
        return self.option_selector(state)

    def get_termination_probs(self, state):
        return self.termination(state)


In [None]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import gym
import numpy as np
import random
from collections import deque

# Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)

    def push(self, *transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        return map(np.array, zip(*batch))

# Hyperparameters
env = gym.make("Pendulum-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
num_options = 2
gamma = 0.99
alpha = 0.2
batch_size = 64
lr = 3e-4
tau = 0.005

# Initialize agent and target networks
agent = SoftOptionCriticContinuous(state_dim, action_dim, num_options)
buffer = ReplayBuffer()
optimizers = {
    "q": [optim.Adam(q.parameters(), lr=lr) for q in agent.q_funcs],
    "policy": [optim.Adam(pi.parameters(), lr=lr) for pi in agent.policies],
    "option": optim.Adam(agent.option_selector.parameters(), lr=lr),
    "termination": optim.Adam(agent.termination.parameters(), lr=lr)
}

# Soft update for target networks
def soft_update(target, source):
    for t_param, s_param in zip(target.parameters(), source.parameters()):
        t_param.data.copy_(tau * s_param.data + (1 - tau) * t_param.data)

# Main Training Loop
for episode in range(500):
    state = torch.tensor(env.reset()[0], dtype=torch.float32)
    option = agent.get_option_probs(state.unsqueeze(0))[0].multinomial(1).item()

    for t in range(200):
        action, logp, _, _ = agent.sample_action(state.unsqueeze(0), option)
        action_np = action.squeeze().detach().numpy()
        next_state_np, reward, done, _, _ = env.step(action_np)
        next_state = torch.tensor(next_state_np, dtype=torch.float32)

        buffer.push(state.numpy(), action_np, reward, next_state.numpy(), float(done), option, logp.item())
        state = next_state

        # Sample and train
        if len(buffer.buffer) >= batch_size:
            s, a, r, s2, d, o, logp = buffer.sample(batch_size)
            s = torch.tensor(s, dtype=torch.float32)
            a = torch.tensor(a, dtype=torch.float32)
            r = torch.tensor(r, dtype=torch.float32).unsqueeze(1)
            s2 = torch.tensor(s2, dtype=torch.float32)
            d = torch.tensor(d, dtype=torch.float32).unsqueeze(1)
            o = torch.tensor(o, dtype=torch.int64).unsqueeze(1)
            logp = torch.tensor(logp, dtype=torch.float32).unsqueeze(1)

            # Critic loss
            q_vals = torch.stack([
                agent.get_q(s, a, i).squeeze() for i in range(num_options)
            ], dim=1)
            q = q_vals.gather(1, o)
            with torch.no_grad():
                a2, logp2, _, _ = agent.sample_action(s2, o[0,0].item())
                q_target = agent.get_target_q(s2, a2, o[0,0].item())
                target = r + gamma * (1 - d) * (q_target - alpha * logp2)
            q_loss = F.mse_loss(q, target)

            optimizers["q"][o[0,0]].zero_grad()
            q_loss.backward()
            optimizers["q"][o[0,0]].step()

            # Policy loss
            new_action, new_logp, _, _ = agent.sample_action(s, o[0,0].item())
            q_new = agent.get_q(s, new_action, o[0,0].item())
            policy_loss = (alpha * new_logp - q_new).mean()

            optimizers["policy"][o[0,0]].zero_grad()
            policy_loss.backward()
            optimizers["policy"][o[0,0]].step()

            # Soft update target networks
            soft_update(agent.q_targets[o[0,0]], agent.q_funcs[o[0,0]])

        if done:
            break

    if episode % 25 == 0:
        print(f"Episode {episode} complete.")

env.close()

### Transformer-based Option Selection
Replace MLP option selector with transformer or RNN that considers temporal context.\
In partially observable environments or those with temporally extended dynamics, it helps if option selection is based not just on the current state, but on a history of states or embeddings.

In [None]:
class TransformerOptionSelector(nn.Module):
    def __init__(self, state_dim, num_options, seq_len=10, d_model=128, nhead=4):
        super().__init__()
        self.embed = nn.Linear(state_dim, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
        self.classifier = nn.Linear(d_model, num_options)
        self.seq_len = seq_len

    def forward(self, state_seq):
        # state_seq: (batch_size, seq_len, state_dim)
        x = self.embed(state_seq)
        x = self.transformer(x)
        final_token = x[:, -1]  # use final embedding
        return F.softmax(self.classifier(final_token), dim=-1), F.log_softmax(self.classifier(final_token), dim=-1)


### Sparse Rewards + Intrinsic Motivation
### Sparse Rewards + Intrinsic Motivation

To tackle hard-exploration tasks, intrinsic rewards can be used effectively. Examples include:

- **RND (Random Network Distillation)**: Encourages exploration by rewarding states that are hard to predict.
- **ICM (Intrinsic Curiosity Module)**: Drives exploration by rewarding transitions that are hard to model.
- **Novelty-based rewards**: Rewards the agent for visiting novel states.

#### Integration Steps:

1. Compute intrinsic reward $r_{\text{int}}$ at each timestep.
2. Combine intrinsic reward with external reward:
    $$r' = r_{\text{ext}} + \eta \cdot r_{\text{int}}$$
    where $\eta$ is a scaling factor for intrinsic rewards.
3. Train the Soft Option-Critic using the combined reward $ r' $.

In [None]:
class RNDModel(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.target = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128)
        )
        self.predictor = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128)
        )
        for p in self.target.parameters():
            p.requires_grad = False  # fixed target

    def forward(self, state):
        with torch.no_grad():
            target_feat = self.target(state)
        pred_feat = self.predictor(state)
        return F.mse_loss(pred_feat, target_feat, reduction='none').mean(dim=1)


Use DIAYN-style skill discovery to unsupervisedly learn diverse, distinguishable options.

#### Core Idea:
Train a discriminator to classify the option ID from state embeddings:
- The easier it is to predict the option from the state, the more distinguishable the skills.
- Adds mutual information (MI) between states and options, $MI(s, z)$, into the reward function.

#### DIAYN Components:
- **Discriminator**: $D(o \mid s)$
- **Intrinsic Reward**:
    $$r_{\text{disc}}(s, o) = \log D(o \mid s) - \log p(o)$$


### Mutual Information Regularization (InfoMax)

Encourage options to encode distinct representations by maximizing mutual information. This ensures that each option learns a unique and disentangled representation.

#### Loss Examples:
- **InfoNCE Loss**: Measures the similarity between the option ID and the learned state embedding.
- **Disentangled Representations**: Promotes diversity across options.

#### Loss Function:
$$
\mathcal{L}_{\text{InfoNCE}} = -\mathbb{E}_{(s, o)}\left[\log \frac{\exp(f(s)^\top W g(o))}{\sum_{o'} \exp(f(s)^\top W g(o'))}\right]
$$

Where:
- \( f(s) \): State embedding function.
- \( g(o) \): Option embedding function.
- \( W \): Learnable weight matrix.

This loss encourages the embeddings \( f(s) \) and \( g(o) \) to align for the correct option \( o \), while being distinct from other options \( o' \).