# DIAYN (Diversity is All You Need)

Skill discovery in RL aims to learn useful behavioral primitives (skills/options) without external rewards. DIAYN proposes that diversity alone, without specific goals, can be a powerful unsupervised objective.
The central idea is:
"Can we learn a diverse set of skills such that each skill leads to distinct behavior, even without any extrinsic reward?"
DIAYN does this by:
- Assigining a fixed skill at the beginning of each episode.
- Encouraging the policy to behave as distinctively as possible so that a discriminator can easily infer which skill was used, given only the observed state.

## Background

Let $z \in {1,2,...,K}$ a categorical skill ID, the policy $\pi(a_t|s_t,z)$ is conditioned on the skill, the goal is to make the skill predictable from the agent's behavior, maximize:
$$I(s_t;z)$$
Where I denotes mutual information.
A discriminator $q_\phi(z|s_t)$ is trained to predict the skill given a state.

## Core Objective

The agent receives an intrinsic reward proportional to how easily the discriminator can identify the skill:
$$r_t^{intr} = \log q_\phi(z|s_t)$$
This leads to the following unsupervised RL objective:
$$\max_\pi \mathbb{E}_{z \sim p(z)} \mathbb{E}_{\tau  \sim \pi(\cdot | z)}[\sum_t \log q_\phi(z|s_t)]$$
Where:
- $\pi(a|s,z)$ skill-conditioned policy
- $q_\phi(z|s)$ learned discriminator
- $p(z)$ uniform skill prior
- $s_t$ observation at time t

This alligns with empowerment and variational information maximization, where high mutual information implies better skill controllability.

## Variational Derivation of Mutual Info

The mutual info can be written as:
$$I(s;z) = H(z) - H(z|s)$$
Since $H(z)$ is constant (uniform prior), maximizing I(s;z) is equivalent to minimizing $H(z|s)$, making skill predictable from states.
We use variational lower bound:
$$I(s;z) \geq \mathbb{E}_{z \sim p(z), s\sim\pi(z)}[\log q_\phi(z|s)]$$
This becomes our intrinsic reward training objective.

DIAYN ensures:
- Diversity: each skill reaches different parts of the state space
- Disnetaglement: the agent learns a latent representation where skills correspond to semantically distinct behaviors.
- No collapse: the discriminator prevents the policy from learning degenerate skills that lead to the same state.

## Challenge in pixel based envs

State representation $s_t$ must be learned from raw pixel observations $o_t$, we can use a ConvNet encoder to map $o_t \to s_t \in \mathbb{R}^d$.
This requires the encoder to be trainable and expressive enough to:
- retain spatial structure
- encode meaningful behavioral info.
This entire discriminator is trained on the latent state $s_t$, which is the output of the encoder:
$$q_\phi(z|f_\theta(o_t))$$

- Option-Critic: DIAYN can be interpreted as a special case where skills/options are selected stochastically and not terminated until episode end.
- DADS, VALOR: Continuous skill analogs that extend DIAYN with continuous latent skills and better control mechanisms.

## Implementation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import random
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class PixelEncoder(nn.Module):
    def __init__(self, input_shape=(3,96,96), latent_dim=256):
        super().__init__()
        c,h,w = input_shape

        self.conv = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2), nn.ReLU(),
        )

        with torch.no_grad():
            dummy = torch.zeros(1, *input_shape)
            out_dim = self.conv(dummy).view(1,-1).shape[1]
        
        self.fc = nn.Linear(out_dim, latent_dim)

    def forward(self, x):
        x = x /255.0
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [5]:
class SkillPolicy(nn.Module):
    def __init__(self, obs_dim, skill_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim + skill_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, obs, skill):
        x = torch.cat([obs, skill], dim=-1)
        return torch.tanh(self.net(x))

In [6]:
class SkillDiscriminator(nn.Module):
    def __init__(self, obs_dim, num_skills, hidden_dim = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, num_skills)
        )

    def forward(self, obs):
        return self.net(obs)

In [8]:
def compute_intrinsic_reward(discriminator, state_latent, skill_id):
    logits = discriminator(state_latent)
    log_probs = F.log_softmax(logits, dim=-1)
    return log_probs.gather(1, skill_id.unsqueeze(1)).squeeze(1)


In [9]:
class DIAYNAgent(nn.Module):
    def __init__(self, image_shape, num_skill, action_dim, latent_dim=256):
        super().__init__()
        self.encoder = PixelEncoder(image_shape, latent_dim)
        self.skill_embedding = nn.Embedding(num_skill, latent_dim)
        self.skill_policy = SkillPolicy(latent_dim, num_skill, action_dim)
        self.skill_discriminator = SkillDiscriminator(latent_dim, num_skill)

    def act(self, obs, skill):
        with torch.no_grad():
            obs_encoded = self.encoder(obs)
            skill = self.skill_embedding(skill)
            return self.skill_policy(obs_encoded, skill)

In [10]:
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity=100_000):
        self.buffer = deque(maxlen=capacity)

    def push(self, obs, action, reward, next_obs, done, skill):
        self.buffer.append((obs, action, reward, next_obs, done, skill))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        obs, actions, rewards, next_obs, dones, skills = zip(*batch)
        return (
            torch.tensor(np.stack(obs), dtype=torch.float32).to(device),
            torch.tensor(np.stack(actions), dtype=torch.float32).to(device),
            torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(device),
            torch.tensor(np.stack(next_obs), dtype=torch.float32).to(device),
            torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(device),
            torch.tensor(skills, dtype=torch.long).to(device)
        )

    def __len__(self):
        return len(self.buffer)


In [11]:
def soft_update(target, source, tau):
    for t_param, s_param in zip(target.parameters(), source.parameters()):
        t_param.data.copy_(t_param.data * (1.0 - tau) + s_param.data * tau)

In [12]:
def train_step(agent, replay_buffer, batch_size=128, gamma=0.99, alpha=0.2):
    if len(replay_buffer) < batch_size:
        return

    batch = replay_buffer.sample(batch_size)
    obs, action, next_obs, reward, done, skill = batch

    # Encode current and next observations
    latent_s = agent.encoder(obs)
    latent_next_s = agent.encoder(next_obs).detach()

    # === Critic update ===
    with torch.no_grad():
        next_action, logp_next = agent.policy.sample(latent_next_s, skill)
        target_q1, target_q2 = agent.target_critic(latent_next_s, next_action, skill)
        target_q = torch.min(target_q1, target_q2) - alpha * logp_next
        target = reward + gamma * (1 - done) * target_q

    q1, q2 = agent.critic(latent_s, action, skill)
    critic_loss = F.mse_loss(q1, target) + F.mse_loss(q2, target)

    agent.critic_optimizer.zero_grad()
    critic_loss.backward()
    agent.critic_optimizer.step()

    # === Policy update ===
    action_pi, logp_pi = agent.policy.sample(latent_s, skill)
    q1_pi, q2_pi = agent.critic(latent_s, action_pi, skill)
    policy_loss = (alpha * logp_pi - torch.min(q1_pi, q2_pi)).mean()

    agent.policy_optimizer.zero_grad()
    policy_loss.backward()
    agent.policy_optimizer.step()

    # === Discriminator update ===
    logits = agent.discriminator(latent_s.detach())
    disc_loss = F.cross_entropy(logits, skill)

    agent.disc_optimizer.zero_grad()
    disc_loss.backward()
    agent.disc_optimizer.step()

    # === Target Network Update ===
    soft_update(agent.target_critic, agent.critic, tau=0.005)

    return {
        "critic_loss": critic_loss.item(),
        "policy_loss": policy_loss.item(),
        "disc_loss": disc_loss.item()
    }


In [None]:
def evaluate_skills(env, agent, num_skills=10, episodes_per_skill=1, render=False):
    skill_trajectories = []

    for skill_id in range(num_skills):
        for _ in range(episodes_per_skill):
            obs, _ = env.reset()
            done = False
            episode = []
            skill = torch.tensor([skill_id], dtype=torch.long).to(device)

            while not done:
                obs_tensor = preprocess(obs)
                with torch.no_grad():
                    latent = agent.encoder(obs_tensor.unsqueeze(0))
                    action = agent.policy(latent, skill.unsqueeze(0)).cpu().numpy()[0]
                obs, _, done, _, _ = env.step(action)
                episode.append(obs)

                if render:
                    env.render()

            skill_trajectories.append((skill_id, episode))

    return skill_trajectories

def discriminator_accuracy(agent, dataloader):
    total, correct = 0, 0
    for obs_batch, skill_batch in dataloader:
        latent = agent.encoder(obs_batch.to(device))
        logits = agent.discriminator(latent)
        pred = torch.argmax(logits, dim=1)
        correct += (pred == skill_batch.to(device)).sum().item()
        total += len(skill_batch)
    return correct / total