In [None]:
import os
# if not in root dir, change to root dir
if os.path.basename(os.getcwd()) != "rl-llm":
    os.chdir("..")

In [2]:
import utils
import gym, babyai_text
import torch
from transformers import PreTrainedTokenizer, AutoTokenizer
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead, create_reference_model
from typing import Dict, List, Any, Tuple
from rich.pretty import pprint

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = PPOConfig(batch_size=4, mini_batch_size=4)
model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
ref_model = create_reference_model(model)

trainer = PPOTrainer(config, model, ref_model, tokenizer)

generation_kwargs = {
    "max_new_tokens": 20,
    "do_sample": True,
    "top_k": 10,
    "top_p": 0.95,
    "temperature": 0.8,
}



In [4]:
class EnvManager:

    def __init__(self, env: gym.Env, invalid_action_penalty: float = -0.1, consecutive_invalid_actions_allowed: int = 5):
        self.env = env
        self.invalid_action_penalty = invalid_action_penalty
        self.consecutive_invalid_actions_allowed = consecutive_invalid_actions_allowed
        self.consecutive_invalid_actions = 0
    
    def reset(self) -> Tuple[str, str]:
        self.consecutive_invalid_actions = 0
        obs, info = self.env.reset()
        mission = obs["mission"]
        text_obs = "\n".join(info["descriptions"])
        return mission, text_obs
    
    def step(self, text_action: str) -> Tuple[str, float, bool]:
        action = utils.text_to_action.get(text_action, None)
        if action is None:
            self.consecutive_invalid_actions += 1
            text_obs = "You entered an invalid action, the valid actions are: " + str(list(utils.text_to_action.keys()))
            reward = self.invalid_action_penalty
            done = self.consecutive_invalid_actions >= self.consecutive_invalid_actions_allowed
        else:
            obs, reward, done, info = self.env.step(action)
            text_obs = "\n".join(info["descriptions"])
        return text_obs, reward, done

In [5]:
num_envs = 4
experiences_needed = 10

query_tensors_all, response_tensors_all, rewards_all = [], [], [] # list of all experiences we have collected so far

envs = [EnvManager(gym.make("BabyAI-MixedTrainLocal-v0", seed=i)) for i in range(num_envs)]
missions, text_obss = zip(*[env.reset() for env in envs])

messagess = [[] for _ in range(num_envs)]
rewardss = [[] for _ in range(num_envs)]
system_prompt_template = utils.get_system_prompt()
for messages, mission, text_obs in zip(messagess, missions, text_obss):
    system_prompt = system_prompt_template.replace("{goal}", mission)
    messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": text_obs})

while len(rewards_all) < experiences_needed:

    query_tensors = tokenizer.apply_chat_template(
        messagess, 
        return_tensors="pt", 
        add_generation_prompt=True, 
        padding="longest",
        padding_side="left",
    ).to(device)
    query_tensors = [tensor for tensor in query_tensors]
    print(f"query_tensors: {[tensor.shape for tensor in query_tensors]}")

    response_tensors = trainer.generate(query_tensors, **generation_kwargs, return_prompt=False)
    print(f"response_tensors: {[tensor.shape for tensor in response_tensors]}")

    action_texts = tokenizer.batch_decode(
        response_tensors, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )
    for messages, action_text in zip(messagess, action_texts):
        messages.append({"role": "assistant", "content": action_text})
    print(f"action_texts: {action_texts}")

    break

    # step the envs
    # for each env take a step
    # if the episode is done:
    # - add the final reward to all previous rewards
    # - append the query_tensors, response_tensors, and rewards from that episode to the list of all experiences that we have collected so far
    # - reset the env(s) that was done
    # - keep collecting experiences until we have enough

    # Thought dump:
    # NOTE 1: Each time an episode ends, we add all the experiences from that episode to the list of all experiences we have collected so far
    # This means that we could end up with more than experiences_needed experiences in the list.
    # if that happends we just take the first experiences_needed experiences from the list
    # NOTE 2: There is no need for multiprocessing since the overhead from the environment is not that much. 
    # All we need to do for now is use batch inference with the LLM so that each forward pass of the LLM allows us to collect num_processes experiences
    # (that get appended once the episode ends)
    # NOTE 3: The resulting function should be a drop in replacement to the while loop in train.py. 
    # The only additional changes are probably just in setup_training.py which would setup multiple environments (and possibly env managers for convenience instead of one)
    # NOTE 4: Different envs could be done at different times, even though while debugging with SmolLM2-135M they will all probably end at the same time due to invalid actions.
    # NOTE 5: The root dir can be treated as a module. From the root dir you can run pip install -e . to install the package in editable mode, 
    # so we don't need to sys.path.append every time we want to import something.

  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


query_tensors: [torch.Size([347]), torch.Size([347]), torch.Size([347]), torch.Size([347])]
response_tensors: [torch.Size([20]), torch.Size([20]), torch.Size([20]), torch.Size([20])]
action_texts: ['To the right, 3 steps right,\nTo the left, 4 steps left,\n', 'I see you are trying to interact with the key and the ball, and it seems like you are', 'You see a purple key 1 step left, 1 step forward. You see a grey box', "The key to picking up a puzzle is using the 'toggle' action to interact with a certain"]


In [6]:
# expected function signature
def collect_trajectories_batched(
    envs: List[EnvManager], 
    trainer: PPOTrainer, 
    tokenizer: PreTrainedTokenizer, 
    generation_kwargs: Dict[str, Any], 
    experiences_needed: int,
    consecutive_invalid_actions_allowed: int = 5,
    invalid_action_penalty: float = -0.1,
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
    pass