In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

import torch
from torch.utils.data import DataLoader, random_split

from instruct_goose.agent import Agent
from instruct_goose.env import TextEnv
from instruct_goose.trainer import RLHFTrainer
from instruct_goose.utils import create_reference_model, ReplayBuffer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("imdb", split="train")
model = AutoModelForCausalLM.from_pretrained("gpt2")
ref_model = create_reference_model(model)

tokenizer = AutoTokenizer.from_pretrained("gpt2")
eos_token_id = tokenizer.eos_token_id

Found cached dataset imdb (/Users/education/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


In [3]:
# model.config

In [4]:
eos_token_id

50256

In [5]:
tokenizer.decode(torch.tensor([50256]))

'<|endoftext|>'

In [6]:
dataset, _ = random_split(dataset, [10, len(dataset) - 10])
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [7]:
agent = Agent(model)
replay_buffer = ReplayBuffer()

### Train the RL-based Language model

In [8]:
for prompt in train_dataloader:
    
    print(f"prompt={prompt['text'][0]}")
    env = TextEnv(model, tokenizer, observation_input=prompt["text"][0])
    state = env.reset()
    done = False

    epoch = 1
    while done is not True:
        with torch.no_grad():
            action, log_prob, entropy, value = agent(state)
            next_state, reward, terminated, truncated, info, done = env.step(action)
            
            predicted_token_ids = torch.tensor([env.predicted_token_ids]).squeeze(dim=0)
            generated_text = tokenizer.decode(predicted_token_ids)
            
            print(f"epoch={epoch}, reward={reward}")
            print(generated_text)
            epoch += 1
        
        replay_buffer.append(state, action, log_prob, value, reward, done)
        if done:
            print("----------- DONE-----------")
            print(f"value={value}, reward={reward}, done={done}")
            break
        else:
            state = next_state

prompt=I saw this film at the Taos Film Festival last year, and was just overwhelmed by it. It's a rich, warm novel brought to the screen, beautifully acted, and well directed. More than anything, it reminded me of the films of David Lean, both in its ability to handle a complex story, and its knack for creating powerful scenes that affect you on several different levels. The best movie I've seen in years.
epoch=1, reward=0
I
epoch=2, reward=0
IThis
epoch=3, reward=0
IThisIn
epoch=4, reward=0
IThisInAl
epoch=5, reward=0
IThisInAlSteam
epoch=6, reward=0
IThisInAlSteamFrom
epoch=7, reward=0
IThisInAlSteamFromMr
epoch=8, reward=0
IThisInAlSteamFromMrI
epoch=9, reward=0
IThisInAlSteamFromMrIAre
epoch=10, reward=0
IThisInAlSteamFromMrIAreMuch
epoch=11, reward=0
IThisInAlSteamFromMrIAreMuch5
epoch=12, reward=0
IThisInAlSteamFromMrIAreMuch5Bill
epoch=13, reward=0
IThisInAlSteamFromMrIAreMuch5BillJeremy
epoch=14, reward=0
IThisInAlSteamFromMrIAreMuch5BillJeremyWelcome
epoch=15, reward=0
IThisI

KeyboardInterrupt: 