Typical PPO loop should look something like

```python
for data in collector:
    for epoch in range(n_epochs):
        advantage(data)
        replay_buffer.extend(data)
        for batch in replay_buffer:
            loss = ppo_loss(batch)  
            loss.backward()
            optim.step()
```

In [1]:
from torchrl.objectives import ClipPPOLoss

from data import get_prompt_dataloaders
from models.actor_critic import init_actor_critic
from models.reward import init_reward_model
from models.transformer import init_transformer
from utils import load_config

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = load_config("config/train_rlhf.yaml")

In [3]:
model = init_transformer(config, inference=True)

compiling the model... (takes a ~minute)


In [4]:
reward_model = init_reward_model(config)

compiling the model... (takes a ~minute)


In [5]:
actor, critic, critic_head = init_actor_critic(config)

In [6]:
loss_fn = ClipPPOLoss(actor, critic_head)

In [7]:
tdl, _ = get_prompt_dataloaders(config)

In [None]:
for _ in range(config["max_iters"]):
    # form batch
    batch = next(tdl)
    # perhaps data = torch.stack([next(tdl) for _ in range(...)])?
    for epoch in range(config["num_epochs"]):
        reward = reward_model(batch.pr

In [8]:
config

{'episode_length': 50,
 'out_dir': 'out',
 'out_dir_reward': 'out_reward',
 'eval_interval': 2,
 'log_interval': 2,
 'eval_iters': 100,
 'always_save_checkpoint': True,
 'base_model': 'gpt2',
 'init_reward_from': 'resume',
 'init_base_from': 'resume',
 'dataset': 'openai_summarize_comparisons',
 'gradient_accumulation_steps': 1,
 'batch_size': 4,
 'block_size': 550,
 'dropout': 0.0,
 'learning_rate': 5e-06,
 'max_iters': 3000,
 'weight_decay': 0.01,
 'beta1': 0.9,
 'beta2': 0.999,
 'grad_clip': 10.0,
 'decay_lr': True,
 'warmup_iters': 20,
 'lr_decay_iters': 3000,
 'min_lr': 5e-07,
 'device': 'cuda',
 'dtype': 'bfloat16',
 'compile': True,
 'verbose': False,
 'ppo_batch_size': 16,
 'num_epochs': 4}

In [None]:
batch = next(tdl)

In [None]:
batch

In [None]:
reward_model(batch.transformer_data.input_ids, batch.transformer_data.attention_mask)