# GRPO from Scratch

> Grouped Relative Policy Optimization

GRPO is in fact just a simpler version of PPO without the reward function. So most 

In [2]:
from __future__ import annotations

import os
from dataclasses import dataclass

import numpy as np
import torch as t
from torch import nn, optim, Tensor
import einx

import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange

import gymnasium as gym

device = t.device("mps")

In [3]:
@dataclass
class EnvConfig:
    envs: int = 10
    steps: int = 128
    generations: int = 4

    @property
    def tot_envs(self): return self.envs * self.generations

@dataclass
class TrainConfig:
    episodes: int = 10
    lr: float = 1e-3
    gamma: float = 0.999 # discount factor
    ent_coef = 1e-2 # entropy bonus (encourage exploration)

env_cfg = EnvConfig()
trn_cfg = TrainConfig()

In [131]:
class NoopResetWrapper(gym.vector.VectorWrapper):
    def step(
        self, actions: ActType,
    ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
        """Steps through each of the environments returning the batched results.

        Returns:
            The batched environment step results
        """
        actions = iterate(self.action_space, actions)

        infos = {}
        for i, action in enumerate(actions):
            if self.autoreset_mode == AutoresetMode.NEXT_STEP:
                if self._autoreset_envs[i]:
                    self._env_obs[i], env_info = self.envs[i].reset()

                    self._rewards[i] = 0.0
                    self._terminations[i] = False
                    self._truncations[i] = False
                else:
                    (
                        self._env_obs[i],
                        self._rewards[i],
                        self._terminations[i],
                        self._truncations[i],
                        env_info,
                    ) = self.envs[i].step(action)
            elif self.autoreset_mode == AutoresetMode.DISABLED:
                if self._autoreset_envs[i]: continue
                    
                (
                    self._env_obs[i],
                    self._rewards[i],
                    self._terminations[i],
                    self._truncations[i],
                    env_info,
                ) = self.envs[i].step(action)
            elif self.autoreset_mode == AutoresetMode.SAME_STEP:
                (
                    self._env_obs[i],
                    self._rewards[i],
                    self._terminations[i],
                    self._truncations[i],
                    env_info,
                ) = self.envs[i].step(action)

                if self._terminations[i] or self._truncations[i]:
                    infos = self._add_info(
                        infos,
                        {"final_obs": self._env_obs[i], "final_info": env_info},
                        i,
                    )

                    self._env_obs[i], env_info = self.envs[i].reset()
            else:
                raise ValueError(f"Unexpected autoreset mode, {self.autoreset_mode}")

            infos = self._add_info(infos, env_info, i)

        # Concatenate the observations
        self._observations = concatenate(
            self.single_observation_space, self._env_obs, self._observations
        )
        self._autoreset_envs = np.logical_or(self._terminations, self._truncations)

        return (
            deepcopy(self._observations) if self.copy else self._observations,
            np.copy(self._rewards),
            np.copy(self._terminations),
            np.copy(self._truncations),
            infos,
        )

In [132]:
_envs = gym.make_vec(
    "LunarLander-v3",
    num_envs=env_cfg.tot_envs,
    max_episode_steps=env_cfg.steps, 
    vectorization_mode="sync",
    vector_kwargs=dict(autoreset_mode=gym.vector.AutoresetMode.DISABLED),
)
envs = gym.wrappers.vector.RecordEpisodeStatistics(
    _envs,
    buffer_length=env_cfg.tot_envs * env_cfg.steps
)

envs = NoopResetWrapper(envs)

obs_shape = envs.single_observation_space.shape[0]
action_shape = envs.single_action_space.n

In [133]:
def loss_func(rewards, log_probs):
    reward_per_generation = einx.sum("(env gen) [step] -> env gen", rewards, gen=env_cfg.generations)
    mean_reward = einx.mean("env [gen]", reward_per_generation)
    advantages = einx.subtract(
        "env gen, env -> env gen",
        reward_per_generation,
        mean_reward,
    )
    
    # std_rewards = einx.std("n_envs [n_generations]", reward_per_generation) + 1e-8
    # advantages = einx.divide("[n_envs n_generations], [n_envs]", advantages, std_rewards)
    
    discounts = torch.tensor([gamma ** i for i in range(n_steps_per_update)][::-1], device=device)
    per_step_advantages = einx.multiply(
        "env gen, step -> env gen step",
        advantages,
        discounts,
    )
    
    # ratio = torch.exp(action_log_probs - action_log_probs)
    pg_loss1 = -einx.multiply("env gen step, (env gen) step -> env gen step", per_step_advantages, log_probs).mean()
    return pg_loss1

In [134]:

agent = nn.Sequential(
    nn.Linear(obs_shape, 32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, action_shape),
).to(device)

opt = optim.RMSprop(agent.parameters(), lr=trn_cfg.lr)


def sample(logits):
    """
    Samples an action from logits

    Args:
        logits: (tot_envs, actions)

    Returns:
        actions: Sampled action (tot_envs,)
        log_probs: Log-prob for sampled action (tot_envs,)
        entropy: Entropy of distributions (tot_envs,)
    """
    pd = t.distributions.Categorical(logits=logits) # softmax
    actions = pd.sample()
    log_probs = pd.log_prob(actions)
    entropy = pd.entropy()
    return actions, log_probs, entropy

def maybe_reset_envs(envs, seed, tt_mask):
    consts = einx.solve("(env gen)", tt_mask, gen=env_cfg.generations)
    
    seeds = seed + einx.rearrange("env -> (env gen)", np.arange(env_cfg.envs), **consts)
    
    envs_to_reset = einx.all("(env [gen])", tt_mask, **consts) # all generations done
    reset_mask = einx.rearrange("env -> (env gen)", envs_to_reset, **consts)

    if np.any(reset_mask):
        states, info = envs.reset(seed=seeds.tolist(), options=dict(reset_mask=reset_mask))
        next_seed = seeds.max() + 1
        return states, info, next_seed.item()
    else:
        return None

# reset all envs at the start
states, info, seed = maybe_reset_envs(envs, seed=0, tt_mask=np.ones(env_cfg.tot_envs))

ep_log_probs = t.zeros((env_cfg.tot_envs, env_cfg.steps), dtype=t.float32, device=device)
ep_rewards = t.zeros((env_cfg.tot_envs, env_cfg.steps), dtype=t.float32, device=device)
ep_masks = t.full((env_cfg.tot_envs, env_cfg.steps), False, dtype=t.bool, device=device)
for step in range(env_cfg.steps):
    # sample next action
    logits = agent(t.from_numpy(states).to(device))
    actions, log_probs, entropy = sample(logits)

    # get next state from actions
    states, rewards, terminated, truncated, info = envs.step(actions.cpu().numpy())
    tt_mask = t.from_numpy(terminated | truncated).to(device)

    ep_log_probs[:, step] = log_probs
    ep_rewards[:, step] = t.from_numpy(rewards).float().to(device)
    ep_masks[:, step] = tt_mask

    res = maybe_reset_envs(envs, seed, tt_mask.cpu().numpy())
    if res is not None:
        states, info, seed = res

loss_func(ep_rewards, ep_log_probs)


NameError: name 'iterate' is not defined