In [27]:
import torch
import torch.nn.functional as F
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase, set_skip_existing
from torchrl.data.replay_buffers import LazyTensorStorage, SamplerWithoutReplacement, TensorDictReplayBuffer
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tqdm import trange
from transformers import GPT2Tokenizer, GenerationConfig

from data import get_prompt_dataloaders
from models.actor_critic import init_actor_critic
from models.reward import init_reward_model
from models.transformer import init_transformer
from utils import load_config

In [2]:
torch.cuda.empty_cache()

In [3]:
config = load_config("config/train_rlhf.yaml")

In [4]:
model = init_transformer(config, inference=True)

In [5]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [6]:
reward_model = init_reward_model(config)

In [7]:
actor, critic, critic_head = init_actor_critic(config)
critic.eval();

In [8]:
loss_fn = ClipPPOLoss(actor, critic_head)

In [9]:
tdl, _ = get_prompt_dataloaders(config)

### computing log probs

```python
input_ids = batch.transformer_data.input_ids.clone()
# mask out label
prompt_rindex = batch.transformer_data.prompt_rindex
label_idx = torch.arange(input_ids.shape[1], device=prompt_rindex.device) >= prompt_rindex[:, None]
input_ids[label_idx] = 50_256
# move padding tokens to left pad
input_ids = torch.stack([torch.roll(row, (row == 50_256).sum().item(), 0) for row in input_ids])
outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, generation_config=generation_config)

scores = torch.stack(outputs.scores, 1)
log_probs = scores.max(dim=-1).values - torch.logsumexp(scores, dim=-1)
```

### generate

In [42]:
@torch.no_grad()
def generate(batch, max_new_tokens=50):
    input_ids = batch.transformer_data.input_ids.clone()
    # mask out label
    prompt_rindex = batch.transformer_data.prompt_rindex
    label_idx = torch.arange(input_ids.shape[1], device=prompt_rindex.device) >= prompt_rindex[:, None]
    input_ids[label_idx] = 50_256
    
    # move padding tokens to left pad
    input_ids = torch.stack([torch.roll(row, (row == 50_256).sum().item(), 0) for row in input_ids])
    
    # generate and capture scores
    generation_config = GenerationConfig(
        output_scores=True,
        return_dict_in_generate=True,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=max_new_tokens,
    )
    outputs = model.generate(
        input_ids=input_ids, attention_mask=(input_ids != 50_256).to(torch.int64), generation_config=generation_config
    )
    samples = outputs.sequences
    generated = torch.ones_like(input_ids) * 50_256
    for i, sample in enumerate(samples):
        mask = sample != 50_256
        generated[i, :mask.sum()] = sample[mask]
    scores = torch.stack(outputs.scores, 1)
    log_probs = F.pad(
        scores.max(dim=-1).values - torch.logsumexp(scores, dim=-1),
        (0, max_new_tokens - scores.shape[1]),
        value=0,
    )
    return generated, log_probs

### rollout

In [43]:
@torch.no_grad()
def create_rollout_td(batch, generated, reward_model, log_probs, max_new_tokens=50):
    # duplicate the input_ids, revealing one new token each time
    # this feels a bit memory inefficient
    rollout_generated = torch.stack(
        [
            torch.stack(
                [
                    torch.where(torch.arange(row.shape[0], device=generated.device) < rindex + i, row, 50_256) 
                    for i in range(max_new_tokens + 1)
                ]
            ) 
            for rindex, row in zip(batch.transformer_data.prompt_rindex, generated)
        ],
    )
    rollout_attention_mask = (rollout_generated != 50_256).to(torch.int64)
    done_idx = torch.minimum(
        (generated != 50_256).sum(dim=-1) - batch.transformer_data.prompt_rindex, torch.tensor(max_new_tokens)
    )
    done = (
        torch.arange(max_new_tokens, device=generated.device) == done_idx[:, None]
    ).unsqueeze(-1)
    _, end_scores = reward_model(
        input_ids=rollout_generated[:, -1], attention_mask=rollout_attention_mask[:, -1]
    )
    _, end_scores_labels = reward_model(
        batch.transformer_data.input_ids, batch.transformer_data.attention_mask
    )
    reward = done * (end_scores - end_scores_labels)[:, None, None]
    action_idx = torch.stack(
        [
            torch.arange(i, i + max_new_tokens, device=generated.device)
            for i in batch.transformer_data.prompt_rindex
        ]
    )
    action = generated[
        torch.arange(generated.shape[0], device=generated.device)[:, None], 
        action_idx,
    ]
    #TODO: KL
    td = {
        "action": action,
        "input_ids": rollout_generated[:, :-1].clone(),
        "attention_mask": rollout_attention_mask[:, :-1].clone(),
        "sample_log_prob": log_probs,
        "next": {
            "input_ids": rollout_generated[:, 1:].clone(),
            "attention_mask": rollout_attention_mask[:, 1:].clone(),
            "done": done,
            "reward": reward,
        }
    }
    return TensorDict(td, batch_size=done.shape[:2], device=generated.device)

In [47]:
batch = next(tdl)
generated, log_probs = generate(batch)

In [50]:
td = create_rollout_td(batch, generated, reward_model, log_probs)

In [52]:
class VmapCritic(TensorDictModuleBase):
    def __init__(self, critic):
        super().__init__()
        self.in_keys = critic.in_keys
        self.out_keys = critic.out_keys
        self.module = critic

    def forward(self, tensordict):
        ndim = tensordict.ndim
        training = self.module.training
        self.module.eval()
        td = torch.vmap(self.module, (ndim - 1,))(tensordict)
        self.module.train(training)
        # vmap sends this dim to the beginning so we need to send it back where it belongs
        td = td.permute(*range(1, ndim), 0)
        return tensordict.update(td)

In [53]:
adv_fn = GAE(
    value_network=VmapCritic(critic), gamma=0.99, lmbda=0.95, average_gae=True
)

Typical PPO loop should look something like

```python
for data in collector:
    for epoch in range(n_epochs):
        advantage(data)
        replay_buffer.extend(data)
        for batch in replay_buffer:
            loss = ppo_loss(batch)  
            loss.backward()
            optim.step()
```

In [54]:
lr = config["learning_rate"]
wd = config["weight_decay"]
beta1 = config["beta1"]
beta2 = config["beta2"]

optimizer = torch.optim.AdamW(
    loss_fn.parameters(), lr=lr, weight_decay=wd, betas=(beta1, beta2)
)

In [74]:
done = td["next", "done"]
mask = torch.zeros_like(done)
mask[..., 1:, :] = done[..., :-1, : ] # shift by one
mask = ~mask.cumsum(-2).bool().squeeze()

In [None]:
rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(config["episode_length"] * config["batch_size"]),
    batch_size=config["ppo_batch_size"],
    sampler=SamplerWithoutReplacement(),
)
losses = []

for _ in trange(min(5, config["max_iters"])):
    # form batch
    batch = next(tdl)
    generated = generate(batch)
    td = create_rollout_td(batch, generated, reward_model)
    with torch.no_grad():
        adv_fn(td)
    rb.extend(td.reshape(-1))

    for j, batch in enumerate(rb):
        loss_vals = loss_fn(batch.to(config["device"]))

        loss_val = sum(
            value for key, value in loss_vals.items() if key.startswith("loss")
        )
        loss_val.backward()
        losses.append(loss_val.detach().cpu())
        gn = torch.nn.utils.clip_grad_norm_(loss_fn.parameters(), grad_clip)
        optimizer.step()
        optimizer.zero_grad()

In [None]:
adv_fn

In [None]:
batch = next(tdl)

In [None]:
batch

In [None]:
reward_model(batch.transformer_data.input_ids, batch.transformer_data.attention_mask)