In [64]:
import lark
from lark import Lark, Transformer, v_args, Visitor
from collections import Counter
import json

In [81]:
sexpr_grammar_path = "dsl/sexpr.lark"
sexpr_grammar = open(sexpr_grammar_path).read()
sexpr_parser = Lark(sexpr_grammar, start='sexpr')

# This visitor computes a map with the number of times each atom appears in the S-Expression
class SExprStats(Visitor):
    def __init__(self) -> None:
        super().__init__()
        self.counts = Counter()

    def atom(self, atom):
        name = atom.children[0].value
        self.counts[name] += 1

In [82]:
sexpr_example = """
(
    (a 1)
    (
        (b 2)
        (a 3)
    )
)
"""
sexpr_tree = sexpr_parser.parse(sexpr_example)
print(sexpr_tree.pretty())

stats = SExprStats()
stats.visit(sexpr_tree)
print(stats.counts)

sexpr
  sexpr
    sexpr
      atom	a
    sexpr
      atom	1
  sexpr
    sexpr
      sexpr
        atom	b
      sexpr
        atom	2
    sexpr
      sexpr
        atom	a
      sexpr
        atom	3

Counter({'a': 2, 'b': 1, '2': 1, '3': 1, '1': 1})


In [106]:

def parse_entry(entry):
    task_id, valid, invalid = entry
    task_id = task_id.split(" ")[1]
    valid = int(valid.split(": ")[1])
    invalid = int(invalid.split(": ")[1])

    return task_id, valid, invalid

def parse_log(log_path):
    with open(log_path, "r") as f:
        log = f.read()

    log = log.splitlines()
    header = log[:3]
    log = log[3:]

    model = header[1].split(": ")[1]
    num_gens = int(header[2].split(": ")[1])

    assert len(log) % 3 == 0

    log = list(zip(*[iter(log)]*3))
    log = [parse_entry(entry) for entry in log]

    return model, num_gens, log

def try_parse_sexpr(code):
    try:
        code_tree = sexpr_parser.parse(code)
        return code , "valid", code_tree
    except Exception as e:
        # print(e)
        pass
    rep_code = "(" + code
    try:
        code_tree = sexpr_parser.parse(rep_code)
        return rep_code, "paren_l", code_tree
    except Exception as e:
        # print(e)
        pass
    rep_code = code + ")"
    try:
        code_tree = sexpr_parser.parse(rep_code)
        return rep_code, "paren_r", code_tree
    except Exception as e:
        # print(e)
        pass
    return code, "invalid", None

# gens_path = "dsl/v0_3/generations/gens_test"
gens_path = "dsl/v0_3/generations/arga_gpt4o_m100_20240516"

# read the log file
with open(gens_path + "/log.txt") as f:
    log = f.read()

# parse the log file
model, num_gens, log = parse_log(gens_path + "/log.txt")
print(model, num_gens)

all_tasks_count = {}

# log = log[:1]
print(len(log))
for task_id, n_valid, n_invalid in log:
    valid_file = gens_path + f"/{task_id}_valid.txt"
    invalid_file = gens_path + f"/{task_id}_invalid.txt"

    with open(valid_file, "r") as f:
        valid = json.load(f)
    with open(invalid_file, "r") as f:
        invalid = json.load(f)

    all = []
    for entry in valid:
        response = json.loads(entry)
        all.append(response["code"])
    for entry in invalid:
        response = json.loads(entry["response"])
        all.append(response["code"])
    # print(f"Task {task_id}: {n_valid} valid, {n_invalid} invalid")
    print(f"Task {task_id}:")
    assert n_valid + n_invalid == len(all)
    print(f"{len(valid)}/{len(all)} valid")
    # lib = f"(\n{''.join(all)}\n)"
    #parse the library
    # lib_tree = sexpr_parser.parse(lib)
    global_counter = Counter()
    parsed = 0
    task_stats = {
        # "task_id": task_id,
        "per_solution_counts": []
    }
    skipped = 0
    for code in all:
        # print(code)
        code, parse_status, code_tree = try_parse_sexpr(code)
        if parse_status == "invalid":
            skipped += 1
            continue
        parsed += 1
        # print(code_tree.pretty())
        stats = SExprStats()
        stats.visit(code_tree)
        counts = stats.counts
        global_counter += counts
        entry = {
            "code": code,
            "parse_status": parse_status,
            "counts": counts
        }
        task_stats["per_solution_counts"].append(entry)
    task_stats["total_counts"] = global_counter
    # print(json.dumps(global_stats, indent=4))
    print(f"{parsed}/{len(all)} repaired")
    print(f"{skipped}/{len(all)} skipped")

    all_tasks_count[task_id] = task_stats

# save it to gens directory
counts_path = gens_path + f"/counts.json"
with open(counts_path, "w") as f:
    json.dump(all_tasks_count, f, indent=4)
    print(f"Counts saved to {counts_path}")
    

gpt-4o 205
160
Task 00d62c1b:
106/150 valid
150/150 repaired
0/150 skipped
Task 025d127b:
139/150 valid
149/150 repaired
1/150 skipped
Task 05f2a901:
141/150 valid
149/150 repaired
1/150 skipped
Task 08ed6ac7:
111/175 valid
171/175 repaired
4/175 skipped
Task 0962bcdd:
113/150 valid
150/150 repaired
0/150 skipped
Task 0ca9ddb6:
101/175 valid
173/175 repaired
2/175 skipped
Task 0d3d703e:
140/150 valid
148/150 repaired
2/150 skipped
Task 0e206a2e:
115/150 valid
150/150 repaired
0/150 skipped
Task 150deff5:
103/150 valid
146/150 repaired
4/150 skipped
Task 1a07d186:
115/150 valid
147/150 repaired
3/150 skipped
Task 1b60fb0c:
113/150 valid
148/150 repaired
2/150 skipped
Task 1caeab9d:
113/150 valid
148/150 repaired
2/150 skipped
Task 1e0a9b12:
141/150 valid
150/150 repaired
0/150 skipped
Task 1f0c79e5:
134/150 valid
150/150 repaired
0/150 skipped
Task 2204b7a8:
106/150 valid
149/150 repaired
1/150 skipped
Task 22168020:
142/150 valid
150/150 repaired
0/150 skipped
Task 22233c11:
105/150 va