# Trainer

> Fill in a module description here

In [None]:
#| default_exp trainer

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import pytorch_lightning as pl 

In [None]:
#| export
from typing import Callable, Tuple

import torch
from torchtyping import TensorType
from einops import rearrange

$L_t^{C L I P+V F+S}(\theta)=\hat{\mathbb{E}}_t\left[L_t^{C L I P}(\theta)-c_1 L_t^{V F}(\theta)+c_2 S\left[\pi_\theta\right]\left(s_t\right)\right]$

$L^{C L I P}(\theta)=\hat{\mathbb{E}}_t\left[\min \left(r_t(\theta) \hat{A}_t, \operatorname{clip}\left(r_t(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_t\right)\right]$

$\frac{\pi_\theta\left(a_t \mid s_t\right)}{\pi_{\theta_{\text {old }}}\left(a_t \mid s_t\right)} = \log(\pi_\theta\left(a_t \mid s_t\right)) - \log(\pi_{\theta_{\text {old }}}\left(a_t \mid s_t\right))$

$r_t(\theta)=\frac{\pi_\theta\left(a_t \mid s_t\right)}{\pi_{\theta_{\text {old }}}\left(a_t \mid s_t\right)}$

In [None]:
#| export
class RLHFTrainer:
    def __init__(
        self, model: Callable, ref_model: Callable, config
    ):
        self.model = model
        self.ref_model = ref_model
        self.epsilon = config.epsilon
        self.ent_coef = config.ent_coef
        self.vf_coef = config.vf_coef
    
    @classmethod
    def compute_advantage_and_return(
        self,
        rewards: TensorType["batch_size"],
        values: TensorType["batch_size"]
    ) -> Tuple[TensorType["batch_size"], TensorType["batch_size"]]:
        # copied from https://github.com/lvwerra/trl/blob/d2e8bcf8373726fb92d2110c500f7df6d0bd566d/trl/trainer/ppo_trainer.py#L686
        # TODO: understand this!!!!
        
        rewards = rearrange(rewards, 'b -> 1 b')
        values = rearrange(values, 'b -> 1 b')
        
        lastgaelam = 0
        advantages_reversed = []
        gen_len = len(rewards)
        
        gamma = 1
        lam = 0.95

        for t in reversed(range(gen_len)):
            nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
            delta = rewards[:, t] + gamma * nextvalues - values[:, t]
            lastgaelam = delta + gamma * lam * lastgaelam
            advantages_reversed.append(lastgaelam)

        advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
        returns = advantages + values

        advantages = rearrange(advantages, '1 b -> b')
        returns = rearrange(returns, '1 b -> b')
        
        return advantages, returns

    def compute_loss(
        self,
        queries: TensorType["batch_size", "seq_len"],
        responses: TensorType["batch_size", "seq_len"],
        rewards: TensorType["batch_size"],
    ) -> TensorType["1"]:
        logprobs, values, entropies, ref_logprobs = self.forward_batch(queries, responses)
        # loss = self.loss(logprobs, ref_logprobs, values)
        
        ratio = (logprobs - ref_logprobs).exp()
        clipped_ratio = torch.clamp(ratio, min=1-self.epsilon, max=1+self.epsilon)
        
        # TODO: write the advantages
        # advantages = rewards
        # returns = rewards
        advantages, returns = self.compute_advantage_and_return(rewards, values)
        
        pg_loss_1 = ratio * advantages
        pg_loss_2 = ratio * clipped_ratio
        
        pg_loss = torch.min(pg_loss_1, pg_loss_2).mean()
        
        value_loss = (values - returns).pow(2).mean()
        
        loss = pg_loss - self.ent_coef * entropies.mean() + self.vf_coef * value_loss
        
        return loss
    
    def forward_batch(
        self,
        queries: TensorType["batch_size", "seq_len"],
        responses: TensorType["batch_size", "seq_len"]
    ) -> Tuple[
        TensorType["batch_size"], TensorType["batch_size"],
        TensorType["batch_size"], TensorType["batch_size"]
    ]:
        inputs = torch.cat([queries, responses], dim=1)
        
        _, logprobs, entropy, value = self.model(inputs)
        _, ref_logprob, _, _ = self.ref_model(inputs)
            
        return logprobs, entropy, value, ref_logprob
    
    def forward(
        self,
        input_ids: TensorType["batch", "seq_len", "n_dim"],
        attention_mask: TensorType["batch", "seq_len"]
    ) -> TensorType["batch", "log_probs"]:
        
        with torch.no_grad():
            # action_logits, action_logprobs, entropy, value
            _, logprobs, entropy, value = self.model(input_ids, attention_mask)
            _, ref_logprob, _, _ = self.ref_model(input_ids, attention_mask)
        
        loss = self.loss(logprobs, entropy, value, ref_logprob)
        