In [None]:
# For the TruthHypo dataset: Classify the relation between core biomedical entities in each hypothesis using GPT-4.1 and output the results as a CSV file.

import json
import autogen
import openai

config_list = autogen.config_list_from_models(model_list=["gpt-4.1"])
gpt_config = {
    "chat_model": "gpt-4.1",
    "cache_seed": 42,
    "temperature": 0.7,
    "config_list": config_list,
    "timeout": 540000,
    "max_output_tokens": 1500
}

def classify_relation_openai(entities, hypothesis, mode="chemical_or_gene", model="gpt-4.1"):
    if mode == "disease":
        system_prompt = (
            "RelationClassifierAgent:\n"
            "Given a disease and a gene, and a biomedical hypothesis sentence, "
            "classify the potential relation between these two entities as either:\n"
            "'stimulate' or 'inhibit'.\n"
            "Your classification must be strictly based on the content of the provided hypothesis. "
            "Do not infer or assume any relationship not clearly stated or implied in the hypothesis sentence.\n"
            "Always choose one of these two options. Do not answer with anything else."

            
               )
    else:
        system_prompt = (
            "RelationClassifierAgent:\n"
            "Given two core biological entities (such as genes, proteins, chemicals, etc) and a biomedical hypothesis sentence, "
            "classify the potential relation between these two entities as one of:\n"
            "'positive_correlate' or 'negative_correlate'.\n"
            "Your classification must be strictly based on the content of the provided hypothesis. "
            "Do not infer or assume any relationship not clearly stated or implied in the hypothesis sentence.\n"
            "Respond ONLY with one of these two words (no explanation, no extra text).\n"
            "Always choose one of these two. Do not answer with anything else."

        )

    entities_str = "; ".join(entities)
    user_prompt = (
        f"Entities: {entities_str}\n"
        f"Hypothesis:\n{hypothesis}\n\n"
        "Relation type:"
    )
    response = openai.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        max_completion_tokens=10,
    )
    return response.choices[0].message.content.strip().lower()

def main(input_jsonl, output_csv):
    with open(input_jsonl, "r", encoding="utf-8") as fin, \
         open(output_csv, "w", encoding="utf-8") as fout:
        fout.write("core_genes,relation\n")
        for idx, line in enumerate(fin):   
            if not line.strip():
                continue
            obj = json.loads(line)
            core_genes = obj.get("core_genes", [])
            if not core_genes or not isinstance(core_genes, list):
                continue
            genes_str = ";".join(core_genes)
            hypos = obj.get("hypotheses", [])
            if not hypos or not isinstance(hypos, list):
                continue
            last_hypo = hypos[-1]

            if 100 <= idx < 200:      # Disease-Gene
                mode = "disease"
            else:
                mode = "chemical_or_gene"
            relation = classify_relation_openai(core_genes, last_hypo, mode=mode)
            fout.write(f"{genes_str},{relation}\n")
            print(f"{genes_str}: {relation}")

if __name__ == "__main__":
    input_jsonl = r"../data/data_raw/Truthhypo_hypo.jsonl"
    output_csv = r"Truthhypo_relation_results.csv"
    main(input_jsonl, output_csv)


In [None]:
# This script evaluates relation classification performance (precision, recall, F1, accuracy) on the TruthHypo dataset, reporting detailed metrics by group.

from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

csv_file = r"../data/data_eval/Truthhypo_relation_results.csv"
relation_list = []
with open(csv_file, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line or line.startswith("#"):
            continue
        if ":" in line:
            label = line.split(":")[-1].strip()
        elif "," in line:
            label = line.split(",")[-1].strip()
        else:
            continue
        relation_list.append(label)

gt_labels = (
    ["negative_correlate"]*50 +
    ["positive_correlate"]*50 +
    ["inhibit"]*50 +
    ["stimulate"]*50 +
    ["negative_correlate"]*50 +
    ["positive_correlate"]*50
)
relation_list = relation_list[:len(gt_labels)]

group_info = [
    ("Chemical-Gene", (0, 100),       ["negative_correlate", "positive_correlate"]),
    ("Disease-Gene",  (100, 200),     ["inhibit", "stimulate"]),
    ("Gene-Gene",     (200, 300),     ["negative_correlate", "positive_correlate"]),
]

print(f"{'Group/Class':<35}{'Precision':>10}{'Recall':>10}{'F1':>10}{'Acc':>10}{'Support':>10}")
print("-"*85)

for gname, (s, e), valid_labels in group_info:
    yt = gt_labels[s:e]
    yp = relation_list[s:e]
    p = precision_score(yt, yp, labels=valid_labels, average=None, zero_division=0)
    r = recall_score(yt, yp, labels=valid_labels, average=None, zero_division=0)
    f = f1_score(yt, yp, labels=valid_labels, average=None, zero_division=0)
    supports = [sum([1 for y in yt if y == lbl]) for lbl in valid_labels]
    for i, lbl in enumerate(valid_labels):
        correct = sum(1 for yt_i, yp_i in zip(yt, yp) if yt_i == yp_i == lbl)
        acc = correct / supports[i] if supports[i] > 0 else 0.0
        print(f"{gname + ' / ' + lbl:<35}{p[i]:>10.3f}{r[i]:>10.3f}{f[i]:>10.3f}{acc:>10.3f}{supports[i]:>10}")
    macro_p = precision_score(yt, yp, labels=valid_labels, average="macro", zero_division=0)
    macro_r = recall_score(yt, yp, labels=valid_labels, average="macro", zero_division=0)
    macro_f = f1_score(yt, yp, labels=valid_labels, average="macro", zero_division=0)
    macro_acc = accuracy_score(yt, yp)
    print(f"{gname + ' (macro avg)':<35}{macro_p:>10.3f}{macro_r:>10.3f}{macro_f:>10.3f}{macro_acc:>10.3f}{(e-s):>10}")
    print("-"*85)

all_labels = ["negative_correlate", "positive_correlate", "inhibit", "stimulate"]
macro_p_all = precision_score(gt_labels, relation_list, labels=all_labels, average="macro", zero_division=0)
macro_r_all = recall_score(gt_labels, relation_list, labels=all_labels, average="macro", zero_division=0)
macro_f_all = f1_score(gt_labels, relation_list, labels=all_labels, average="macro", zero_division=0)
acc_all = accuracy_score(gt_labels, relation_list)
print(f"{'ALL DATA (macro avg)':<35}{macro_p_all:>10.3f}{macro_r_all:>10.3f}{macro_f_all:>10.3f}{acc_all:>10.3f}{len(gt_labels):>10}")

