In [14]:
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm

from instruct_goose.agent import Agent
from instruct_goose.reward import RewardModel, LitRewardModel, RewardLoss
from instruct_goose.dataset import PairDataset, PromptDataset
from instruct_goose.utils import load_yaml

In [11]:
config = load_yaml("../configs/sentiment_config.yml")
reward_checkpoint = config["reward_model"]["model_path"]
reward_data_path = config["reward_data"]["data_path"]
batch_size = config["reward_data"]["batch_size"]

In [3]:
reward_dataset = load_dataset(reward_data_path)
tokenizer = AutoTokenizer.from_pretrained(reward_checkpoint)
tokenizer.pad_token = tokenizer.eos_token

Using custom data configuration Dahoas--rm-static-576a4467763bb58a
Found cached dataset parquet (/Users/education/.cache/huggingface/datasets/Dahoas___parquet/Dahoas--rm-static-576a4467763bb58a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 2/2 [00:00<00:00, 241.66it/s]


In [4]:
small_reward_dataset, _ = random_split(reward_dataset["train"], [10, len(reward_dataset["train"]) - 10])

In [5]:
reward_model = RewardModel(reward_checkpoint)

In [6]:
# pair_dataset = PairDataset(reward_dataset["train"], tokenizer, max_length=1024)
pair_dataset = PairDataset(small_reward_dataset, tokenizer, max_length=1024)

100%|██████████| 10/10 [00:00<00:00, 401.83it/s]


In [7]:
pair_dataset[0]

(tensor([[48902,    25,  3878,  ..., 50256, 50256, 50256]]),
 tensor([[1, 1, 1,  ..., 0, 0, 0]]),
 tensor([[48902,    25,   314,  ..., 50256, 50256, 50256]]),
 tensor([[1, 1, 1,  ..., 0, 0, 0]]))

In [8]:
dataloader = DataLoader(pair_dataset, batch_size=2, shuffle=True)

In [9]:
reward_loss = RewardLoss()
lit_reward = LitRewardModel(reward_model, reward_loss)

In [10]:
trainer = pl.Trainer(max_epochs=3, log_every_n_steps=1)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [11]:
trainer_fit = trainer.fit(lit_reward, dataloader)


  | Name      | Type        | Params
------------------------------------------
0 | model     | RewardModel | 124 M 
1 | loss_func | RewardLoss  | 0     
------------------------------------------
124 M     Trainable params
0         Non-trainable params
124 M     Total params
497.762   Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Epoch 2: 100%|██████████| 5/5 [01:28<00:00, 17.67s/it, loss=-0.248, v_num=1]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 5/5 [01:29<00:00, 17.92s/it, loss=-0.248, v_num=1]


In [9]:
for i, batch in tqdm(enumerate(dataloader)):
    chosen_input_ids, chosen_attention_mask, rejected_input_ids, rejected_attention_mask = batch
    
    chosen_rewards = reward_model(chosen_input_ids, chosen_attention_mask)
    rejected_rewards = reward_model(rejected_input_ids, rejected_attention_mask)
    
    pass

5it [01:13, 14.73s/it]


In [14]:
chosen_rewards

tensor([[2.4260, 8.0676, 6.4001,  ..., 8.4922, 6.5932, 1.5312],
        [2.3436, 7.9532, 8.3668,  ..., 9.5768, 9.3080, 8.9846]],
       grad_fn=<ReshapeAliasBackward0>)

In [21]:
chosen_rewards.shape

torch.Size([2, 1024])

In [22]:
chosen_rewards[:, -1]

tensor([1.5312, 8.9846], grad_fn=<SelectBackward0>)

In [18]:
import einops
last_token = einops.rearrange(chosen_rewards, "b n -> (b n)")

In [20]:
last_token.shape

torch.Size([2048])

In [26]:
import torch
hiddens = torch.randn([2, 1, 1024, 1])

In [27]:
hiddens.shape

torch.Size([2, 1, 1024, 1])

In [28]:
hiddens[0][0][-1]

tensor([0.0683])

### Training a language model to aligns with human preferences

In [12]:
agent_data_path = config["agent_data"]["data_path"]
agent_path = config["model"]["model_path"]
model_tokenizer_path = config["model"]["tokenizer_path"]

In [13]:
model_tokenizer_path, agent_path

('mrm8488/bert-mini2bert-mini-finetuned-cnn_daily_mail-summarization',
 'mrm8488/bert-mini2bert-mini-finetuned-cnn_daily_mail-summarization')

In [None]:
agent = Agent()