In [1]:
import os
# if not in root dir called rl-llm, go up one directory
if os.path.basename(os.getcwd()) != 'rl-llm':
    os.chdir('..')

In [2]:
import utils
from env_manager import EnvManager
import gym, babyai_text
import torch
from transformers import PreTrainedTokenizer, AutoTokenizer
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead, create_reference_model
from typing import Dict, List, Any, Tuple
from rich.pretty import pprint
import sys, io

In [3]:
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = PPOConfig(batch_size=4, mini_batch_size=4)
model_id = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    import warnings
    warnings.filterwarnings("ignore", message="Setting `pad_token_id` to `eos_token_id`.*")
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
ref_model = create_reference_model(model)
trainer = PPOTrainer(config, model, ref_model, tokenizer)
generation_kwargs = {
    "max_new_tokens": 20,
    "do_sample": True,
    "top_k": 10,
    "top_p": 0.95,
    "temperature": 0.8,
}



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [4]:
env_id = "BabyAI-GoToObj-v0" # "BabyAI-MixedTrainLocal-v0" # "BabyAI-GoToObj-v0"
num_envs = 4
env_managers = [EnvManager(gym.make(env_id, seed=i)) for i in range(num_envs)]
batch_size = 128

In [5]:
# Initialize variables
Q, R, W = [], [], []
query_tensors_per_episode = [[] for _ in range(num_envs)]
response_tensors_per_episode = [[] for _ in range(num_envs)]
rewards_per_episode = [[] for _ in range(num_envs)]

# Reset envs and initialize contexts
contexts = [[] for _ in range(num_envs)]
system_prompt_template = utils.get_system_prompt()
missions, text_obss = zip(*[env.reset() for env in env_managers])
for i, (context, mission, text_obs) in enumerate(zip(contexts, missions, text_obss)):
    system_prompt = system_prompt_template.replace("{goal}", mission)
    context.append({"role": "system", "content": system_prompt})
    context.append({"role": "user", "content": text_obs})
    if i==0: pprint(context[0]['content']); pprint(context[1]['content'])

while len(W) < batch_size:

    query_tensors_step = []
    for context in contexts:
        query_tensor = tokenizer.apply_chat_template(context, return_tensors="pt", add_generation_prompt=True).squeeze(0)
        query_tensors_step.append(query_tensor)
    
    response_tensors_step = trainer.generate(
        query_tensors_step,
        generation_kwargs=generation_kwargs,
        return_prompt=False,
    )
    response_texts_step = tokenizer.batch_decode(response_tensors_step, skip_special_tokens=True)
    
    for i, (env, response_text) in enumerate(zip(env_managers, response_texts_step)):

        query_tensors_per_episode[i].append(query_tensors_step[i])
        response_tensors_per_episode[i].append(response_tensors_step[i])

        text_obs, reward, done = env.step(response_text)
        rewards_per_episode[i].append(reward)
        contexts[i].append({"role": "assistant", "content": response_text})
        contexts[i].append({"role": "user", "content": text_obs})

        if i==0: pprint(response_text); pprint(text_obs); pprint(f"REWARD: {reward}"); pprint(f"DONE: {done}")

        if done:
            # Discount future rewards if successful
            success = True if reward > 0 else False
            for j in range(len(rewards_per_episode[i])-1):
                rewards_per_episode[i][j] += rewards_per_episode[i][-1]
            # Append trajectory to Q, R, W
            Q.extend(query_tensors_per_episode[i])
            R.extend(rewards_per_episode[i])
            W.extend(rewards_per_episode[i])
            # Reset env and contexts
            query_tensors_per_episode[i] = []
            response_tensors_per_episode[i] = []
            rewards_per_episode[i] = []
            mission, text_obs = env.reset()
            system_prompt = system_prompt_template.replace("{goal}", mission)
            contexts[i] = [{"role": "system", "content": system_prompt},
                           {"role": "user", "content": text_obs}]
            if i==0: pprint(contexts[i][0]['content']); pprint(contexts[i][1]['content'])
            

  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")


You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  logger.deprecation(
  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
