# Reward Model

> Reward Model and Pairwise Loss function

In [None]:
#| default_exp reward

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
from torchtyping import TensorType

### Reward Model

In [None]:
#| export
class RewardModel(nn.Module):
    """Reward model."""
    def __init__(
        self, checkpoint: str, # `transformers`'s model path
        dropout: float = 0.1 
    ):
        super().__init__()
        self.model = AutoModel.from_pretrained(checkpoint)
        
        config = self.model.config
        n_embed = config.n_embd
        
        # custom head
        self.reward_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(n_embed, 1),
            nn.Sigmoid()
        )
        
    def forward(
        self,
        input_ids: TensorType["batch_size", "seq_len"],
        attention_mask: TensorType["batch_size", "seq_len"] = None,
    ) -> TensorType["batch_size", 1]: # A reward scalar for each item in a batch
        """Calculate reward for each item in a batch."""
        last_hidden_state = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).last_hidden_state
        
        output = self.reward_head(last_hidden_state)
                
        # for each item in the batch
        # choose the hidden state of the last token as a reward!
        reward_scalar = output[:, -1, 0]
        
        return reward_scalar

In [None]:
show_doc(RewardModel)

---

### RewardModel

>      RewardModel (checkpoint:str, dropout:float=0.1)

Reward model.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| checkpoint | str |  | `transformers`'s model path |
| dropout | float | 0.1 |  |

In [None]:
show_doc(RewardModel.forward)

---

### RewardModel.forward

>      RewardModel.forward (input_ids:typing.Annotated[torch.Tensor,{'__torchtyp
>                           ing__':True,'details':('batch_size','seq_len',),'cls
>                           _name':'TensorType'}], attention_mask:typing.Annotat
>                           ed[torch.Tensor,{'__torchtyping__':True,'details':('
>                           batch_size','seq_len',),'cls_name':'TensorType'}])

Calculate reward for each item in a batch.

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| input_ids | Annotated |  |
| attention_mask | Annotated |  |
| **Returns** | **Annotated** | **A reward scalar for each item in a batch** |

### Pairwise Loss

$\operatorname{loss}(\theta)=-\frac{1}{\left(\begin{array}{c}
K \\
2
\end{array}\right)} E_{\left(x, y_w, y_l\right) \sim D}\left[\log \left(\sigma\left(r_\theta\left(x, y_w\right)-r_\theta\left(x, y_l\right)\right)\right)\right]$

In [None]:
#| export
class PairwiseLoss(nn.Module):
    """Pairwise loss function."""
    def forward(
        self,
        chosen_rewards: TensorType["batch_size", 1], # The reward of the chosen prompt
        rejected_rewards: TensorType["batch_size", 1] # The reward of the rejected prompt
    ) -> TensorType[1]: # A scalar loss
        """Forward pass."""
        assert len(chosen_rewards) == len(rejected_rewards)
        batch_size = len(chosen_rewards)
        
        # maps the difference between the rewards to a probability
        probs = torch.sigmoid(chosen_rewards - rejected_rewards)
        return -probs.mean() / batch_size

In [None]:
show_doc(PairwiseLoss.forward)

---

### PairwiseLoss.forward

>      PairwiseLoss.forward (chosen_rewards:typing.Annotated[torch.Tensor,{'__to
>                            rchtyping__':True,'details':('batch_size',1,),'cls_
>                            name':'TensorType'}], rejected_rewards:typing.Annot
>                            ated[torch.Tensor,{'__torchtyping__':True,'details'
>                            :('batch_size',1,),'cls_name':'TensorType'}])

Forward pass.

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| chosen_rewards | Annotated | The reward of the chosen prompt |
| rejected_rewards | Annotated | The reward of the rejected prompt |
| **Returns** | **Annotated** | **A scalar loss** |