In [4]:
import json
import os
from os import listdir
from os.path import isfile, join
import sys
import re

In [51]:
def parse_entry(entry):
    task_id, valid, invalid = entry
    task_id = task_id.split(" ")[1]
    valid = int(valid.split(": ")[1])
    invalid = int(invalid.split(": ")[1])

    return task_id, valid, invalid

def parse_log(log_path):
    with open(log_path, "r") as f:
        log = f.read()

    log = log.splitlines()
    header = log[:3]
    log = log[3:]

    model = header[1].split(": ")[1]
    num_gens = int(header[2].split(": ")[1])

    assert len(log) % 3 == 0

    log = list(zip(*[iter(log)]*3))
    log = [parse_entry(entry) for entry in log]

    return model, num_gens, log

# find the corresponding entry in the target log
def find_entry(entry, target_log):
    for target_entry in target_log:
        if entry[0] == target_entry[0]:
            return target_entry
    return None

def read_suggestions_files(dir, task_id):
    valid_path = join(dir, f"{task_id}_valid.txt")
    valid_programs_path = join(dir, f"{task_id}_valid_programs.txt")
    invalid_path = join(dir, f"{task_id}_invalid.txt")

    with open(valid_path, "r") as f:
        valid = f.read()
    with open(valid_programs_path, "r") as f:
        valid_programs = f.read()
    with open(invalid_path, "r") as f:
        invalid = f.read()

    return valid, valid_programs, invalid    


def merge(base_dir, target_dir, output_dir):

    base_log_path = join(base_dir, "log.txt")
    target_log_path = join(target_dir, "log.txt")

    base_model, base_num_gens, base_log = parse_log(base_log_path)
    target_model, target_num_gens, target_log = parse_log(target_log_path)
    total_num_gens = base_num_gens + target_num_gens
    assert base_model == target_model

    print(f"""Merging
        {base_dir} 
    and 
        {target_dir}
    into 
        {output_dir}""")

    print(f"Models match: {base_model}")
    print(f"Base tasks: {len(base_log)}")
    print(f"Target tasks: {len(target_log)}")
    print(f"Base gens per task: {base_num_gens}")
    print(f"Target gens per task: {target_num_gens}")

    print("Merging suggestions...")
    # create the output directory and make sure it is empty
    os.makedirs(output_dir, exist_ok=True)
    for f in listdir(output_dir):
        os.remove(join(output_dir, f))

    output_log = f"""Timestamp: merged
Model: {base_model}
Number of responses: {total_num_gens}
"""

    # TODO: correct this
    # for now we just assume that the tasks from target are a subset of the tasks from base
    # this is not always true
    updated, skipped = [], []
    for i, entry in enumerate(base_log):
        task_id, base_valid_count, base_invalid_count = entry
        base_valid, base_valid_programs, base_invalid = read_suggestions_files(base_dir, task_id)

        target_entry = find_entry(entry, target_log)
        if target_entry is not None:
            _, target_valid_count, target_invalid_count = target_entry
            total_valid_count = base_valid_count + target_valid_count
            total_invalid_count = base_invalid_count + target_invalid_count
            # print(f"Task {task_id}: valid: {total_valid_count}, invalid: {total_invalid_count}")

            # read the suggestions files
            target_valid, target_valid_programs, target_invalid = read_suggestions_files(target_dir, task_id)

            # # found_match = False
            # # put this on a try block because the files might not exist
            # try:
            #     base_valid, base_valid_programs, base_invalid = read_suggestions_files(base_dir, task_id)
            #     target_valid, target_valid_programs, target_invalid = read_suggestions_files(target_dir, task_id)
            #     found_match = True
            # except:
            #     print(f"Task {task_id} not found in suggestions files")

            # if found_match:
                # merge valid programs
            output_valid_programs = "\n\n".join([base_valid_programs, target_valid_programs])

            # merge valid suggestions
            base_valid = json.loads(base_valid)
            target_valid = json.loads(target_valid)
            output_valid = base_valid + target_valid
            assert len(output_valid) == total_valid_count

            # merge invalid suggestions
            base_invalid = json.loads(base_invalid)
            target_invalid = json.loads(target_invalid)
            output_invalid = base_invalid + target_invalid
            assert len(output_invalid) == total_invalid_count
            
            # else:
            #     output_valid_programs = base_valid_programs
            #     output_valid = json.loads(base_valid)
            #     output_invalid = json.loads(base_invalid)

            updated.append(task_id)
        else:
            total_valid_count = base_valid_count
            total_invalid_count = base_invalid_count

            output_valid_programs = base_valid_programs
            output_valid = json.loads(base_valid)
            output_invalid = json.loads(base_invalid)


            # print(f"Task {task_id} t found in target log")
            skipped.append(task_id)
        

        output_valid_programs_path = join(output_dir, f"{task_id}_valid_programs.txt")
        with open(output_valid_programs_path, "w") as f:
            f.write(output_valid_programs)

        # merge valid suggestions
        output_valid_path = join(output_dir, f"{task_id}_valid.txt")
        with open(output_valid_path, "w") as f:
            f.write(json.dumps(output_valid, indent=4))

        # merge invalid suggestions
        output_invalid_path = join(output_dir, f"{task_id}_invalid.txt")
        with open(output_invalid_path, "w") as f:
            f.write(json.dumps(output_invalid, indent=4))

        # update the log
        output_log += f"""Task {task_id}
Valid: {total_valid_count}
Invalid: {total_invalid_count} 
"""

    print(f"Updated tasks: {len(updated)}")
    print(f"Skipped tasks: {len(skipped)}")

    output_log_path = join(output_dir, "log.txt")
    with open(output_log_path, "w") as f:
        f.write(output_log)
    
    print("Done")
    

# base_dir = "models/logs/gens_20240514T165229_gpt4o/"
# target_dir = "models/logs/gens_20240514T170401_gpt4o/"

# base_dir = "models/logs/gens_20240515T230905_gpt4o/"
# target_dir = "models/logs/gens_20240516T143724_gpt4o/"

# base_dir = "models/logs/gens_20240516T190000_gpt4o_merged"
# target_dir = "models/logs/gens_20240516T191518_gpt4o"

base_dir = "models/logs/gens_20240516T202600_gpt4o_merged"
target_dir = "models/logs/gens_20240516T202929_gpt4o"

base_dir = "models/logs/gens_20240516T211900_gpt4o_merged"
target_dir = "models/logs/gens_20240516T212117"

output_dir = "models/logs/gens_merged/"

merge(base_dir, target_dir, output_dir)

Merging
        models/logs/gens_20240516T211900_gpt4o_merged 
    and 
        models/logs/gens_20240516T212117
    into 
        models/logs/gens_merged/
Models match: gpt-4o
Base tasks: 160
Target tasks: 2
Base gens per task: 190
Target gens per task: 15
Merging suggestions...
Updated tasks: 2
Skipped tasks: 158
Done
