# Trainer

> Fill in a module description here

In [None]:
#| default_exp trainer

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import pytorch_lightning as pl 

In [None]:
#| export
from typing import Callable, Tuple

import torch
from torchtyping import TensorType

$L_t^{C L I P+V F+S}(\theta)=\hat{\mathbb{E}}_t\left[L_t^{C L I P}(\theta)-c_1 L_t^{V F}(\theta)+c_2 S\left[\pi_\theta\right]\left(s_t\right)\right]$

$L^{C L I P}(\theta)=\hat{\mathbb{E}}_t\left[\min \left(r_t(\theta) \hat{A}_t, \operatorname{clip}\left(r_t(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_t\right)\right]$

$\frac{\pi_\theta\left(a_t \mid s_t\right)}{\pi_{\theta_{\text {old }}}\left(a_t \mid s_t\right)} = \log(\pi_\theta\left(a_t \mid s_t\right)) - \log(\pi_{\theta_{\text {old }}}\left(a_t \mid s_t\right))$

$r_t(\theta)=\frac{\pi_\theta\left(a_t \mid s_t\right)}{\pi_{\theta_{\text {old }}}\left(a_t \mid s_t\right)}$

In [None]:
#| export
class RLHFTrainer:
    def __init__(
        self, model: Callable, ref_model: Callable, config
    ):
        self.model = model
        self.ref_model = ref_model
        self.epsilon = config.epsilon
        self.ent_coef = config.ent_coef
        self.vf_coef = config.vf_coef
    
    def loss(
        self,
        old_logprobs: torch.FloatTensor,
        values: TensorType["batch_size"],
        rewards: torch.FloatTensor,
        query: torch.LongTensor,
        response: torch.LongTensor,
        model_input: torch.LongTensor,
    ):
        pass
    
    def loss(
        self,
        action_logprobs, entropy: TensorType["batch_size"], values: TensorType["batch_size"],
        prev_logprobs
    ) -> TensorType["batch_size", 1]:
       
        # ref_probs = F.softmax(ref_logits, dim=-1)
        
        ratio = (action_logprobs - prev_logprobs).exp()
        clipped_ratio = torch.clamp(ratio, min=1-self.epsilon, max=1+self.epsilon)
        
        # TODO: Implement the advantages
        advantages = None
        
        unclipped_pg_loss = ratio * advantages
        clipped_pg_loss = clipped_ratio * advantages
        
        pg_loss = torch.min(unclipped_pg_loss, clipped_pg_loss).mean()
        
        entropy_loss = entropy.mean()
        value_losses = values.mean()
        
        loss = pg_loss - self.ent_coef * entropy_loss + self.vf_coef * value_losses
        
        return loss

    def step(
        self,
        queries: TensorType["batch_size", "seq_len"],
        responses: TensorType["batch_size", "seq_len"],
        rewards: TensorType["batch_size"],
    ):
        output = self.forward_batch(queries, responses)
    
    def forward_batch(
        self,
        queries: TensorType["batch_size", "seq_len"],
        responses: TensorType["batch_size", "seq_len"]
    ) -> Tuple[TensorType["batch_size", ""]]:
        inputs = torch.cat([queries, responses], dim=1)
        
        with torch.no_grad():
            _, logprobs, _, value = self.model(inputs)
            _, ref_logprob, _, _ = self.ref_model(inputs)
            
        return logprobs, ref_logprob, value
    
    def forward(
        self,
        input_ids: TensorType["batch", "seq_len", "n_dim"],
        attention_mask: TensorType["batch", "seq_len"]
    ) -> TensorType["batch", "log_probs"]:
        
        with torch.no_grad():
            # action_logits, action_logprobs, entropy, value
            _, logprobs, entropy, value = self.model(input_ids, attention_mask)
            _, ref_logprob, _, _ = self.ref_model(input_ids, attention_mask)
        
        loss = self.loss(logprobs, entropy, value, ref_logprob)
        