In [2]:
import logging
from datetime import datetime
from typing import Dict, List, Any, Tuple
from rich.pretty import pprint
from types import SimpleNamespace
from tqdm import tqdm
import os, sys

import gym
import babyai_text
import torch
from transformers import PreTrainedTokenizer, AutoTokenizer
from trl import (
    PPOConfig,
    PPOTrainer,
    AutoModelForCausalLMWithValueHead,
    create_reference_model,
)

In [None]:
args = {
        # Training config
        "model_id": "meta-llama/Llama-3.2-3B-Instruct",
        "env_id": "BabyAI-MixedTrainLocal-v0",
        "num_shared_layers": None,
        "num_steps_train": 2000,
        "num_envs": 4,
        "seed" : 30,
        # PPO config
        "batch_size": 4,
        "mini_batch_size": 4,
        # "gradient_accumulation_steps": 4, 
        "optimize_device_cache": True,
        "early_stopping": False,
        # Env config
        "consecutive_invalid_actions_allowed": 5,
        "invalid_action_penalty": -0.1,
        "max_steps_per_episode": 100,
        # Generation kwargs
        "max_new_tokens": 10,
        "do_sample": True,
        "temperature": 0.8,
        "top_k": 20,
        "top_p": 0.90,
        # PEFT config
        "use_peft": True,
        "lora_r": 32,
        "lora_alpha": 32,
        "lora_dropout": 0.05,
        "lora_bias": "none",
    }
args = SimpleNamespace(**args)  # same type as argparse would return

In [4]:
envs = []
for i in range(args.num_envs):
    env = gym.make(args.env_id)
    env.seed(100 * args.seed + i)
    envs.append(env)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(args.model_id, padding_side="left")
model = AutoModelForCausalLMWithValueHead.from_pretrained(args.model_id)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [6]:
ref_model = create_reference_model(model, num_shared_layers=args.num_shared_layers)
config = PPOConfig(
        batch_size=args.batch_size,
        mini_batch_size=args.mini_batch_size,
        # gradient_accumulation_steps=args.gradient_accumulation_steps,
        optimize_device_cache=args.optimize_device_cache,
        early_stopping=args.early_stopping
    )
trainer = PPOTrainer(config, model, ref_model, tokenizer)



In [47]:
generation_kwargs = {
        "max_new_tokens": args.max_new_tokens,
        "do_sample": args.do_sample,
        "top_k": args.top_k,
        "top_p": args.top_p,
        "temperature": args.temperature,
    }

In [8]:
system_prompt_msg = """You are an agent playing a simple navigation game. Your goal is to **{goal}**. The following are the possible actions you can take in the game, followed by a short description of each action:

turn left: turn to the left,
turn right: turn to the right,
go forward: take one step forward,
pick up: pick up the object below you,
drop: drop the object that you are holding,
toggle: manipulate the object in front of you.

In a moment I will present you an observation.

Tips:
- Once the desired object you want to interact or pickup in front of you, you can use the 'toggle' action to interact with it.
- It doesn't make sense to repeat the same action over and over if the observation doesn't change.

PLAY!
"""

In [23]:
num_envs = args.num_envs
obss, infos = zip(*[env.reset() for env in envs])
missions = [obs["mission"] for obs in obss]
text_obss = ['\n'.join(info['descriptions']) for info in infos]
contexts = [[] for _ in range(num_envs)]
for messages, mission, text_obs in zip(contexts, missions, text_obss):
    system_prompt = system_prompt_msg.replace("{goal}", mission)
    messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": text_obs})

In [27]:
contexts[0]

[{'role': 'system',
  'content': "You are an agent playing a simple navigation game. Your goal is to **pick up the purple key**. The following are the possible actions you can take in the game, followed by a short description of each action:\n\nturn left: turn to the left,\nturn right: turn to the right,\ngo forward: take one step forward,\npick up: pick up the object below you,\ndrop: drop the object that you are holding,\ntoggle: manipulate the object in front of you.\n\nIn a moment I will present you an observation.\n\nTips:\n- Once the desired object you want to interact or pickup in front of you, you can use the 'toggle' action to interact with it.\n- It doesn't make sense to repeat the same action over and over if the observation doesn't change.\n\nPLAY!\n"},
 {'role': 'user',
  'content': 'You see a wall 6 steps forward\nYou see a purple key 2 steps left and 5 steps forward\nYou see a grey key 2 steps left and 2 steps forward\nYou see a red key 2 steps left\nYou see a green ball

In [53]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
query_tenors_step = tokenizer.apply_chat_template(
    contexts, 
    return_tensors="pt", 
    add_generation_prompt=True,
    padding_side="left"
)

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

In [33]:
query_tenors_step.shape

torch.Size([4, 358])

In [57]:
query_tenors_step_trial = []
for conv in contexts:
    query_tenors_step_trial.append(tokenizer.apply_chat_template(
        conv, 
        return_tensors="pt", 
        add_generation_prompt=True
    ))

In [58]:
generation_kwargs

{'max_new_tokens': 10,
 'do_sample': True,
 'top_k': 20,
 'top_p': 0.9,
 'temperature': 0.8}

In [49]:
generated_tokens = trainer.generate(
    list(query_tenors_step),
    **generation_kwargs
)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


In [63]:
generated_tokens_trial = trainer.generate(
    query_tensor=query_tenors_step_trial,
    generation_kwargs=generation_kwargs
)

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

In [51]:
response_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

In [None]:
response_text

ssh -i ~/.ssh/id_ed25519 -p 40009 root@213.171.186.233 -L 8080:localhost:8080