In [5]:
from glob import glob
import json
import pandas as pd

def load_test_data(path):
    with open(path, "r") as f:
        temp = json.load(f)
    data = {}
    for samples in temp["data"].values():
        for sample in samples:
            data[sample["id"]] = sample
    return {
        "gen_subgraph_args": temp["pseudo_subgraphs_args"],
        "retrieved_subgraph_agrs": temp["retrieve_subgraphs_args"],
        "data": data,
    }


def check_valid_path(path, required_samples=490):
    with open(path, "r") as f:
        data = json.load(f)["data"]
    valid = True
    invalid_list = []
    for samples in data.values():
        s = sum([1 for sample in samples if "retrieved_triplets" in sample])
        if s < required_samples:
            valid = False
        invalid_list.append(s)
    if not valid:
        print(f"Invalid path: {path}")
        print(f"Invalid samples: {invalid_list}")
    return valid

def get_pd_report(retrieved_paths):
    pd_data = []
    for path in retrieved_paths:
        try:
            temp = load_test_data(path)
            retrieved_subgraph_agrs = temp["retrieved_subgraph_agrs"]
            gen_subgraph_args = temp["gen_subgraph_args"]

            k_unknown_relations = retrieved_subgraph_agrs["algorithm_top_k_unknown_relations"]
            k_unknown_each_connected_node = retrieved_subgraph_agrs[
                "algorithm_top_k_unknown_each_connected_node"
            ]
            k_complete_relations = retrieved_subgraph_agrs["algorithm_top_k_complete_relations"]
            scoring_method = retrieved_subgraph_agrs["scoring_method"]
            llm_size = None
            for size in ["llm_8b", "llm_7b", "llm_3b", "llm_1.5b", "llm_1b"]:
                if size in gen_subgraph_args["specialized_model_path"]:
                    llm_size = size.split("_")[1]
                    break
            
            num_samples = None
            for size in ["10000", "5000", "2000", "500", "100"]:
                if size in gen_subgraph_args["specialized_model_path"]:
                    num_samples = int(size)
                    break
            
            num_beams = None
            for size in ["beams_10", "beams_5", "beams_3", "beams_1"]:
                if size in retrieved_subgraph_agrs["input_file_path"]:
                    num_beams = int(size.split("_")[1])
                    break
                
            constraint = True
            if "without_constraint" in retrieved_subgraph_agrs["input_file_path"]:
                constraint = False

            valid = check_valid_path(path)
            pd_data.append(
                {
                    "llm": llm_size,
                    "training": num_samples,
                    "beams": num_beams,
                    "constraint": constraint,
                    "unknown_relations": k_unknown_relations,
                    "unknown_each_connected_node": k_unknown_each_connected_node,
                    "complete_relations": k_complete_relations,
                    "scoring_method": scoring_method,
                    "valid": valid,
                    "path": path,
                }
            )
        except Exception as e:
            print(f"Error in {path}: {e}")
    return pd.DataFrame(pd_data)

In [8]:
retrieved_paths = glob(
    "/home/namb/hoangpv4/kg_fact_checking/data/output/retrieved_subgraphs_dev/**/*.json",
    recursive=True,
)
pd_data = get_pd_report(retrieved_paths)

llm = "3b"
training = 5000
beams = 5
constraint = True
unknown_relations = 3
unknown_each_connected_node = 3
complete_relations = 1
scoring_method = "embedding"
valid = True

filtered_setup_paths = pd_data[
    True
    & (pd_data["llm"] == llm)
    # & (pd_data["training"] == training)
    & (pd_data["beams"] == beams)
    & (pd_data["constraint"] == constraint)
    & (pd_data["unknown_relations"] == unknown_relations)
    & (pd_data["unknown_each_connected_node"] == unknown_each_connected_node)
    & (pd_data["complete_relations"] == complete_relations)
    & (pd_data["scoring_method"] == scoring_method)
    & (pd_data["valid"] == valid)
]["path"].tolist()

print()
print(f"Found {len(filtered_setup_paths)} paths")
for path in filtered_setup_paths:
    print(path)

Error in /home/namb/hoangpv4/kg_fact_checking/data/output/retrieved_subgraphs_dev/specialized_llm_3b_base_5000_checkpoint-157_num_beams_3_retrieved.json: 'pseudo_subgraphs_args'
Error in /home/namb/hoangpv4/kg_fact_checking/data/output/retrieved_subgraphs_dev/specialized_llm_3b_base_2000_checkpoint-125_num_beams_3_retrieved.json: 'pseudo_subgraphs_args'
Error in /home/namb/hoangpv4/kg_fact_checking/data/output/retrieved_subgraphs_dev/specialized_llm_3b_base_5000_checkpoint-157_num_beams_5_retrieved.json: 'pseudo_subgraphs_args'
Error in /home/namb/hoangpv4/kg_fact_checking/data/output/retrieved_subgraphs_dev/specialized_llm_3b_base_2000_checkpoint-125_num_beams_5_retrieved.json: 'pseudo_subgraphs_args'
Error in /home/namb/hoangpv4/kg_fact_checking/data/output/retrieved_subgraphs_dev/specialized_llm_3b_base_2000_checkpoint-125_num_beams_1_retrieved.json: 'pseudo_subgraphs_args'


Error in /home/namb/hoangpv4/kg_fact_checking/data/output/retrieved_subgraphs_dev/specialized_llm_3b_base_5000_checkpoint-157_num_beams_1_retrieved.json: 'pseudo_subgraphs_args'
Invalid path: /home/namb/hoangpv4/kg_fact_checking/data/output/retrieved_subgraphs_dev/specialized_llm_8b_base_5000_checkpoint-313/specialized_llm_8b_base_5000_checkpoint-313_num_beams_5_retrieved_-4254586098498722734.json
Invalid samples: [500, 500, 500, 500, 160]
Invalid path: /home/namb/hoangpv4/kg_fact_checking/data/output/retrieved_subgraphs_dev/specialized_llm_8b_base_5000_checkpoint-313/specialized_llm_8b_base_5000_checkpoint-313_num_beams_5_retrieved_-532697754939918850.json
Invalid samples: [500, 500, 500, 500, 160]
Invalid path: /home/namb/hoangpv4/kg_fact_checking/data/output/retrieved_subgraphs_dev/specialized_llm_1b_base_5000_checkpoint-157/specialized_llm_1b_base_5000_checkpoint-157_num_beams_10_retrieved_2382825554631171604.json
Invalid samples: [500, 500, 500, 500, 425]

Found 5 paths
/home/namb

In [8]:
SCRIPT = """
export PYTHONPATH="$(pwd)":$PYTHONPATH
python workflow/pipeline/llm_reasoning.py \
    --input-file-path {{input_file_path}} \
    --output-folder /home/namb/hoangpv4/kg_fact_checking/data/output/reasoning_results_dev_different_retrieval_params \
    --num-workers 100 \
    --vllm-server-host http://localhost:8264 \
    --model-name Llama-3.3-70B-Instruct \
""".strip()

script_list = []
for path in filtered_setup_paths:
    script_list.append(
        SCRIPT.replace("{{input_file_path}}", path)
    )
with open(
    "/home/namb/hoangpv4/kg_fact_checking/scripts/reasoning/reasoning_dev_different_retrieval_params.sh",
    "w",
) as f:
    f.write("\n\n".join(script_list))

In [None]:
# path = """
# /home/namb/hoangpv4/kg_fact_checking/data/output/retrieved_subgraphs/specialized_llm_3b_base_5000_checkpoint-157/specialized_llm_3b_base_5000_checkpoint-157_num_beams_5_retrieved_1631798178181607376.json
# """.strip()
# with open(path, "r") as f:
#     temp = json.load(f)
# temp.keys()
# VERIFY_PROMPT = """
# ### Task:
# Verify whether the fact in the given sentence is true or false based on the provided graph triplets. Use only the information in the triplets for verification.

# - The triplets provided represent all relevant knowledge that can be retrieved.
# - If the fact is a negation and the triplets do not include the fact, consider the fact as true.
# - Ignore questions and verify only the factual assertion within them. For example, in the question 'When was Daniel Martínez (politician) a leader of Montevideo?', focus on verifying the assertion 'Daniel Martínez (politician) a leader of Montevideo'.
# - Interpret the "~" symbol in triplets as indicating a reverse relationship. For example:
#   - "A ~loves B" means "B loves A".
#   - "A ~south of B" means "B is north of A".
# - The unit is not important. (e.g. "98400" is also same as 98.4kg)

# ### Response Format:
# Provide your response in the following JSON format without any additional explanations:
# {
#     "rationale": "A concise explanation for your decision",
#     "verdict": "true/false as the JSON value"
# }

# ### Triplets:
# {{triplets}}

# ### Claim:
# {{claim}}
# """.strip()

# from src.utils.batch_utils import BatchUtils
# data_dict = {}
# for samples in data.values():
#     for sample in samples:
#         data_dict[sample["id"]] = sample
# data_batch = [
#     {
#         "id": key,
#         "messages": [
#             {
#                 "role": "system",
#                 "content": COT_PROMPT_LLAMA3_70B.replace("{{claim}}", sample["claim"]),
#             }
#         ],
#     }
#     for key, sample in data_dict.items()
# ]
# BatchUtils.prepare_jsonl_for_batch_completions(
#     messages_with_ids=data_batch,
#     file_name="batch_openai_gpt4o_mini_factkg_test_cot.jsonl",
#     output_folder="/home/namb/hoangpv4/kg_fact_checking/data/batch_openai",
#     model="gpt-4o-mini",
#     temperature=0.0,
#     top_p=0.5,
#     max_tokens=256,
# )