In [26]:
import gymnasium as gym
from tqdm import tqdm
import sys
import os
import torch
import re
import json

rocket_dir = os.path.split(os.path.abspath(os.curdir))[:-1]
rocket_src_dir = os.path.join(*rocket_dir, 'src')
rocket_dir = os.path.join(*rocket_dir)

print(f'Appending to sys.path:', [rocket_dir, rocket_src_dir])

sys.path.append(rocket_dir)
sys.path.append(rocket_src_dir)

from src.environments.frozen_env import make_frozen_env, FrozenActions
from src.agents.Mistral_agent import MistralAgent

Appending to sys.path: ['/data/zago/rocket', '/data/zago/rocket/src']


In [27]:
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

agent = MistralAgent("mistralai/Mistral-7B-Instruct-v0.2",
                        action_enum=FrozenActions,
                         is_lora=True, 
                         padding_side="right", 
                         gradient_ckpt=False)
agent.critic.to('cuda')

Loading checkpoint shards: 100%|██████████████████| 3/3 [00:06<00:00,  2.30s/it]


trainable params: 27,262,976 || all params: 7,268,995,072 || trainable%: 0.3750583915652433


Sequential(
  (0): Linear(in_features=4096, out_features=1024, bias=True)
  (1): ReLU()
  (2): Linear(in_features=1024, out_features=512, bias=True)
  (3): ReLU()
  (4): Linear(in_features=512, out_features=1, bias=True)
)

In [28]:
runs_directory = 'runs_directory_prova'
envs = gym.vector.SyncVectorEnv(
            [make_frozen_env(
                runs_directory,
                area=8, # 8x8,
                seed=1234 + i + 1,
                size=(512, 512),
                is_slippery=False,
                fixed_orientation=True,
                save_video=False,
                save_video_every=20,
                save_stats=False,
                fov=3
            )
            for i in range(1)
            ])

In [29]:
def find_action_in_sentence_and_create_action_choices(generated_action, action_list, fixtemplate_action_choices):
        #extract the action from the generated_action.
        #substitute the action position with the other possible actions.
        #if no match is found return the defult action prompts.
        pattern = '|'.join(map(re.escape, action_list))
        match = re.search(pattern, generated_action)
        if match:
            match_found = 1
            matched_action = match.group()
            generated_action_choices = [generated_action.replace(matched_action, action) for action in action_list]
        else:
            match_found = 0
            generated_action_choices = fixtemplate_action_choices
        return generated_action_choices, match_found, matched_action

In [31]:
obs, info = envs.reset(seed=42)

_, contexts, _ = agent.create_prompt_for_action_and_value(text_description=info['obs'])
context = contexts[0]
print('_____________CONTEXT PROMPT______________')
print(context)
print('_____________CONTUINUATION______________')
continuation = agent.generate_continuation(context)
print(continuation)
print('_____________EXTRACT ACTION______________')

action_list = [a.value for a in agent.action_enum]
fixtemplate_action_choices = [f"Based on the information provided, the best action would be to {action}" for action in action_list]
action_choices, match_found, matched_action = find_action_in_sentence_and_create_action_choices(continuation, action_list, fixtemplate_action_choices)
print(matched_action)

print('_____________LLM ORACLE______________')
question = f'You must extract the action described in the following text. You must extract exactly one action, with no alternatives, from [{", ".join(action_list)}]. Text: {context}. Answer: the extracted action is'
extracted_action = agent.generate_continuation(question)
pattern = '|'.join(map(re.escape, action_list))
match = re.search(pattern, extracted_action)
print(match.group())



<s>[INST] You are an agent in a survival 2D game. You took action noop.
You see:
- a trap 1 steps to your east.
- a trap 3 steps to your south-east.
- a trap 3 steps to your south.
- the goal 14 steps to your south-east.
 What's the next best action? You must avoid the traps! Choose exactly one from [move west, move south, move east, move north] [/INST]
_____________CONTUINUATION______________
Based on the information provided, the next best action would be to move north to avoid the traps that are located to the south and south-east.</s>
_____________EXTRACT ACTION______________
move north
_____________LLM ORACLE______________
move north


### Generate dataset: we want to test accuracy of the methods

In [104]:
filename = 'dataset_with_agent_functions_500.json'

if os.path.exists(filename):
    with open(filename, 'r') as f:
        dataset = json.load(f)
    print(f'Loaded dataset with {len(dataset)} samples')
else:
    
    ############ SET DATASET SIZE
    size = 500
    
    #######################################################
    # collect data
    envs = gym.vector.SyncVectorEnv(
                [make_frozen_env(
                    runs_directory,
                    area=8, # 8x8,
                    size=(512, 512),
                    is_slippery=False,
                    fixed_orientation=True,
                    save_video=False,
                    save_video_every=20,
                    save_stats=False,
                    fov=3
                )
                for _ in range(1)
                ])

    observation, info = envs.reset()
    raw_dataset = [info['obs'][0]]

    while len(raw_dataset) < size:
        action = envs.action_space.sample()
        _, _, terminated, truncated, info = envs.step(action)

        if terminated or truncated:
            _, info = envs.reset()
        
        raw_dataset.append(info['obs'][0])

    #######################################################
    dataset = []

    for x in tqdm(raw_dataset):
        _, prompts, _ = agent.create_prompt_for_action_and_value(text_description=[x])    
        responses = [agent.generate_continuation(prompts)]
        
        # tag data with LLM (NOTE: it may contain errors!!!)    
        modified_responses = [f'[INST] You have a text which refers to a game, and it indicates the best action to take. You must extract exactly one action form [move north, move south, move east, move west] in less than 10 words. Don\'t tell me the reason, I just want the action. If there are multiple best actions, tell me the first suggested. The text is:\n{responses[0]}[/INST] Extracted action:']
        s = agent.generate_continuation(modified_responses)
            
        dataset.append({
            'text': responses[0],
            'action': s
            })
        
    with open(filename, 'w+') as f:
        json.dump(dataset, f, indent=4)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [29:45<00:00,  3.57s/it]


### Test methods

In [105]:
from thefuzz import process
from thefuzz import fuzz

first_method_correct = 0
second_method_correct = 0

for n in dataset:
    true_action = re.sub(r'.*(?P<direction>north|south|east|west).*', r'move \g<direction>', n['action'], flags=re.IGNORECASE)
    # print(f'True action: {true_action}')
    
    # first simple method
    pattern = '|'.join(map(re.escape, action_list))
    match = re.search(pattern, n['text'], flags=re.IGNORECASE)
    
    if match and match.group().lower().strip() == true_action:
        first_method_correct += 1
    
    # second method: if first fails try with fuzzy matching
    processed_action_list = [r'[^\.]*?'.join(a.split(' ')) for a in action_list]
    pattern = '|'.join(processed_action_list)
    if match is None:
        match = re.search(pattern, n['text'], flags=re.IGNORECASE) 
    
    if match:
        s = process.extractOne(match.group().lower().strip(), action_list, scorer=fuzz.token_set_ratio)[0]
    else:
        s = None
    
    if s == true_action:
        second_method_correct += 1
        
print('First method accuracy:', first_method_correct / len(dataset))
print('Second method accuracy:', second_method_correct / len(dataset))      


First method accuracy: 0.83
Second method accuracy: 0.95


In [106]:
def find_action_in_sentence_and_create_action_choices___2(generated_action, action_list, fixtemplate_action_choices):
        #extract the action from the generated_action.
        #substitute the action position with the other possible actions.
        #if no match is found return the defult action prompts.
        pattern = '|'.join(map(re.escape, action_list))
        match = re.search(pattern, generated_action, flags=re.IGNORECASE)
        if match:
            match_found = 1
            matched_action = match.group()
            generated_action_choices = [generated_action.replace(matched_action, action) for action in action_list]
            return generated_action_choices, match_found, matched_action
        
        processed_action_list = [r'[^\.]*?'.join(a.split(' ')) for a in action_list]
        pattern = '|'.join(processed_action_list)
        match = re.search(pattern, generated_action, flags=re.IGNORECASE)
        
        if match:
            match_found = 1
            matched_action = match.group()
            matched_action = process.extractOne(matched_action.lower().strip(), action_list, scorer=fuzz.token_set_ratio)[0]
            generated_action_choices = [generated_action.replace(matched_action, action) for action in action_list]
            return generated_action_choices, match_found, matched_action
        
        match_found = 0
        generated_action_choices = fixtemplate_action_choices
        matched_action = None
        
        return generated_action_choices, match_found, matched_action