In [1]:
import sys
sys.path.append("..")

In [2]:
import os
import time
import re
from dotenv import load_dotenv
from openai import OpenAI

from agents.alfworld_llm_policy import ALFWorldLLMPolicyAgent
from environments.ALFWorldEnvironment import ALFWorldEnvironment

load_dotenv()

client = OpenAI(
    base_url=os.getenv("CUSTOM_BASE_URL"),
    api_key=os.getenv("CUSTOM_API_KEY")
)

  from tqdm.autonotebook import tqdm, trange


In [3]:
def query_llm(system_prompt, user_prompt, model):
    '''
    Query the LLM with the user prompt
    '''
    messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
    
    while True:
        try:
            response = client.chat.completions.create(model=model, messages=messages)
            break
        except Exception as e:
            print(f"Error calling API: {e}, retrying...")
            time.sleep(1)
    
    # grab the content of the first choice (only one choice is returned)
    return response.choices[0].message.content

def save_to_file(response):
    with open('output_logs.txt', 'a') as logs:
        print('----------------------------------------', file=logs)
        print(response, file=logs)
    
def get_user_prompt(state, action):
    user_prompt = "**Current State:**\n"
    user_prompt += state
    user_prompt += "\n**Action:**"
    user_prompt += action
    user_prompt += "\n"
    return user_prompt

def extract_state(response: str):
    '''
    Extract the next state from the LLM response
    '''
    extract_state_regex = r"next state:(.*?)```"
    extract_state_regex_fallback = [r"\*\*next state:\*\*(.*)", r"next state:(.*)"]
    
    match = re.search(extract_state_regex, response, re.DOTALL | re.IGNORECASE)
    if match is not None:
        next_state = match.group(1)
        return next_state, "success"
    else:
        # if debug:
        #     print("Warning: No match found, trying fallback regex...")
        
        for regex in extract_state_regex_fallback:
            match = re.search(regex, response, re.DOTALL | re.IGNORECASE)
            if match is not None:
                next_state = match.group(1)
                return next_state, "success on fallback regex"
        else:
            print("Error: No match found with fallback regex, using full response as next state")
            return response, "error"

In [5]:
model = 'open-mixtral-8x22b'
# model = 'mistral-large-2407'

with open('../prompts/prompt_alfworld_transition.txt', 'r') as f:
    system_prompt = f.read()

state = "You arrive at desk 1. On the desk 1, you see a book 1, a notepad 1, a pen 1, and a pencil 1."
action = "take book 1 from desk 1"
user_prompt = get_user_prompt(state, action)

In [6]:
response = query_llm(system_prompt, user_prompt, model)
print(response)

**Reasoning**
The given action follows the action format for **Pick up an Object**. Thus, in the next state, we expect the agent to pick up the book 1 from desk 1.

```plaintext
Next State:
You pick up the book 1 from the desk 1.
```


In [None]:
env = ALFWorldEnvironment(config_path='../configs/alfworld_env.yaml')
agent = ALFWorldLLMPolicyAgent(env, device='cuda', llm_model='gpt-4o-mini', prompt_buffer_prefix="alfworld", env_params={ "system_prompt_path": "../prompts/prompt_alfworld_policy.txt" })

obs, _ = env.reset()
state = obs
print('Initial State: ', state['text_state'])

num_steps = 10
for _ in range(num_steps):
    action = agent.act(state)
    print('Action: ', action)

    user_prompt = get_user_prompt(state['text_state'], action)
    pred_state = query_llm(system_prompt, user_prompt, model)
    pred_state, _ = extract_state(pred_state)
    print('Predicted Next State', pred_state)

    state, reward, done, _, info = env.step(action)
    print('Actual Next State', state['text_state'])
    print('-----------------------------\n')


Initializing AlfredTWEnv...


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00, 416.10it/s]

Overall we have 1 games in split=train
Training with 1 games





Initial State:  -= Welcome to TextWorld, ALFRED! =-

You are in the middle of a room. Looking quickly around you, you see a bed 1, a desk 1, a drawer 3, a drawer 2, a drawer 1, a garbagecan 1, a laundryhamper 1, a shelf 2, a shelf 1, and a sidetable 1.

Your task is to: put a alarmclock in desk.
Action:  go to sidetable 1
Predicted Next State 
You arrive at sidetable 1. On the sidetable 1, you see a alarmclock 1.

Actual Next State You arrive at loc 20. On the sidetable 1, you see a alarmclock 3, a alarmclock 2, a alarmclock 1, a creditcard 1, a desklamp 1, a keychain 1, a pen 3, a pen 2, and a pencil 3.
-----------------------------

Action:  take alarmclock 1 from sidetable 1
Predicted Next State 
You pick up the alarmclock 1 from the sidetable 1.

Actual Next State You pick up the alarmclock 1 from the sidetable 1.
-----------------------------

Action:  put alarmclock 1 in/on sidetable 1
Predicted Next State 
You put the alarmclock 1 on the sidetable 1.

Actual Next State You put t