# Reward Model

> Fill in a module description here

In [None]:
#| default_exp reward

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
from typing import Callable, Union, List

import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
import pytorch_lightning as pl 
from transformers import AutoModel, AutoTokenizer
from einops import rearrange
from torchtyping import TensorType

from instruct_goose.utils import load_yaml

  from .autonotebook import tqdm as notebook_tqdm


### Reward Model

In [None]:
#| export
class RewardModel(nn.Module):
    def __init__(self, checkpoint: str, dropout: float = 0.1):
        super().__init__()
        # self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        # self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModel.from_pretrained(checkpoint)
        
        config = self.model.config
        n_embed = config.n_embd
        
        # custom head
        self.reward_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(n_embed, 1),
            nn.Sigmoid()
        )
        
    def forward(
        self,
        input_ids: TensorType["batch_size", "seq_len"],
        attention_mask: TensorType["batch_size", "seq_len"],
    ) -> TensorType["batch_size", 1]:
        # inputs = self.tokenizer(
        #     prompts,
        #     padding=True,
        #     truncation=True,
        #     return_tensors="pt"
        # )
        
        last_hidden_state = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).last_hidden_state
        
        output = self.reward_head(last_hidden_state)
                
        # output = rearrange(output, 'b 1 t 1 -> b t')
        # for eacb item in the batch
        # choose the hidden state of the last token as a reward!
        reward_scalar = output[:, -1, 0]
        
        return reward_scalar

### Loss function

$\operatorname{loss}(\theta)=-\frac{1}{\left(\begin{array}{c}
K \\
2
\end{array}\right)} E_{\left(x, y_w, y_l\right) \sim D}\left[\log \left(\sigma\left(r_\theta\left(x, y_w\right)-r_\theta\left(x, y_l\right)\right)\right)\right]$

In [None]:
#| export
class PairwiseLoss(nn.Module):
    def forward(self, chosen_rewards: torch.Tensor, rejected_rewards: torch.Tensor):
        assert len(chosen_rewards) == len(rejected_rewards)
        batch_size = len(chosen_rewards)
        
        # maps the difference between the rewards to a probability
        probs = torch.sigmoid(chosen_rewards - rejected_rewards)
        return -probs.mean() / batch_size

### Trainer

In [None]:
#| export
class LitRewardModel(pl.LightningModule):
    def __init__(
        self, model: Callable, loss_func: Callable,
        lr: Union[int, float] = 1e-3
    ):
        super().__init__()
        self.model = model
        self.loss_func = loss_func
        self.lr = lr
    
    def training_step(self, batch, batch_idx: int):
        chosen_input_ids, chosen_attention_mask,\
        rejected_input_ids, rejected_attention_mask = batch
        
        chosen_rewards = self.model(chosen_input_ids, chosen_attention_mask)
        rejected_rewards = self.model(rejected_input_ids, rejected_attention_mask)
        
        loss = self.loss_func(chosen_rewards, rejected_rewards)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        return optimizer