In [1]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from extract import LOCM2
from utils import read_plan, read_json_file
from evaluator import ExecutabilityEvaluator
from collections import defaultdict
import json

In [2]:
def test_cross_exe(domain_name, train_traces , prefix_traces, gt_traces=[]):
    debug = {}
    try:
        polocm2 = LOCM2(state_param=True, viz=False, debug=debug)
        model = polocm2.extract_model(train_traces)
        learned_domain = model.to_pddl_domain(domain_name)
        gt_filename = os.path.join('../../data', 'goose-benchmarks', 'tasks', domain_name, 'domain.pddl')
        
        evaluator = ExecutabilityEvaluator(learned_domain,gt_filename, debug=False)

        exe = evaluator.get_cross_executabilities(prefix_traces, gt_traces)
        return exe
    except Exception as e:
        print(f"Error processing domain {domain_name}: {e}")
        return 0,0
    
def test_balanced_exe(domain_name, train_traces, test_traces, test_invalid_traces):
    debug = {}
    try:
        polocm2 = LOCM2(state_param=True, viz=False, debug=debug)
        model = polocm2.extract_model(train_traces)
        learned_domain = model.to_pddl_domain(domain_name)
        gt_filename = os.path.join('../../data', 'goose-benchmarks', 'tasks', domain_name, 'domain.pddl')
        
        evaluator = ExecutabilityEvaluator(learned_domain, gt_filename, debug=False)
        exe = evaluator.get_balanced_executability(test_traces, test_invalid_traces)

        return exe
    except Exception as e:
        print(f"Error processing domain {domain_name}: {e}")
        return 0,0

In [3]:
plain_traces = defaultdict(list)
with open("../../data/plain_traces/plain_traces.txt", "r") as f:
    lines = f.readlines()
    for line in lines:
        details = line.split("&&")

        domain_name = details[0]
        plan = details[-1]

        plain_traces[domain_name].append(read_plan(plan))

In [4]:
train_traces = defaultdict(list)
data = read_json_file("../../data/training_data/traces_plan_r1.json")
for learning_obj in data:
    domain_name = learning_obj["domain"]
    train_traces[domain_name].append(learning_obj)
    


In [5]:
valid_traces = defaultdict(list)
invalid_traces = defaultdict(list)

with open("../../data/plain_traces/random_walks.txt", "r") as f:
    lines = f.readlines()
    for line in lines:
        details = line.split("&&")
        domain_name = details[0]
        is_valid = details[1] == "rand"
        plan = details[-1]
        if is_valid:
            valid_traces[domain_name].append(read_plan(plan))
        else:
            invalid_traces[domain_name].append(read_plan(plan))

In [None]:
with open("./cross_exe_results.csv", "w") as f:
    f.write("ID, Domain, len, objs, l_exe, gt_exe\n")
    for domain, items in train_traces.items():
        print(f"Testing domain: {domain}")
        for item in items:
            traces = item["traces"]

            try:
                coress_exe = test_cross_exe(domain, traces, plain_traces[domain], plain_traces[domain])
            except Exception as e:
                print(f"Error in balanced executability for domain {domain}: {e}")
                coress_exe = 0,0

        # Test balanced executability
            f.write(f"{item['id']}, {domain}, {item['total_length']}, {item['number_of_objects']}, {coress_exe[0]}, {coress_exe[1]}\n")
            f.flush()
            print(f"Cross Executability for {domain}: {coress_exe}")

Testing domain: blocksworld
Cross Executability for blocksworld: (0.49481446543722457, 0.030303030303030304)
Cross Executability for blocksworld: (0.54726748531232, 0.06060606060606061)
Cross Executability for blocksworld: (0.5893494360618644, 0.9949494949494949)
Cross Executability for blocksworld: (0.3736530849672458, 0.8580808080808081)
Cross Executability for blocksworld: (0.5789783669363135, 0.6851050184383518)
Cross Executability for blocksworld: (0.5789783669363135, 0.6809303350970018)
Cross Executability for blocksworld: (0.5051855345627754, 0.9696969696969697)
Cross Executability for blocksworld: (0.5893494360618644, 0.9941919191919192)
Cross Executability for blocksworld: (0.5893494360618644, 0.9904040404040404)
Cross Executability for blocksworld: (0.3736530849672458, 0.8442760942760943)
Cross Executability for blocksworld: (0.5789783669363135, 0.6882435465768799)
Cross Executability for blocksworld: (0.8188258265147725, 0.9026747824693893)
Cross Executability for blocksworl

In [None]:

with open("./balanced_exe_results.csv", "w") as f:
    f.write("ID, Domain, len, objs, valid_exe, invalid_exe\n")
    for domain, items in train_traces.items():
        print(f"Testing domain: {domain}")
        for item in items:
            traces = item["traces"]

            try:
                balanced_exe = test_balanced_exe(domain, traces, valid_traces[domain], invalid_traces[domain])
            except Exception as e:
                print(f"Error in balanced executability for domain {domain}: {e}")
                balanced_exe = 0,0

        # Test balanced executability
            f.write(f"{item['id']}, {domain}, {item['total_length']}, {item['number_of_objects']}, {balanced_exe[0]}, {balanced_exe[1]}\n")
            f.flush()
            print(f"Balanced Executability for {domain}: {balanced_exe}")

Testing domain: blocksworld
Balanced Executability for blocksworld: (0.15850073419758487, 0.17101068727307364)
Balanced Executability for blocksworld: (0.2333423618833991, 0.24442387790960024)
Balanced Executability for blocksworld: (0.28044776143401734, 0.2989688392798139)
Balanced Executability for blocksworld: (0.1953716950739421, 0.21169369584782874)
Balanced Executability for blocksworld: (0.2445116775437157, 0.2645264406240145)
Balanced Executability for blocksworld: (0.2445116775437157, 0.2645264406240145)
Balanced Executability for blocksworld: (0.19686526926633768, 0.20924938222516934)
Balanced Executability for blocksworld: (0.28044776143401734, 0.2989688392798139)
Balanced Executability for blocksworld: (0.28044776143401734, 0.2989688392798139)
Balanced Executability for blocksworld: (0.1953716950739421, 0.21169369584782874)
Balanced Executability for blocksworld: (0.28554283604900293, 0.3046678743600073)
Balanced Executability for blocksworld: (0.41375391405604867, 0.437332

In [None]:
# with open("../../data/plain_traces/plain_traces.txt", "r") as f:
#     with open("res.csv", "w") as out:
#         lines = f.readlines()
#         for line in lines:
#             try:
#                 details = line.split("&&")
#                 name = f"{details[0]}-{details[2]}-{details[3]}"
            
#                 plan = details[-1]
#                 trace = read_plan(plan)
           

#                 domain_name = details[0]
            
#                 exe_l, exe_gt = test_cross_exe(domain_name, trace, trace)
#             except Exception as e:
#                 print(f"Error processing task {name}: {e}")
#                 exe_l, exe_gt = 0, 0
#             print(f"Executability for {name}: {exe_l}, {exe_gt}")
#             out.write(f"{name},{exe_l}, {exe_gt}\n")
#             out.flush()