# Reward Model

## 模型结构

![reward](./llama2reward.png)

> 3.2.2 Reward Modeling
> 
> The model architecture and hyper-parameters are identical to those
> of the pretrained language models, except that the classification head for next-token prediction is replaced
> with a regression head for *outputting a scalar reward*.

In [16]:
import torch
import torch.nn.functional as F
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification

torch.manual_seed(1)

# 加载模型
config = LlamaConfig(vocab_size = 100,      # default is 32000
                    hidden_size = 256,
                    intermediate_size = 512,
                    num_hidden_layers = 2,
                    num_attention_heads = 4,
                    num_key_value_heads = 4,
                    )
model = LlamaForCausalLM(config)
model.save_pretrained('./lm_pretrained')
rm_model = LlamaForSequenceClassification.from_pretrained('./lm_pretrained', num_labels=1)

print(rm_model)

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at ./lm_pretrained and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 模型训练+margin loss

$L=-\log(\sigma(r_{\theta}(x,y_c)-r_{\theta}(x,y_r)-m(r)))$

where the margin $m(r)$ is a discrete function of the preference rating.

In [62]:
X_chosen = torch.randint(0, 100, (1,10))
X_rejected = torch.randint(0, 100, (1,10))
margin = 3.0 # Margin Large : Significantly Better 

idx={}
idx['input_ids'] = X_chosen
rm_chosen = rm_model(**idx).logits

idx['input_ids'] = X_rejected
rm_rejected = rm_model(**idx).logits

loss = -torch.sigmoid(rm_chosen - rm_rejected).log()
loss_with_margin = -torch.sigmoid(rm_chosen - rm_rejected - margin).log()

print( f'prompt chosen reward : {rm_chosen.item()}')
print( f'prompt rejected reward : {rm_rejected.item()}')
print( f'reward model loss: {loss.item()}')
print( f'reward model loss with margin: {loss_with_margin.item()}')

## 模型推理 score

In [66]:
x = torch.randint(0, 100, (1,10))
rm_model.eval()
rm_score = rm_model(**idx).logits
print(x)
print(x.shape)
print('reward result:', rm_score.item())

##  双reward选择

$$ R_c(g | p) = \begin{cases} R_s(g|p)& \text{if } \text{is\_safety}(p) \text{ or } R_s(g \mid p) < 0.15 \\ R_h(g|p) & \text{otherwise} \end{cases} $$

In [10]:
def llama2_reward_select(reward_safety, reward_helpfulness):
    return reward_safety if reward_safety < 0.15 else reward_helpfulness
    
rc = llama2_reward_select(reward_safety=-0.3,  reward_helpfulness=0.7)
print(rc)
rc = llama2_reward_select(reward_safety=1.3,  reward_helpfulness=0.4)
print(rc)

## 逆Sigmoid

> We also find it important to whiten
> the final linear scores (shown here by reversing the sigmoid with the logit function) in order to increase
> stability and balance properly with the KL penalty term (β) above.

$$\hat{R}={\text{WHITEN}}({\color{red}{\text{LOGIT}}(R_c(g|p)}))$$

In [56]:
# 实际reward已经输出标量，无须加入以下操作
# 代码仅以复现为目标
def inverse_sigmoid(x):
    return torch.log(x / (1 - x))

sigmoid_output = torch.tensor([0.9])
inverse_sigmoid_output = inverse_sigmoid(sigmoid_output)
print("逆Sigmoid输出：", inverse_sigmoid_output)

sigmoid_output = torch.tensor([0.5])
inverse_sigmoid_output = inverse_sigmoid(sigmoid_output)
print("逆Sigmoid输出：", inverse_sigmoid_output)

sigmoid_output = torch.tensor([0.01])
inverse_sigmoid_output = inverse_sigmoid(sigmoid_output)
print("逆Sigmoid输出：", inverse_sigmoid_output)

## Whiten

$$\hat{R}_c(g|p)={\color{red}{\text{WHITEN}}}({\text{LOGIT}}(R_c(g|p)))$$

In [59]:
def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
    mean, var = torch.mean(values), torch.var(values)
    whitened = (values - mean) * torch.rsqrt(var + 1e-8)
    if not shift_mean:
        whitened += mean
    return whitened

values = torch.tensor([[0.8300, 1.2000, 3.3000, 4.6000]])
values_whiten = whiten(values)
print('whiten前：', values)
print('whiten后：', values_whiten)


values = torch.tensor([[100.8300, 101.2000, 103.3000, 104.6000]])
values_whiten = whiten(values)
print('whiten前：', values)
print('whiten后：', values_whiten)


## KL penalty 

$$R(g|p)=\hat{R}_c(g|p)-\color{red}{\beta D_{KL}(\pi_{\theta}(g|p)||\pi_{0}(g|p))}$$

In [76]:
import torch.nn.functional as F

model = LlamaForCausalLM(config) # actor model
model_old = LlamaForCausalLM(config) # actor model

# old policy
index_old = torch.randint(0, 100, (1,1)) # select policy
prob_old = torch.rand(1,1)
print('old policy index:', index_old.item())
print('old policy prob:',  prob_old.item())

# new policy
output = model(X)['logits'][:,-1,:].sigmoid()
prob = torch.gather(output, dim=1, index=index_old)
print('policy prob:', prob.item())

# calculative kl 
# kl = F.kl_div(torch.log(prob), prob_old.detach, reduction='sum') 
kl = F.kl_div(torch.log(prob), prob_old)
print('kl penalty:', kl.item())

# final reward
beta = 0.01
rm_score = rm_model(**idx).logits
rm_ppo = rm_score - beta * kl
print('rm_score:', rm_score.item())
print('beta:', beta)
print('rm_score with kl', rm_ppo.item())