# Test accuracy of LLM Transition Model

In [1]:
#enable autoreload
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append('..')

from environments.ElevatorEnvironment import ElevatorEnvironment
from agents.llmzero import LLMTransitionModel
from agents.random_agent import RandomAgent
from agents.elevator_expert import ElevatorExpertPolicyAgent

import numpy as np
import random
import matplotlib.pyplot as plt

  from pyRDDLGym.Visualizer.MovieGenerator import MovieGenerator
  from tqdm.autonotebook import tqdm, trange


Load default elevator environment

In [3]:
env = ElevatorEnvironment()

c:\Users\ianch\miniconda3\envs\aiplanning\Lib\site-packages\pyRDDLGym\Examples c:\Users\ianch\miniconda3\envs\aiplanning\Lib\site-packages\pyRDDLGym\Examples\manifest.csv
Available example environment(s):
CartPole_continuous -> A simple continuous state-action MDP for the classical cart-pole system by Rich Sutton, with actions that describe the continuous force applied to the cart.
CartPole_discrete -> A simple continuous state MDP for the classical cart-pole system by Rich Sutton, with discrete actions that apply a constant force on either the left or right side of the cart.
Elevators -> The Elevator domain models evening rush hours when people from different floors in a building want to go down to the bottom floor using elevators.
HVAC -> Multi-zone and multi-heater HVAC control problem
MarsRover -> Multi Rover Navigation, where a group of agent needs to harvest mineral.
MountainCar -> A simple continuous MDP for the classical mountain car control problem.
NewLanguage -> Example with

<op> is one of {<=, <, >=, >}
<rhs> is a deterministic function of non-fluents or constants only.
>> ( sum_{?f: floor} [ elevator-at-floor(?e, ?f) ] ) == 1


Define agent to generate trajectories for testing and initialize transition model

In [5]:
SEED = 117
np.random.seed(SEED)
random.seed(SEED)

In [6]:
random_agent = RandomAgent(env, seed=SEED)
expert_agent = ElevatorExpertPolicyAgent()

params = {
    "env_params": {
        "system_prompt_path": "../prompts/prompt_elevator_transition.txt",
        "extract_state_regex": r"next state:(.*?)```",
        "extract_regex_fallback": [r"\*\*next state:\*\*(.*)", r"next state:(.*)"],
    },
    "load_prompt_buffer_path": None, # update this path to the path of the saved prompt buffer   
}

transition_model = LLMTransitionModel(**params, prompt_buffer_prefix="elevator_transition_mistral", debug=True)

Generate trajectories for both agents

In [7]:
# random agent
state, _ = env.reset(SEED)
done = False

random_agent_trajectory = []

while not done:
    action = random_agent.act(state)
    next_state, reward, done, _, _ = env.step(action)
    random_agent_trajectory.append((state, action, reward, next_state, done))
    state = next_state
    
# expert agent
state, _ = env.reset(SEED + 1)  # use a different seed for the expert agent
done = False

expert_agent_trajectory = []

while not done:
    action = expert_agent.act(state)
    next_state, reward, done, _, _ = env.step(action)
    expert_agent_trajectory.append((state, action, reward, next_state, done))
    state = next_state

In [8]:
trajectories_combined = random_agent_trajectory + expert_agent_trajectory

In [10]:
import pickle

#load "elevator_transition_results_gpt.pkl
with open("elevator_transition_results_gpt.pkl", "rb") as f:
    results = pickle.load(f)
    
state_text_generated = [t[0] for t in trajectories_combined]
state_text_loaded = [t[0] for t in results]
state_text_generated = [env.state_to_text(s) for s in state_text_generated]
state_text_loaded == state_text_generated

False

### Sanity check test

In [33]:
# test the transition model
trajectory = random.choice(trajectories_combined)
state_test, action, next_state_test = trajectory[0], trajectory[1], trajectory[3]
state_text, action_text, next_state_text = env.state_to_text(state_test), env.action_to_text(action), env.state_to_text(next_state_test)

print("State text:\n", state_text)
print("Action text:", action_text)
print("Next state text:\n", next_state_text)

State text:
 People waiting at floor 2: 3
People waiting at floor 3: 3
People waiting at floor 4: 1
People waiting at floor 5: 3
Elevator at floor 3.
There are 4 people in the elevator.
Elevator is moving down.
Elevator door is closed.

Action text: move
Next state text:
 People waiting at floor 2: 3
People waiting at floor 3: 3
People waiting at floor 4: 1
People waiting at floor 5: 3
Elevator at floor 2.
There are 4 people in the elevator.
Elevator is moving down.
Elevator door is closed.



In [63]:
predicted_next_state_text, status = transition_model.get_next_state(state_text, action_text)



In [62]:
print(predicted_next_state_text)


People waiting at floor 2: 2
People waiting at floor 3: 3
People waiting at floor 4: 0
People waiting at floor 5: 2
Elevator at floor 2.
There are 4 people in the elevator.
Elevator is moving down.
Elevator door is closed.


### Get results for all trajectories

In [11]:
import re

def try_get_int(text):
    try:
        return int(text)
    except:
        return None

def parse_state_text(state_text):
    lines = state_text.split("\n")
    people_waiting = []
    elevator_floor = None
    people_in_elevator = None
    elevator_direction = None
    door_status = None
    
    for line in lines:
        line = line.strip().lower()
        # assuming the lines are in order
        if "people waiting at floor" in line:
            split = line.split(":")
            # print("waiting", split)
            waiting = split[1].strip() if len(split) == 2 else None
                
            waiting = try_get_int(waiting)            
            people_waiting.append(waiting)
            
        if "elevator at floor" in line:
            split = line.split("floor")
            # print("floor", split)
            elevator_floor = split[1].strip() if len(split) == 2 else None
            if '.' in elevator_floor:
                elevator_floor = elevator_floor.split(".")[0]
            elevator_floor = try_get_int(elevator_floor)
            
        if "people in the elevator" in line:
            #use regex to find the only number in the line
            people_in_elevator = re.findall(r'\d+', line)
            people_in_elevator = int(people_in_elevator[0]) if len(people_in_elevator) else None
            
        if "elevator is moving" in line:
            elevator_direction = "up" if "up" in line else "down" if "down" in line else None
            
        if "elevator door is" in line:
            door_status = "open" if "open" in line else "closed" if "closed" in line else None
            
    # check for parse failure
    failure_flag = False
    
    if len(people_waiting) == 0 or None in people_waiting:
        failure_flag = True
    if None in [elevator_floor, people_in_elevator, elevator_direction, door_status]:
        failure_flag = True
        
    out = {
        "people_waiting": people_waiting,
        "elevator_floor": elevator_floor,
        "people_in_elevator": people_in_elevator,
        "elevator_direction": elevator_direction,
        "door_status": door_status,
        "failure_flag": failure_flag
    }  
    
    return out           
        

In [12]:
# sanity check for ground truth states
states = [trajectory[0] for trajectory in trajectories_combined]
state_texts = [env.state_to_text(state) for state in states]

for state_text in state_texts:
    parsed_state = parse_state_text(state_text)
    if parsed_state["failure_flag"]:
        print("Failed to parse state text")
        print(state_text)
        print(parsed_state)
else:
    print("All states parsed successfully")
    print(parsed_state)

All states parsed successfully
{'people_waiting': [1, 0, 0, 3], 'elevator_floor': 1, 'people_in_elevator': 0, 'elevator_direction': 'up', 'door_status': 'closed', 'failure_flag': False}


In [13]:
def eval_prediction(current_state_text, next_state_text, predicted_next_state_text):
    '''
    Evaluate the correctness of the predicted next state.
    
    For people waiting:
        - check if the number of people is not above 3.
        - check if the number is equal or more than the current state.
            - If it is more, check if the elevator picked up the passengers on that floor.
        
    For elevator floor, people in elevator, elevator direction, and door status:
        - check if the predicted values are the same as the ground truth values
        
    Returns a dictionary indicating if each field is correct or not.
    '''
    
    state = parse_state_text(current_state_text)
    next_state = parse_state_text(next_state_text)
    predicted_next_state = parse_state_text(predicted_next_state_text)
            
    elevator_floor_correct = next_state["elevator_floor"] == predicted_next_state["elevator_floor"]
    people_in_elevator_correct = next_state["people_in_elevator"] == predicted_next_state["people_in_elevator"]
    elevator_direction_correct = next_state["elevator_direction"] == predicted_next_state["elevator_direction"]
    door_status_correct = next_state["door_status"] == predicted_next_state["door_status"]
    
    people_waiting_correct = False
    
    if len(next_state["people_waiting"]) == len(predicted_next_state["people_waiting"]):
        for i, (waiting, predicted_waiting) in enumerate(zip(state["people_waiting"], state["people_waiting"])):
            if predicted_waiting is None:
                people_waiting_correct = False
                break
            if  predicted_waiting > 3:
                people_waiting_correct = False
                break
            if predicted_waiting < waiting:
                #check if the elevator picked up the passengers on that floor
                #pre-requisite: previous status all correct
                if not all([elevator_floor_correct, people_in_elevator_correct, elevator_direction_correct, door_status_correct]):
                    people_waiting_correct = False
                    break
                else:
                    floor = i + 2
                    
                    cond_1 = next_state["elevator_floor"] == floor
                    cond_2 = next_state["door_status"] == "open"
                    cond_3 = predicted_waiting == 0
                    
                    if not all([cond_1, cond_2, cond_3]):
                        people_waiting_correct = False
                        break 
                                  
        else:
            people_waiting_correct = True
    
    
    has_parse_failure = predicted_next_state["failure_flag"]
    
    out = {
        "people_waiting": people_waiting_correct,
        "elevator_floor": elevator_floor_correct,
        "people_in_elevator": people_in_elevator_correct,
        "elevator_direction": elevator_direction_correct,
        "door_status": door_status_correct,
        "has_parse_failure": has_parse_failure
    }
    
    all_correct = all([v for k, v in out.items() if k != "has_parse_failure"])
    out["all_correct"] = all_correct
    
    return out    

In [14]:
# sanity test with gt next state
next_states = [trajectory[3] for trajectory in trajectories_combined]
next_state_texts = [env.state_to_text(state) for state in next_states]

for state_text, next_state_text in zip(state_texts, next_state_texts):
    predicted_next_state_text = next_state_text
    prediction_eval = eval_prediction(state_text, next_state_text, predicted_next_state_text)
    
    #should be all correct
    if not prediction_eval["all_correct"]:
        print("Failed to predict next state")
        print(state_text)
        print(next_state_text)
        print(predicted_next_state_text)
        print(prediction_eval)
        
        break
    
else:
    print("All predictions correct")

All predictions correct


In [15]:
transition_model.llm_model

'open-mixtral-8x22b'

In [16]:
import tqdm
import time

predicted_states = []
status_list = []
eval_results = []
corrects = 0
total = 0

pbar = tqdm.tqdm(trajectories_combined)

transition_model.debug = False

for trajectory in pbar:
    state, action, reward, next_state, done = trajectory
    state_text, action_text, next_state_text = env.state_to_text(state), env.action_to_text(action), env.state_to_text(next_state)
    
    while True:
        try:
            predicted_next_state_text, status = transition_model.get_next_state(state_text, action_text)
            break
        except:
            print("Retrying..")
            time.sleep(1)
        
    
    predicted_states.append(predicted_next_state_text)
    status_list.append(status)
    
    eval_result = eval_prediction(state_text, next_state_text, predicted_next_state_text)
    
    eval_results.append(eval_result)
    
    if eval_result["all_correct"]:
        corrects += 1
    total += 1
    
    pbar.set_description(f"Correct: {corrects}/{total}")
    

Correct: 8/10:   2%|▎         | 10/400 [00:49<32:08,  4.95s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 10/14:   4%|▎         | 14/400 [01:09<32:33,  5.06s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 17/22:   6%|▌         | 22/400 [01:38<20:01,  3.18s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 29/37:   9%|▉         | 37/400 [02:06<13:42,  2.27s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 35/43:  10%|█         | 42/400 [02:25<20:06,  3.37s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 38/46:  12%|█▏        | 46/400 [02:39<21:39,  3.67s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 46/54:  14%|█▎        | 54/400 [03:13<25:07,  4.36s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl
Error: No match found with fallback regex, using full response as next state


Correct: 49/57:  14%|█▍        | 57/400 [03:28<27:17,  4.77s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 56/64:  16%|█▌        | 64/400 [03:58<21:00,  3.75s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 57/65:  16%|█▋        | 65/400 [04:03<21:53,  3.92s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 61/70:  18%|█▊        | 70/400 [04:24<23:58,  4.36s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 66/76:  18%|█▊        | 73/400 [04:40<26:32,  4.87s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 68/78:  20%|█▉        | 78/400 [04:49<16:03,  2.99s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 69/79:  20%|█▉        | 79/400 [04:54<17:32,  3.28s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 79/92:  23%|██▎       | 91/400 [05:35<14:47,  2.87s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 86/99:  23%|██▎       | 93/400 [05:40<13:33,  2.65s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl
Error: No match found with fallback regex, using full response as next state
Error: No match found with fallback regex, using full response as next state


Correct: 95/110:  28%|██▊       | 110/400 [06:30<13:52,  2.87s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 97/113:  28%|██▊       | 113/400 [06:45<18:36,  3.89s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 103/119:  30%|██▉       | 119/400 [07:12<20:22,  4.35s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 108/124:  31%|███       | 124/400 [07:32<17:19,  3.77s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 111/128:  32%|███▏      | 128/400 [07:52<21:15,  4.69s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 117/134:  34%|███▎      | 134/400 [08:20<20:50,  4.70s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 127/145:  36%|███▌      | 144/400 [09:06<19:18,  4.53s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 136/154:  38%|███▊      | 153/400 [09:39<17:29,  4.25s/it]

Error: No match found with fallback regex, using full response as next state
Error: No match found with fallback regex, using full response as next state


Correct: 137/155:  39%|███▉      | 155/400 [09:44<14:06,  3.46s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 139/157:  39%|███▉      | 157/400 [09:54<17:15,  4.26s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 152/172:  42%|████▏     | 167/400 [10:14<08:38,  2.23s/it]

Error: No match found with fallback regex, using full response as next state
Error: No match found with fallback regex, using full response as next state


Correct: 155/175:  43%|████▎     | 173/400 [10:19<05:31,  1.46s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 158/178:  44%|████▍     | 178/400 [10:28<06:10,  1.67s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 162/182:  46%|████▌     | 182/400 [10:41<08:40,  2.39s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 163/183:  46%|████▌     | 183/400 [10:46<10:03,  2.78s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 173/193:  48%|████▊     | 193/400 [11:27<12:33,  3.64s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 183/205:  51%|█████▏    | 205/400 [12:13<12:23,  3.81s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 192/215:  54%|█████▍    | 215/400 [13:00<14:02,  4.55s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl
Error: No match found with fallback regex, using full response as next state


Correct: 201/225:  56%|█████▋    | 225/400 [13:48<13:30,  4.63s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 202/226:  56%|█████▋    | 226/400 [13:52<13:29,  4.65s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 208/233:  58%|█████▊    | 233/400 [14:25<13:04,  4.70s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 210/235:  59%|█████▉    | 235/400 [14:34<12:50,  4.67s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 215/240:  60%|██████    | 240/400 [14:59<12:53,  4.84s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 216/242:  60%|██████    | 242/400 [15:09<12:48,  4.87s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 218/244:  61%|██████    | 244/400 [15:18<11:59,  4.61s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 219/245:  61%|██████▏   | 245/400 [15:24<13:01,  5.04s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 228/255:  64%|██████▍   | 255/400 [16:12<12:26,  5.15s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 230/258:  64%|██████▍   | 258/400 [16:26<11:30,  4.86s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 234/265:  66%|██████▋   | 265/400 [17:02<11:18,  5.03s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 243/275:  69%|██████▉   | 275/400 [17:50<10:01,  4.81s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 246/278:  70%|██████▉   | 278/400 [18:08<11:15,  5.54s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 251/283:  71%|███████   | 283/400 [18:32<09:17,  4.77s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 252/285:  71%|███████▏  | 285/400 [18:41<09:06,  4.75s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl
Error: No match found with fallback regex, using full response as next state


Correct: 261/295:  74%|███████▍  | 295/400 [19:31<09:04,  5.19s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 267/305:  76%|███████▋  | 305/400 [20:18<07:29,  4.74s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 276/315:  79%|███████▉  | 315/400 [21:03<06:12,  4.39s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 283/325:  81%|████████▏ | 325/400 [21:51<06:04,  4.86s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 290/335:  84%|████████▍ | 335/400 [22:38<05:00,  4.62s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 294/339:  85%|████████▍ | 339/400 [22:56<04:32,  4.47s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 300/345:  86%|████████▋ | 345/400 [23:22<04:04,  4.44s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 305/351:  88%|████████▊ | 351/400 [23:51<04:01,  4.93s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 308/355:  89%|████████▉ | 355/400 [24:09<03:23,  4.52s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 317/365:  91%|█████████▏| 365/400 [24:54<02:36,  4.46s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 320/368:  92%|█████████▏| 368/400 [25:10<02:38,  4.95s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 326/375:  94%|█████████▍| 375/400 [25:46<02:06,  5.06s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 328/377:  94%|█████████▍| 377/400 [25:54<01:45,  4.57s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 335/385:  96%|█████████▋| 385/400 [26:31<01:08,  4.54s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 344/395:  99%|█████████▉| 395/400 [27:17<00:23,  4.70s/it]

Saving prompt buffer to elevator_transition_mistral_20241109_020603.pkl


Correct: 347/400: 100%|██████████| 400/400 [27:43<00:00,  4.16s/it]


In [17]:
accuracy = corrects / total
print("Accuracy:", accuracy)

Accuracy: 0.8675


### Explore the failure cases

In [17]:
import pickle

# states_text = [env.state_to_text(state) for state in states]
# zipped = list(zip(states_text, predicted_states, eval_results))

# # save the results

# with open("elevator_transition_results_mistral.pkl", "wb") as f:
#     pickle.dump(zipped, f)

with open("elevator_transition_results_mistral.pkl", "rb") as f:
    zipped = pickle.load(f)

In [18]:
parse_failures = [z for z in zipped if z[2]["has_parse_failure"]]

print(f"Parse failures: {len(parse_failures)} out of {len(zipped)}")

Parse failures: 46 out of 400


In [117]:
parse_fail = parse_failures[3]
state_text, predicted_next_state_text, eval_result = parse_fail
print(state_text)
print(predicted_next_state_text)
print(eval_result)

People waiting at floor 2: 2
People waiting at floor 3: 1
People waiting at floor 4: 1
People waiting at floor 5: 3
Elevator at floor 1.
There are 0 people in the elevator.
Elevator is moving up.
Elevator door is closed.


People waiting at floor 2: 3 (2 existing + 1 new)
People waiting at floor 3: 1
People waiting at floor 4: 2 (1 existing + 1 new)
People waiting at floor 5: 3
Elevator at floor 1.
There are 0 people in the elevator.
Elevator is moving up.
Elevator door is closed.
{'people_waiting': True, 'elevator_floor': True, 'people_in_elevator': True, 'elevator_direction': True, 'door_status': True, 'has_parse_failure': True, 'all_correct': True}


Rerun with OPENAI

In [36]:
# change .env to use openai and restart the kernel

transition_model = LLMTransitionModel(**params, prompt_buffer_prefix="elevator_transition", debug=False)
transition_model.llm_model

'gpt-4o-mini'

In [38]:
import tqdm

predicted_states = []
status_list = []
eval_results = []
corrects = 0
total = 0

pbar = tqdm.tqdm(trajectories_combined)

transition_model.debug = False

for trajectory in pbar:
    state, action, reward, next_state, done = trajectory
    state_text, action_text, next_state_text = env.state_to_text(state), env.action_to_text(action), env.state_to_text(next_state)
    
    while True:
        try:
            predicted_next_state_text, status = transition_model.get_next_state(state_text, action_text)
            break
        except:
            print("Retrying..")
            time.sleep(1)
        
    
    predicted_states.append(predicted_next_state_text)
    status_list.append(status)
    
    eval_result = eval_prediction(state_text, next_state_text, predicted_next_state_text)
    
    eval_results.append(eval_result)
    
    if eval_result["all_correct"]:
        corrects += 1
    total += 1
    
    pbar.set_description(f"Correct: {corrects}/{total}")
    
states_text = [env.state_to_text(state) for state in states]
zipped = list(zip(states_text, predicted_states, eval_results))

# save the results

with open("elevator_transition_results_gpt.pkl", "wb") as f:
    pickle.dump(zipped, f)

Correct: 16/16:   3%|▎         | 13/400 [00:47<20:59,  3.25s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 27/28:   7%|▋         | 28/400 [01:38<25:24,  4.10s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 38/42:  10%|█         | 41/400 [02:22<22:02,  3.68s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 48/53:  13%|█▎        | 53/400 [03:07<25:55,  4.48s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 57/65:  16%|█▋        | 65/400 [03:46<17:36,  3.15s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 65/75:  19%|█▉        | 75/400 [04:36<26:28,  4.89s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 75/86:  22%|██▏       | 86/400 [05:16<20:28,  3.91s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 91/104:  26%|██▌       | 104/400 [05:57<11:13,  2.28s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 103/117:  29%|██▉       | 117/400 [06:44<21:30,  4.56s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 109/127:  32%|███▏      | 127/400 [07:35<22:11,  4.88s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 125/143:  36%|███▌      | 143/400 [08:21<14:39,  3.42s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 146/169:  42%|████▏     | 168/400 [09:09<07:02,  1.82s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 160/183:  46%|████▌     | 183/400 [10:01<11:49,  3.27s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 169/193:  48%|████▊     | 193/400 [10:50<15:55,  4.62s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 183/208:  52%|█████▏    | 208/400 [11:40<14:59,  4.69s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 189/218:  55%|█████▍    | 218/400 [12:30<16:18,  5.38s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 194/225:  56%|█████▋    | 225/400 [13:02<12:54,  4.43s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 197/228:  57%|█████▋    | 228/400 [13:26<19:30,  6.81s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 206/238:  60%|█████▉    | 238/400 [14:27<19:11,  7.11s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 214/248:  62%|██████▏   | 248/400 [15:19<12:04,  4.76s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 221/258:  64%|██████▍   | 258/400 [16:10<11:42,  4.95s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 229/270:  68%|██████▊   | 270/400 [16:58<08:26,  3.89s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 234/276:  69%|██████▉   | 276/400 [17:29<10:24,  5.04s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 237/280:  70%|███████   | 280/400 [17:49<09:42,  4.85s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 246/290:  72%|███████▎  | 290/400 [18:37<08:47,  4.79s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 256/300:  75%|███████▌  | 300/400 [19:23<07:32,  4.53s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 258/303:  76%|███████▌  | 303/400 [19:38<07:44,  4.79s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 263/310:  78%|███████▊  | 310/400 [20:12<07:29,  5.00s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 273/325:  81%|████████▏ | 325/400 [20:59<04:46,  3.82s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 282/334:  84%|████████▎ | 334/400 [21:49<06:55,  6.29s/it]

Error: No match found with fallback regex, using full response as next state


Correct: 282/335:  84%|████████▍ | 335/400 [21:54<06:29,  5.99s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 290/345:  86%|████████▋ | 345/400 [22:50<05:33,  6.06s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 296/355:  89%|████████▉ | 355/400 [23:50<04:11,  5.58s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 304/365:  91%|█████████▏| 365/400 [24:40<02:52,  4.94s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 313/378:  94%|█████████▍| 378/400 [25:33<01:48,  4.93s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 321/390:  98%|█████████▊| 390/400 [26:31<00:49,  4.99s/it]

Saving prompt buffer to elevator_transition_20241108_035325.pkl


Correct: 330/400: 100%|██████████| 400/400 [27:19<00:00,  4.10s/it]


In [39]:
accuracy = corrects / total
accuracy

0.825