In [None]:
import warnings, os
os.chdir("..")
warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
import utils
import gym, babyai_text

import torch
from transformers import AutoTokenizer
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead, create_reference_model

from pprint import pprint
from tqdm.notebook import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = PPOConfig(
    model_name="HuggingFaceTB/SmolLM2-135M-Instruct",
    batch_size=16,
    mini_batch_size=16,
    optimize_cuda_cache=True,
)

In [None]:
env = gym.make("BabyAI-MixedTrainLocal-v0")
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
reference_model = create_reference_model(model, num_shared_layers=6) # TODO: choose an appropriate number of shared layers
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

In [None]:
trainer = PPOTrainer(
    model=model,
    ref_model=reference_model,
    tokenizer=tokenizer,
    config=config,
)

In [None]:
generation_kwargs = {
    "max_new_tokens": 20,
    "do_sample": True,
    "top_k": 50,
    "top_p": 0.95,
    "temperature": 0.8,
    "return_prompt": False
}

In [None]:
def sample_trajectory(
    env,
    trainer,
    tokenizer,
    generation_kwargs,
    max_invalid_actions=8,
    invalid_action_penalty=0.1,
):
    """
    Sample a trajectory from the environment using the current policy.
    
    Args:
        env: The environment to sample from.
        trainer: The PPO trainer containing the current policy model.
        tokenizer: The tokenizer corresponding to the model.
        generation_kwargs: The generation arguments for the model.
        max_invalid_actions: The maximum number of invalid actions before truncating the episode.
        invalid_action_penalty: The penalty for each invalid action.
    
    Returns:
        trajectory: A list of dictionaries containing the observations, actions, and rewards.
    """
    trajectory = []
    messages = []

    obs, info = env.reset()
    done, truncated = False, False
    mission = obs["mission"]
    text_obs = "\n".join(info["descriptions"])

    system_prompt = utils.get_system_prompt()
    system_prompt = system_prompt.replace("{goal}", mission)

    messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": text_obs})

    invalid_actions_this_episode = 0
    while not (done or truncated):
        
        old_text_obs = text_obs
        inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").squeeze(0).to(device)
        outputs = trainer.generate(inputs, **generation_kwargs)[0]

        action_text = tokenizer.decode(outputs)
        action_text = action_text.split("assistant\n")[-1].strip()
        messages.append({"role": "assistant", "content": action_text})

        if action_text not in utils.text_to_action:
            invalid_actions_this_episode += 1
            if invalid_actions_this_episode >= max_invalid_actions: 
                truncated = True
            text_obs = "You entered an invalid action, the valid actions are: " + str(list(utils.text_to_action.keys()))
            reward = -invalid_action_penalty
        else:
            action = utils.text_to_action[action_text]
            obs, reward, done, info = env.step(action)
            text_obs = "\n".join(info["descriptions"])
        
        messages.append({"role": "user", "content": text_obs})
        trajectory.append({
            "inputs": inputs,
            "outputs": outputs,
            "reward": reward
        })

    if not truncated:
        final_reward = trajectory[-1]["reward"]
        for experience in trajectory:
            experience["reward"] += final_reward
    
    return trajectory

trajectory = sample_trajectory(env, trainer, tokenizer, generation_kwargs)
pprint(trajectory)

In [None]:
trajectory[0]['outputs']

In [None]:
trajectory[0]['inputs']

In [None]:
tokenizer.decode(trajectory[0]['outputs'])

In [None]:
trajectory[0]['reward']

In [None]:
trajectories = []
while len(trajectories) < config.batch_size:
    trajectories.extend(sample_trajectory(env, trainer, tokenizer, generation_kwargs))
    print(f"Collected {len(trajectories)} experiences")

In [None]:
inputs = [experience["inputs"] for experience in trajectories]
outputs = [experience["outputs"] for experience in trajectories]
rewards = [torch.tensor(experience["reward"], device=device) for experience in trajectories]

stats = trainer.step(inputs, outputs, rewards)
trainer.log_stats(stats)
pprint(stats)