# Agent

> Fill in a module description here

In [None]:
#| default_exp agent

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
from typing import Callable

import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoModel
import pytorch_lightning as pl 

In [None]:
#| export
class Agent(nn.Module):
    def __init__(self, checkpoint):
        super().__init__()
        self.model = AutoModel.from_pretrained(checkpoint)
    
    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        output = self.model.generate(input_ids, attention_mask=attention_mask)
        return output

$\begin{aligned} \operatorname{objective~}(\phi)= & E_{(x, y) \sim D_{\pi_\phi^{\mathrm{RL}}}}\left[r_\theta(x, y)-\beta \log \left(\pi_\phi^{\mathrm{RL}}(y \mid x) / \pi^{\mathrm{SFT}}(y \mid x)\right)\right]+ \\ & \gamma E_{x \sim D_{\text {pretrain }}}\left[\log \left(\pi_\phi^{\mathrm{RL}}(x)\right)\right]\end{aligned}$

In [None]:
#| export
class AgentLoss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, chosen_rewards, rejected_reward):
        pass

In [None]:
#| export
class AgentObjective(nn.Module):
    def __init__(
        self, model: Callable, sft_model: Callable, reward_model: Callable,
        gamma: float, beta: float
    ):
        super().__init__()
        self.model = model
        self.sft_model = sft_model
        self.reward_model = reward_model
        
    def forward(self, input_ids, attention_mask):
        
        model_logits = self.model(input_ids, attention_mask)
        # TODO: implement these
        model_input_ids = None
        model_attention_mask = None
        model_dist = F.softmax(model_logits, dim=-1)
        
        sft_logits = self.sft_model(input_ids, attention_mask)
        sft_dist = F.softmax(sft_logits, dim=-1)
        
        reward_score = self.reward_model(model_input_ids, model_attention_mask)
        
        ratio = torch.log(model_dist / sft_dist)
        
        # compute the coherent of the generated text
        coherent = torch.log(model_dist)
        
        objective = (reward_score - self.beta*ratio).mean() + self.gamma * coherent.mean()
        
        return objective
        

In [None]:
#| export
class LitAgent(pl.LightningModule):
    def __init__(self, model: Callable, loss_func: Callable):
        super().__init__()
        self.model = model
        self.loss_func = loss_func
    
    def training_step(self, batch, batch_idx):
        pass