# Test accuracy of LLM Transition Model

In [1]:
#enable autoreload
%load_ext autoreload
%autoreload 2

In [31]:
import sys

sys.path.append('..')

from environments.ElevatorEnvironment import ElevatorEnvironment
from agents.llmzero import LLMTransitionModel
from agents.random_agent import RandomAgent
from agents.elevator_expert import ElevatorExpertPolicyAgent

import numpy as np
import random
import matplotlib.pyplot as plt

Load default elevator environment

In [4]:
env = ElevatorEnvironment()

c:\Users\ianch\miniconda3\envs\aiplanning\Lib\site-packages\pyRDDLGym\Examples c:\Users\ianch\miniconda3\envs\aiplanning\Lib\site-packages\pyRDDLGym\Examples\manifest.csv
Available example environment(s):
CartPole_continuous -> A simple continuous state-action MDP for the classical cart-pole system by Rich Sutton, with actions that describe the continuous force applied to the cart.
CartPole_discrete -> A simple continuous state MDP for the classical cart-pole system by Rich Sutton, with discrete actions that apply a constant force on either the left or right side of the cart.
Elevators -> The Elevator domain models evening rush hours when people from different floors in a building want to go down to the bottom floor using elevators.
HVAC -> Multi-zone and multi-heater HVAC control problem
MarsRover -> Multi Rover Navigation, where a group of agent needs to harvest mineral.
MountainCar -> A simple continuous MDP for the classical mountain car control problem.
NewLanguage -> Example with

<op> is one of {<=, <, >=, >}
<rhs> is a deterministic function of non-fluents or constants only.
>> ( sum_{?f: floor} [ elevator-at-floor(?e, ?f) ] ) == 1


Define agent to generate trajectories for testing and initialize transition model

In [5]:
SEED = 42
np.random.seed(SEED)
random.seed(SEED)

In [6]:
random_agent = RandomAgent(env, seed=SEED)
expert_agent = ElevatorExpertPolicyAgent()

params = {
    "env_params": {
        "system_prompt_path": "../prompts/prompt_elevator_transition.txt",
        "extract_state_regex": r"next state:(.*?)```",
        "extract_regex_fallback": [r"\*\*next state:\*\*(.*)", r"next state:(.*)"],
    },
    "load_prompt_buffer_path": None, # update this path to the path of the saved prompt buffer   
}

transition_model = LLMTransitionModel(**params, prompt_buffer_prefix="elevator_transition", debug=True)

Generate trajectories for both agents

In [7]:
# random agent
state, _ = env.reset(SEED)
done = False

random_agent_trajectory = []

while not done:
    action = random_agent.act(state)
    next_state, reward, done, _, _ = env.step(action)
    random_agent_trajectory.append((state, action, reward, next_state, done))
    state = next_state
    
# expert agent
state, _ = env.reset(SEED + 1)  # use a different seed for the expert agent
done = False

expert_agent_trajectory = []

while not done:
    action = expert_agent.act(state)
    next_state, reward, done, _, _ = env.step(action)
    expert_agent_trajectory.append((state, action, reward, next_state, done))
    state = next_state

In [8]:
trajectories_combined = random_agent_trajectory + expert_agent_trajectory

### Sanity check test

In [33]:
# test the transition model
trajectory = random.choice(trajectories_combined)
state_test, action, next_state_test = trajectory[0], trajectory[1], trajectory[3]
state_text, action_text, next_state_text = env.state_to_text(state_test), env.action_to_text(action), env.state_to_text(next_state_test)

print("State text:\n", state_text)
print("Action text:", action_text)
print("Next state text:\n", next_state_text)

State text:
 People waiting at floor 2: 3
People waiting at floor 3: 3
People waiting at floor 4: 1
People waiting at floor 5: 3
Elevator at floor 3.
There are 4 people in the elevator.
Elevator is moving down.
Elevator door is closed.

Action text: move
Next state text:
 People waiting at floor 2: 3
People waiting at floor 3: 3
People waiting at floor 4: 1
People waiting at floor 5: 3
Elevator at floor 2.
There are 4 people in the elevator.
Elevator is moving down.
Elevator door is closed.



In [63]:
predicted_next_state_text, status = transition_model.get_next_state(state_text, action_text)



In [62]:
print(predicted_next_state_text)


People waiting at floor 2: 2
People waiting at floor 3: 3
People waiting at floor 4: 0
People waiting at floor 5: 2
Elevator at floor 2.
There are 4 people in the elevator.
Elevator is moving down.
Elevator door is closed.


### Get results for all trajectories

In [12]:
import re

def try_get_int(text):
    try:
        return int(text)
    except:
        return None

def parse_state_text(state_text):
    lines = state_text.split("\n")
    people_waiting = []
    elevator_floor = None
    people_in_elevator = None
    elevator_direction = None
    door_status = None
    
    for line in lines:
        line = line.strip().lower()
        # assuming the lines are in order
        if "people waiting at floor" in line:
            split = line.split(":")
            # print("waiting", split)
            waiting = split[1].strip() if len(split) == 2 else None
                
            waiting = try_get_int(waiting)            
            people_waiting.append(waiting)
            
        if "elevator at floor" in line:
            split = line.split("floor")
            # print("floor", split)
            elevator_floor = split[1].strip() if len(split) == 2 else None
            if '.' in elevator_floor:
                elevator_floor = elevator_floor.split(".")[0]
            elevator_floor = try_get_int(elevator_floor)
            
        if "people in the elevator" in line:
            #use regex to find the only number in the line
            people_in_elevator = re.findall(r'\d+', line)
            people_in_elevator = int(people_in_elevator[0]) if len(people_in_elevator) else None
            
        if "elevator is moving" in line:
            elevator_direction = "up" if "up" in line else "down" if "down" in line else None
            
        if "elevator door is" in line:
            door_status = "open" if "open" in line else "closed" if "closed" in line else None
            
    # check for parse failure
    failure_flag = False
    
    if len(people_waiting) == 0 or None in people_waiting:
        failure_flag = True
    if None in [elevator_floor, people_in_elevator, elevator_direction, door_status]:
        failure_flag = True
        
    out = {
        "people_waiting": people_waiting,
        "elevator_floor": elevator_floor,
        "people_in_elevator": people_in_elevator,
        "elevator_direction": elevator_direction,
        "door_status": door_status,
        "failure_flag": failure_flag
    }  
    
    return out           
        

In [13]:
# sanity check for ground truth states
states = [trajectory[0] for trajectory in trajectories_combined]
state_texts = [env.state_to_text(state) for state in states]

for state_text in state_texts:
    parsed_state = parse_state_text(state_text)
    if parsed_state["failure_flag"]:
        print("Failed to parse state text")
        print(state_text)
        print(parsed_state)
else:
    print("All states parsed successfully")
    print(parsed_state)

All states parsed successfully
{'people_waiting': [1, 1, 2, 2], 'elevator_floor': 4, 'people_in_elevator': 0, 'elevator_direction': 'up', 'door_status': 'closed', 'failure_flag': False}


In [14]:
def eval_prediction(current_state_text, next_state_text, predicted_next_state_text):
    '''
    Evaluate the correctness of the predicted next state.
    
    For people waiting:
        - check if the number of people is not above 3.
        - check if the number is equal or more than the current state.
            - If it is more, check if the elevator picked up the passengers on that floor.
        
    For elevator floor, people in elevator, elevator direction, and door status:
        - check if the predicted values are the same as the ground truth values
        
    Returns a dictionary indicating if each field is correct or not.
    '''
    
    state = parse_state_text(current_state_text)
    next_state = parse_state_text(next_state_text)
    predicted_next_state = parse_state_text(predicted_next_state_text)
            
    elevator_floor_correct = next_state["elevator_floor"] == predicted_next_state["elevator_floor"]
    people_in_elevator_correct = next_state["people_in_elevator"] == predicted_next_state["people_in_elevator"]
    elevator_direction_correct = next_state["elevator_direction"] == predicted_next_state["elevator_direction"]
    door_status_correct = next_state["door_status"] == predicted_next_state["door_status"]
    
    people_waiting_correct = False
    
    if len(next_state["people_waiting"]) == len(predicted_next_state["people_waiting"]):
        for i, (waiting, predicted_waiting) in enumerate(zip(state["people_waiting"], state["people_waiting"])):
            if predicted_waiting is None:
                people_waiting_correct = False
                break
            if  predicted_waiting > 3:
                people_waiting_correct = False
                break
            if predicted_waiting < waiting:
                #check if the elevator picked up the passengers on that floor
                #pre-requisite: previous status all correct
                if not all([elevator_floor_correct, people_in_elevator_correct, elevator_direction_correct, door_status_correct]):
                    people_waiting_correct = False
                    break
                else:
                    floor = i + 2
                    
                    cond_1 = next_state["elevator_floor"] == floor
                    cond_2 = next_state["door_status"] == "open"
                    cond_3 = predicted_waiting == 0
                    
                    if not all([cond_1, cond_2, cond_3]):
                        people_waiting_correct = False
                        break 
                                  
        else:
            people_waiting_correct = True
    
    
    has_parse_failure = predicted_next_state["failure_flag"]
    
    out = {
        "people_waiting": people_waiting_correct,
        "elevator_floor": elevator_floor_correct,
        "people_in_elevator": people_in_elevator_correct,
        "elevator_direction": elevator_direction_correct,
        "door_status": door_status_correct,
        "has_parse_failure": has_parse_failure
    }
    
    all_correct = all([v for k, v in out.items() if k != "has_parse_failure"])
    out["all_correct"] = all_correct
    
    return out    

In [16]:
# sanity test with gt next state
next_states = [trajectory[3] for trajectory in trajectories_combined]
next_state_texts = [env.state_to_text(state) for state in next_states]

for state_text, next_state_text in zip(state_texts, next_state_texts):
    predicted_next_state_text = next_state_text
    prediction_eval = eval_prediction(state_text, next_state_text, predicted_next_state_text)
    
    #should be all correct
    if not prediction_eval["all_correct"]:
        print("Failed to predict next state")
        print(state_text)
        print(next_state_text)
        print(predicted_next_state_text)
        print(prediction_eval)
        
        break
    
else:
    print("All predictions correct")

All predictions correct


In [105]:
import tqdm
import time

predicted_states = []
status_list = []
eval_results = []
corrects = 0
total = 0

pbar = tqdm.tqdm(trajectories_combined)

transition_model.debug = False

for trajectory in pbar:
    state, action, reward, next_state, done = trajectory
    state_text, action_text, next_state_text = env.state_to_text(state), env.action_to_text(action), env.state_to_text(next_state)
    
    while True:
        try:
            predicted_next_state_text, status = transition_model.get_next_state(state_text, action_text)
            break
        except:
            print("Retrying..")
            time.sleep(1)
        
    
    predicted_states.append(predicted_next_state_text)
    status_list.append(status)
    
    eval_result = eval_prediction(state_text, next_state_text, predicted_next_state_text)
    
    eval_results.append(eval_result)
    
    if eval_result["all_correct"]:
        corrects += 1
    total += 1
    
    pbar.set_description(f"Correct: {corrects}/{total}")
    

Correct: 39/61:   0%|          | 0/400 [00:00<?, ?it/s]

Error: No match found with fallback regex, using full response as next state
Error: No match found with fallback regex, using full response as next state
Error: No match found with fallback regex, using full response as next state
Error: No match found with fallback regex, using full response as next state
Error: No match found with fallback regex, using full response as next state


Correct: 40/64:  16%|█▌        | 64/400 [00:06<00:39,  8.59it/s]

Error: No match found with fallback regex, using full response as next state


Correct: 41/65:  16%|█▋        | 65/400 [00:09<01:10,  4.74it/s]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 46/73:  18%|█▊        | 73/400 [00:35<10:00,  1.84s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 48/75:  19%|█▉        | 75/400 [00:42<13:21,  2.47s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 49/76:  19%|█▉        | 76/400 [00:46<15:18,  2.83s/it]

Retrying..
Retrying..
Retrying..
Retrying..


Correct: 49/77:  19%|█▉        | 77/400 [01:03<36:26,  6.77s/it]

Retrying..


Correct: 50/79:  20%|█▉        | 79/400 [01:14<31:06,  5.82s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 52/82:  20%|██        | 82/400 [01:24<21:06,  3.98s/it]

Retrying..


Correct: 54/86:  22%|██▏       | 86/400 [01:41<22:05,  4.22s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 56/90:  22%|██▏       | 87/400 [01:45<20:17,  3.89s/it]

Error: No match found with fallback regex, using full response as next state
Error: No match found with fallback regex, using full response as next state


Correct: 57/92:  23%|██▎       | 92/400 [01:54<13:04,  2.55s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 65/104:  26%|██▌       | 104/400 [02:23<10:52,  2.20s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 71/112:  28%|██▊       | 112/400 [02:44<11:15,  2.34s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 75/117:  29%|██▉       | 117/400 [03:07<18:51,  4.00s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 81/127:  32%|███▏      | 127/400 [03:50<19:17,  4.24s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 85/134:  33%|███▎      | 132/400 [04:07<14:36,  3.27s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 88/143:  36%|███▌      | 143/400 [04:33<12:58,  3.03s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 109/167:  42%|████▏     | 166/400 [05:11<05:12,  1.34s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 109/169:  42%|████▏     | 168/400 [05:17<06:40,  1.73s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 116/183:  46%|████▌     | 183/400 [05:56<08:18,  2.30s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl
Error: No match found with fallback regex, using full response as next state
Retrying..


Correct: 116/184:  46%|████▌     | 184/400 [06:01<10:23,  2.89s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 121/191:  48%|████▊     | 191/400 [06:32<13:50,  3.97s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 122/193:  48%|████▊     | 193/400 [06:40<12:55,  3.74s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl
Error: No match found with fallback regex, using full response as next state


Correct: 129/204:  51%|█████     | 204/400 [07:05<08:11,  2.51s/it]

Error: No match found with fallback regex, using full response as next state
Retrying..


Correct: 133/208:  52%|█████▏    | 208/400 [07:24<12:03,  3.77s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 140/218:  55%|█████▍    | 218/400 [08:07<13:06,  4.32s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 145/228:  57%|█████▋    | 228/400 [08:50<12:02,  4.20s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 147/231:  58%|█████▊    | 231/400 [09:01<10:36,  3.77s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 150/236:  59%|█████▉    | 236/400 [09:22<09:42,  3.55s/it]

Error: No match found with fallback regex, using full response as next state
Retrying..


Correct: 151/238:  60%|█████▉    | 238/400 [09:31<10:30,  3.89s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl
Error: No match found with fallback regex, using full response as next state


Correct: 151/239:  60%|█████▉    | 239/400 [09:35<10:45,  4.01s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 156/248:  62%|██████▏   | 248/400 [10:15<10:44,  4.24s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 163/258:  64%|██████▍   | 258/400 [11:00<10:29,  4.43s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 172/270:  68%|██████▊   | 270/400 [11:44<08:01,  3.70s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 173/273:  68%|██████▊   | 273/400 [11:57<08:08,  3.85s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 178/279:  70%|██████▉   | 279/400 [12:23<07:47,  3.86s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 179/280:  70%|███████   | 280/400 [12:28<08:53,  4.44s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 182/284:  71%|███████   | 284/400 [12:44<07:23,  3.82s/it]

Error: No match found with fallback regex, using full response as next state
Retrying..


Correct: 186/290:  72%|███████▎  | 290/400 [13:11<07:38,  4.17s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 192/300:  75%|███████▌  | 300/400 [13:52<06:06,  3.67s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl
Error: No match found with fallback regex, using full response as next state
Retrying..


Correct: 193/305:  76%|███████▋  | 305/400 [14:15<05:43,  3.62s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 193/306:  76%|███████▋  | 306/400 [14:17<04:58,  3.18s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 193/307:  77%|███████▋  | 307/400 [14:21<05:06,  3.30s/it]

Retrying..


Correct: 195/310:  78%|███████▊  | 310/400 [14:37<06:28,  4.32s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 205/325:  81%|████████▏ | 325/400 [15:18<03:39,  2.93s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl
Error: No match found with fallback regex, using full response as next state


Correct: 213/335:  84%|████████▍ | 335/400 [16:03<04:27,  4.12s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 213/336:  84%|████████▍ | 336/400 [16:05<03:55,  3.68s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 220/345:  86%|████████▋ | 345/400 [16:44<03:39,  3.99s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 227/355:  89%|████████▉ | 355/400 [17:25<03:14,  4.31s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 234/365:  91%|█████████▏| 365/400 [18:07<02:29,  4.26s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 244/378:  94%|█████████▍| 378/400 [18:50<01:23,  3.80s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 249/385:  96%|█████████▌| 383/400 [19:11<01:11,  4.20s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 249/388:  97%|█████████▋| 388/400 [19:22<00:32,  2.72s/it]

Error: No match found with fallback regex, using full response as next state
Retrying..


Correct: 250/390:  98%|█████████▊| 390/400 [19:34<00:39,  3.98s/it]

Saving prompt buffer to elevator_transition_20241108_013120.pkl


Correct: 250/391:  98%|█████████▊| 391/400 [19:35<00:29,  3.26s/it]

Error: No match found with fallback regex, using full response as next state
Retrying..


Correct: 252/396:  99%|█████████▉| 395/400 [19:54<00:20,  4.14s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 256/400: 100%|██████████| 400/400 [20:11<00:00,  3.03s/it]


In [106]:
accuracy = corrects / total
print("Accuracy:", accuracy)

Accuracy: 0.64


### Explore the failure cases

In [17]:
import pickle

# states_text = [env.state_to_text(state) for state in states]
# zipped = list(zip(states_text, predicted_states, eval_results))

# # save the results

# with open("elevator_transition_results_mistral.pkl", "wb") as f:
#     pickle.dump(zipped, f)

with open("elevator_transition_results_mistral.pkl", "rb") as f:
    zipped = pickle.load(f)

In [18]:
parse_failures = [z for z in zipped if z[2]["has_parse_failure"]]

print(f"Parse failures: {len(parse_failures)} out of {len(zipped)}")

Parse failures: 46 out of 400


In [117]:
parse_fail = parse_failures[3]
state_text, predicted_next_state_text, eval_result = parse_fail
print(state_text)
print(predicted_next_state_text)
print(eval_result)

People waiting at floor 2: 2
People waiting at floor 3: 1
People waiting at floor 4: 1
People waiting at floor 5: 3
Elevator at floor 1.
There are 0 people in the elevator.
Elevator is moving up.
Elevator door is closed.


People waiting at floor 2: 3 (2 existing + 1 new)
People waiting at floor 3: 1
People waiting at floor 4: 2 (1 existing + 1 new)
People waiting at floor 5: 3
Elevator at floor 1.
There are 0 people in the elevator.
Elevator is moving up.
Elevator door is closed.
{'people_waiting': True, 'elevator_floor': True, 'people_in_elevator': True, 'elevator_direction': True, 'door_status': True, 'has_parse_failure': True, 'all_correct': True}


Rerun with OPENAI

In [None]:
# change .env to use openai and restart the kernel

transition_model = LLMTransitionModel(**params, prompt_buffer_prefix="elevator_transition", debug=False)
transition_model.llm_model

'gpt-4o-mini'

In [35]:
print(os.getenv("USE_OPENAI_CUSTOM"))

False


In [None]:
predicted_states = []
status_list = []
eval_results = []
corrects = 0
total = 0

pbar = tqdm.tqdm(trajectories_combined)

transition_model.debug = False

for trajectory in pbar:
    state, action, reward, next_state, done = trajectory
    state_text, action_text, next_state_text = env.state_to_text(state), env.action_to_text(action), env.state_to_text(next_state)
    
    while True:
        try:
            predicted_next_state_text, status = transition_model.get_next_state(state_text, action_text)
            break
        except:
            print("Retrying..")
            time.sleep(1)
        
    
    predicted_states.append(predicted_next_state_text)
    status_list.append(status)
    
    eval_result = eval_prediction(state_text, next_state_text, predicted_next_state_text)
    
    eval_results.append(eval_result)
    
    if eval_result["all_correct"]:
        corrects += 1
    total += 1
    
    pbar.set_description(f"Correct: {corrects}/{total}")