In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
from typing import List, Dict, Union
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
from openai import OpenAI
from metrics import *
from agent import *
from prompt import *
from scipy.spatial.distance import jensenshannon

# re-run for error cases

In [29]:
client = OpenAI(
            api_key = "empty",
            base_url = "http://localhost:8000/v1",
            timeout=120.0
            )
    
class Response(BaseModel):
    reasoning: str = Field(description="Step-by-step explanation of how you interpreted the report to determine the cancer stage.")
    stage: str = Field(description="The cancer stage determined from the report.")
    
testing_schema = Response.model_json_schema()

def test_individual_report(dataset, patient_filename, prompt_method, prompt, stage_type, context):
    report = dataset[dataset.patient_filename == patient_filename]["text"].values[0]

    if context:
        formatted_prompt = prompt.format(report=report, context=context)
    else:
        formatted_prompt = prompt.format(report=report)

    messages = [{"role": "user", "content": formatted_prompt}]
    response = client.chat.completions.create(
        model = "mistralai/Mixtral-8x7B-Instruct-v0.1",
        messages = messages,
        extra_body = {"guided_json": testing_schema},
        temperature = 0.1)
    
    response = json.loads(response.choices[0].message.content)

    dataset.loc[dataset["patient_filename"] == patient_filename, f'{prompt_method}_{stage_type}_reasoning'] = response["reasoning"]
    dataset.loc[dataset["patient_filename"] == patient_filename, f'{prompt_method}_{stage_type}_stage'] = response["stage"]

    return dataset

In [30]:
with open("/home/yl3427/cylab/rag_tnm/src/context.json", "r") as f:
    context = json.load(f)

rag_raw_t14 = context["rag_raw_t14"]
rag_raw_n03 = context["rag_raw_n03"]
ltm_zs_t14 = context["ltm_zs_t14"]
ltm_zs_n03 = context["ltm_zs_n03"]
ltm_rag1_t14 = context["ltm_rag1_t14"]
ltm_rag1_n03 = context["ltm_rag1_n03"]
ltm_rag2_t14 = context["ltm_rag2_t14"]
ltm_rag2_n03 = context["ltm_rag2_n03"]

df = pd.read_csv(f"/home/yl3427/cylab/rag_tnm/rag_result/0929_ltm_rag2.csv")

prompt_method = "ltm_rag1" # zscot, rag_raw, ltm_zs, ltm_rag1, ltm_rag2
stage_type = "n" # t, n
context = ltm_rag1_n03

# key: f"{prompt_method}_{stage_type}"
prompt = {"zscot_t": zscot_t14, "zscot_n": zscot_n03, 
          "rag_raw_t": rag_t14, "rag_raw_n": rag_n03, 
          "ltm_zs_t": ltm_t14, "ltm_zs_n": ltm_n03, 
          "ltm_rag1_t": ltm_t14, "ltm_rag1_n": ltm_n03, 
          "ltm_rag2_t": ltm_t14, "ltm_rag2_n": ltm_n03}

# T14

for idx in range(len(df)):     
    patient_filename = df.loc[idx, "patient_filename"]
    label_column = f'{prompt_method}_{stage_type}_stage'
    if not isinstance(df.loc[df["patient_filename"] == patient_filename, label_column].values.item(), str):
        print(idx) 
        print("before: ", df.loc[df["patient_filename"] == patient_filename, label_column].values.item())
        test_individual_report(df, patient_filename, prompt_method, prompt[f"{prompt_method}_{stage_type}"], stage_type, context)
        print("after: ", df.loc[df["patient_filename"] == patient_filename, label_column].values.item())
        print("label: ", df.loc[df["patient_filename"] == patient_filename, stage_type].values.item())

# df.to_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_{run}_outof_10runs.csv", index=False)

164
before:  nan


APITimeoutError: Request timed out.

In [4]:
df = pd.read_csv("/home/yl3427/cylab/rag_tnm/rag_result/0929_ltm_rag2.csv")
df.columns

Index(['patient_filename', 't', 'text', 'n', 'zscot_t_reasoning',
       'zscot_t_stage', 'zscot_n_reasoning', 'zscot_n_stage',
       'rag_raw_t_reasoning', 'rag_raw_t_stage', 'rag_raw_n_reasoning',
       'rag_raw_n_stage', 'ltm_zs_t_reasoning', 'ltm_zs_t_stage',
       'ltm_zs_n_reasoning', 'ltm_zs_n_stage', 'ltm_rag1_t_reasoning',
       'ltm_rag1_t_stage', 'ltm_rag1_n_reasoning', 'ltm_rag1_n_stage',
       'ltm_rag2_t_reasoning', 'ltm_rag2_t_stage', 'ltm_rag2_n_reasoning',
       'ltm_rag2_n_stage'],
      dtype='object')

In [3]:
print(t14_calculate_metrics(df['t'], df['zscot_t_stage'])['overall'])
print(t14_calculate_metrics(df['t'], df['rag_raw_t_stage'])['overall'])
print(t14_calculate_metrics(df['t'], df['ltm_zs_t_stage'])['overall'])
print(t14_calculate_metrics(df['t'], df['ltm_rag1_t_stage'])['overall'])
print(t14_calculate_metrics(df['t'], df['ltm_rag2_t_stage'])['overall'])

NameError: name 'pred' is not defined

In [15]:
for idx, row in df.iterrows():
    if idx == 4:
        break
    report = row['text']
    print(report)

    

Path No.: Date Obtained: (Age: ). Date Received: F. See Addendum/Procedure. SPECIMEN: A:Lymph node, right axilla sentinel node, biopsy. B:Breast, right, lumpectomy. C:Lymph nodes, right axilla, dissection. DIAGNOSIS(ES): A. Lymph node, right axilla sentinel node, biopsy: - Carcinoma in 1 sentinel node following carcinoma of right breast. B. Breast, right, lumpectomy: - Carcinoma, invasive ductal type, moderately-differentiated, with focal micropapillary features,. Nottingham's score 5 (2+2+1). - Carcinoma, intraductal, comedo type with microcalcifications. - Lobular neoplasia, focal, classical type. - Fibrocystic disease, proliferative, with apocrine metaplasia, sclerosing adenosis and. microcalcifications. - Cicatricial fibrosis and organizing granulation tissue with fat necrosis, consistent with previous. biopsy site. - Fibroadenoma, microscopic. C. Lymph nodes, right axilla, dissection: - No evidence of carcinoma in 14 lymph nodes. Date Dictated: CLINICAL INFORMATION: Breast cancer.

In [11]:
df[['t', 'zscot_t_stage', 'rag_raw_t_stage', 'ltm_zs_t_stage', 'ltm_rag1_t_stage', 'ltm_rag2_t_stage']]

Unnamed: 0,t,zscot_t_stage,rag_raw_t_stage,ltm_zs_t_stage,ltm_rag1_t_stage,ltm_rag2_t_stage
0,1,T2,T3,T2,T3,T2
1,1,T2,T2,T2,T2,T2
2,1,T2,T1,T3,T3,T1c
3,1,T2,T1c,T3,T2,T2
4,1,T2,T3,T2,T3,T2
...,...,...,...,...,...,...
795,3,T2,T2,T2,T2,T2
796,0,T2,T1c,T2,T2,T2
797,2,T2,T2,T2,T3,T2
798,2,T2,T1,T2,T3,T2


In [None]:
t14_calculate_metrics(df['t'], df['cmem_t_ans_str'])

In [None]:
n03_calculate_metrics(n_zscot_df['n'], n_zscot_df['zs_n_ans_str'])

In [None]:
n03_calculate_metrics(df['n'], df['cmem_n_ans_str'])

# Filter data for Qualitative Analysis

### T14

In [None]:
t_train_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/t14_memory_dataset0.csv")
memory_dict_t = {}
for idx, row in t_train_df.iterrows():
    memory_dict_t[f"{idx+1}"] = row['cmem_t_memory_str']


for i in ['10', '20', '30', '40', '50', '60', '70', '80', '90', '100']:
    print(f"Memory at {i}")
    print(memory_dict_t[i])
    print()
    print(len(memory_dict_t[i]))
    print("--------------------------------------------------")

In [None]:
t_test_df.columns

In [None]:
t_groundtruth_issue = {"TCGA-B6-A0IE.DFCA9C6E-710E-4645-9CFC-A908AAD583F3", "TCGA-JL-A3YX.25782EF0-8786-446E-ADBA-21F489844237", 
                       "TCGA-B6-A0IE.DFCA9C6E-710E-4645-9CFC-A908AAD583F3","TCGA-BH-A208.4F943D12-E769-45F3-86BE-75193786DD4E", 
                       "TCGA-AO-A0J9.1E3F3136-6D86-4470-85AA-55B11C9E24CD", "TCGA-BH-A1FM.DA6A0EC9-6E20-4E4A-9B7F-A32EFF7627AD", 
                       "TCGA-D8-A1JS.0EA57ABF-E3DA-4862-BAB8-A6E36408AC42", "TCGA-JL-A3YX.25782EF0-8786-446E-ADBA-21F489844237",
                       "TCGA-BH-A1FM.DA6A0EC9-6E20-4E4A-9B7F-A32EFF7627AD"}
t_groundtruth_issue

In [None]:
t_test_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_0_outof_10runs_verified_wm_for_40.csv").sort_values(by="patient_filename")
t_zscot_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_t14_zscot_test_800_verified_wm.csv").sort_values(by="patient_filename")

split_ids = t_test_df.patient_filename
t_zscot_df = t_zscot_df[t_zscot_df.patient_filename.isin(split_ids)]
for memory_patient, zs_patient in zip(t_test_df.patient_filename, t_zscot_df.patient_filename):
    assert memory_patient == zs_patient

# output_dir = "/secure/shared_data/studio_label/zscot_error_t_verified"
# os.makedirs(output_dir, exist_ok=True)

ids_set = set()
for idx, (filename, label, memory_ans, zscot_ans, memory_rsn, zscot_rsn, memory_feedback, zscot_feedback, memory_final_stage, zscot_final_stage) in enumerate(zip(t_test_df.patient_filename, t_test_df.t, t_test_df.cmem_t_40reports_ans_str, t_zscot_df.zs_t_ans_str, t_test_df.cmem_t_40reasoning, t_zscot_df.zs_t_reasoning, t_test_df.t_feedback, t_zscot_df.t_feedback, t_test_df.t_final_stage, t_zscot_df.t_final_stage)):
    if filename in t_groundtruth_issue:
        continue
    # if (f"T{label+1}" in zs_ans.upper()) and (f"T{label+1}" in memory_ans.upper()): # cases where both are correct
    if (f"T{label+1}" not in zscot_ans.upper()) and (f"T{label+1}" in memory_ans.upper()): # cases where only memory was correct
    # if (f"T{label+1}" in zscot_ans.upper()) and (f"T{label+1}" not in memory_ans.upper()): # cases where only zs was correct
    # if (f"T{label+1}" not in zscot_ans.upper()) and (f"T{label+1}" not in memory_ans.upper()): # cases where both were wrong
        ids_set.add(filename)
        data = {
            "data": {
                "humanMachineDialogue": [
                    {"author": "Patient filename", "text": filename},
                    {"author": "Memory Reasoning", "text": memory_rsn},
                    {"author": "ZS Reasoning", "text": zscot_rsn}, 
                    {"author": "Answer", "text": f"T{label+1}"},
                    {"author": "Memory Answer", "text": memory_ans},
                    {"author": "ZS Answer", "text": zscot_ans},
                    {"author": "Memory Feedback", "text": memory_feedback},
                    {"author": "ZS Feedback", "text": zscot_feedback},
                    {"author": "Memory Final Answer", "text": memory_final_stage},
                    {"author": "ZS Final Answer", "text": zscot_final_stage} 
                ]
            }
        }
        print(idx)
        print(data["data"]["humanMachineDialogue"][3])
        print(data["data"]["humanMachineDialogue"][4])
        print(data["data"]["humanMachineDialogue"][5])
        print(data["data"]["humanMachineDialogue"][6])
        print(data["data"]["humanMachineDialogue"][7])
        print(data["data"]["humanMachineDialogue"][8])
        print(data["data"]["humanMachineDialogue"][9])
        print("--------------------------")
        
        # file_name = f"t3_{idx}.json"
        # file_path = os.path.join(output_dir, file_name)
        # with open(file_path, 'w') as json_file:
        #     json.dump(data, json_file, indent=4)

In [None]:
len(ids_set)

In [None]:
KEPA_T1 = 0
KEPA_T2 = 0
KEPA_T3 = 0
KEPA_T4 = 0
ZSCOT_T1 = 0
ZSCOT_T2 = 0
ZSCOT_T3 = 0
ZSCOT_T4 = 0
GT_T1 = 0
GT_T2 = 0
GT_T3 = 0
GT_T4 = 0

run_lst = [0, 1, 2, 3, 4, 5, 6, 8]
for run in run_lst:
    # print(f"Run {run}, memory 40")
    t_test_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_{run}_outof_10runs.csv").sort_values(by="patient_filename")
    t_zs_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_t14_zs_test_800.csv").sort_values(by="patient_filename")
    t_zscot_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_t14_zscot_test_800.csv").sort_values(by="patient_filename")

    split_ids = t_test_df.patient_filename
    t_zs_df = t_zs_df[t_zs_df.patient_filename.isin(split_ids)]
    t_zscot_df = t_zscot_df[t_zscot_df.patient_filename.isin(split_ids)]
    for memory_patient, zs_patient, zscot_patient in zip(t_test_df.patient_filename, t_zs_df.patient_filename, t_zscot_df.patient_filename):
        assert memory_patient == zs_patient and memory_patient == zscot_patient

    ids_set = set()
    label_dict = {"T1": set(), "T2": set(), "T3": set(), "T4": set()}
    memory_pred_dict = {"T1": set(), "T2": set(), "T3": set(), "T4": set()}
    memory_correct_dict = {"T1": set(), "T2": set(), "T3": set(), "T4": set()}
    zs_pred_dict = {"T1": set(), "T2": set(), "T3": set(), "T4": set()}
    zs_correct_dict = {"T1": set(), "T2": set(), "T3": set(), "T4": set()}
    zscot_pred_dict = {"T1": set(), "T2": set(), "T3": set(), "T4": set()}
    zscot_correct_dict = {"T1": set(), "T2": set(), "T3": set(), "T4": set()}

    three_common_dict = {"T1": set(), "T2": set(), "T3": set(), "T4": set()}
    both_same_pred_but_wrong_answer_dict = {"T1": set(), "T2": set(), "T3": set(), "T4": set()} # each key represents pred
    neither_correct_dict = {"T1": set(), "T2": set(), "T3": set(), "T4": set()} # each key represents label
    for idx, (filename, label, memory_ans, zs_ans, zscot_ans) in enumerate(zip(t_test_df.patient_filename, t_test_df.t, t_test_df.cmem_t_40reports_ans_str, t_zs_df.zs_t_ans_str, t_zscot_df.zs_t_ans_str)):
        if f"T{label+1}" == "T1":
            label_dict["T1"].add(filename)
            if f"T{label+1}" in zscot_ans.upper() and f"T{label+1}" in memory_ans.upper():
                three_common_dict["T1"].add(filename)
            elif f"T{label+1}" not in zscot_ans.upper() and f"T{label+1}" not in memory_ans.upper():
                neither_correct_dict["T1"].add(filename)
        if f"T{label+1}" == "T2":
            label_dict["T2"].add(filename)
            if f"T{label+1}" in zscot_ans.upper() and f"T{label+1}" in memory_ans.upper():
                three_common_dict["T2"].add(filename)
            elif f"T{label+1}" not in zscot_ans.upper() and f"T{label+1}" not in memory_ans.upper():
                neither_correct_dict["T2"].add(filename)
        if f"T{label+1}" == "T3":
            label_dict["T3"].add(filename)
            if f"T{label+1}" in zscot_ans.upper() and f"T{label+1}" in memory_ans.upper():
                three_common_dict["T3"].add(filename)
            elif f"T{label+1}" not in zscot_ans.upper() and f"T{label+1}" not in memory_ans.upper():
                neither_correct_dict["T3"].add(filename)
        if f"T{label+1}" == "T4":
            label_dict["T4"].add(filename)
            if f"T{label+1}" in zscot_ans.upper() and f"T{label+1}" in memory_ans.upper():
                three_common_dict["T4"].add(filename)
            elif f"T{label+1}" not in zscot_ans.upper() and f"T{label+1}" not in memory_ans.upper():
                neither_correct_dict["T4"].add(filename)
        # if (f"T{label+1}" in zs_ans.upper()) and (f"T{label+1}" in memory_ans.upper()): # cases where both are correct
        # if (f"T{label+1}" not in zs_ans.upper()) and (f"T{label+1}" in memory_ans.upper()): # cases where only memory was correct
        # if (f"T{label+1}" in zs_ans.upper()) and (f"T{label+1}" not in memory_ans.upper()): # cases where only zs was correct
        # if (f"T{label+1}" not in zs_ans.upper()) and (f"T{label+1}" not in memory_ans.upper()): # cases where both were wrong
        if "T1" in memory_ans.upper():
            memory_pred_dict["T1"].add(filename)
            if "T1" == f"T{label+1}":
                memory_correct_dict["T1"].add(filename)
        if "T2" in memory_ans.upper():
            memory_pred_dict["T2"].add(filename)
            if "T2" == f"T{label+1}":
                memory_correct_dict["T2"].add(filename)
        if "T3" in memory_ans.upper():
            memory_pred_dict["T3"].add(filename)
            if "T3" == f"T{label+1}":
                memory_correct_dict["T3"].add(filename)
        if "T4" in memory_ans.upper():
            memory_pred_dict["T4"].add(filename)
            if "T4" == f"T{label+1}":
                memory_correct_dict["T4"].add(filename)
        if "T1" in zs_ans.upper():
            zs_pred_dict["T1"].add(filename)
            if "T1" == f"T{label+1}":
                zs_correct_dict["T1"].add(filename)
        if "T2" in zs_ans.upper():
            zs_pred_dict["T2"].add(filename)
            if "T2" == f"T{label+1}":
                zs_correct_dict["T2"].add(filename)
        if "T3" in zs_ans.upper():
            zs_pred_dict["T3"].add(filename)
            if "T3" == f"T{label+1}":
                zs_correct_dict["T3"].add(filename)
        if "T4" in zs_ans.upper():
            zs_pred_dict["T4"].add(filename)
            if "T4" == f"T{label+1}":
                zs_correct_dict["T4"].add(filename)
        if "T1" in zscot_ans.upper():
            zscot_pred_dict["T1"].add(filename)
            if "T1" == f"T{label+1}":
                zscot_correct_dict["T1"].add(filename)
        if "T2" in zscot_ans.upper():
            zscot_pred_dict["T2"].add(filename)
            if "T2" == f"T{label+1}":
                zscot_correct_dict["T2"].add(filename)
        if "T3" in zscot_ans.upper():
            zscot_pred_dict["T3"].add(filename)
            if "T3" == f"T{label+1}":
                zscot_correct_dict["T3"].add(filename)
        if "T4" in zscot_ans.upper():
            zscot_pred_dict["T4"].add(filename)
            if "T4" == f"T{label+1}":
                zscot_correct_dict["T4"].add(filename)
        if "T1" in memory_ans.upper() and "T1" in zscot_ans.upper() and "T1" != f"T{label+1}":
            both_same_pred_but_wrong_answer_dict["T1"].add(filename)
        if "T2" in memory_ans.upper() and "T2" in zscot_ans.upper() and "T2" != f"T{label+1}":
            both_same_pred_but_wrong_answer_dict["T2"].add(filename)
        if "T3" in memory_ans.upper() and "T3" in zscot_ans.upper() and "T3" != f"T{label+1}":
            both_same_pred_but_wrong_answer_dict["T3"].add(filename)
        if "T4" in memory_ans.upper() and "T4" in zscot_ans.upper() and "T4" != f"T{label+1}":
            both_same_pred_but_wrong_answer_dict["T4"].add(filename)
            # ids_set.add(filename)
          
    df1 = pd.DataFrame(
        {
            "T1": [f'{len(memory_pred_dict["T1"])} pred ({len(memory_correct_dict["T1"])} correct)', f'{len(zs_pred_dict["T1"])} pred ({len(zs_correct_dict["T1"])} correct)', f'{len(zscot_pred_dict["T1"])} pred ({len(zscot_correct_dict["T1"])} correct)', f'{len(label_dict["T1"])}'],
            "T2": [f'{len(memory_pred_dict["T2"])} pred ({len(memory_correct_dict["T2"])} correct)', f'{len(zs_pred_dict["T2"])} pred ({len(zs_correct_dict["T2"])} correct)', f'{len(zscot_pred_dict["T2"])} pred ({len(zscot_correct_dict["T2"])} correct)', f'{len(label_dict["T2"])}'],
            "T3": [f'{len(memory_pred_dict["T3"])} pred ({len(memory_correct_dict["T3"])} correct)', f'{len(zs_pred_dict["T3"])} pred ({len(zs_correct_dict["T3"])} correct)', f'{len(zscot_pred_dict["T3"])} pred ({len(zscot_correct_dict["T3"])} correct)', f'{len(label_dict["T3"])}'],
            "T4": [f'{len(memory_pred_dict["T4"])} pred ({len(memory_correct_dict["T4"])} correct)', f'{len(zs_pred_dict["T4"])} pred ({len(zs_correct_dict["T4"])} correct)', f'{len(zscot_pred_dict["T4"])} pred ({len(zscot_correct_dict["T4"])} correct)', f'{len(label_dict["T4"])}']
        },
        index=["KEPA", "ZS", "ZSCOT", "Ground Truth"]
    )
    print(f"{len(memory_pred_dict['T1'])}, {len(memory_pred_dict['T2'])}, {len(memory_pred_dict['T3'])}, {len(memory_pred_dict['T4'])}", end="\t")
    print(f"{len(zs_pred_dict['T1'])}, {len(zs_pred_dict['T2'])}, {len(zs_pred_dict['T3'])}, {len(zs_pred_dict['T4'])}", end="\t")
    print(f"{len(zscot_pred_dict['T1'])}, {len(zscot_pred_dict['T2'])}, {len(zscot_pred_dict['T3'])}, {len(zscot_pred_dict['T4'])}", end="\t")
    print(f"{len(label_dict['T1'])}, {len(label_dict['T2'])}, {len(label_dict['T3'])}, {len(label_dict['T4'])}")
    KEPA_T1 += len(memory_pred_dict["T1"])
    KEPA_T2 += len(memory_pred_dict["T2"])
    KEPA_T3 += len(memory_pred_dict["T3"])
    KEPA_T4 += len(memory_pred_dict["T4"])
    ZSCOT_T1 += len(zscot_pred_dict["T1"])
    ZSCOT_T2 += len(zscot_pred_dict["T2"])
    ZSCOT_T3 += len(zscot_pred_dict["T3"])
    ZSCOT_T4 += len(zscot_pred_dict["T4"])
    GT_T1 += len(label_dict["T1"])
    GT_T2 += len(label_dict["T2"])
    GT_T3 += len(label_dict["T3"])
    GT_T4 += len(label_dict["T4"])
    # display(df1.transpose())
    memory_pv = np.array([len(memory_pred_dict["T1"]), len(memory_pred_dict["T2"]), len(memory_pred_dict["T3"]), len(memory_pred_dict["T4"])]) / (len(memory_pred_dict["T1"]) + len(memory_pred_dict["T2"]) + len(memory_pred_dict["T3"]) + len(memory_pred_dict["T4"]))
    zs_pv = np.array([len(zs_pred_dict["T1"]), len(zs_pred_dict["T2"]), len(zs_pred_dict["T3"]), len(zs_pred_dict["T4"])]) / (len(zs_pred_dict["T1"]) + len(zs_pred_dict["T2"]) + len(zs_pred_dict["T3"]) + len(zs_pred_dict["T4"]))
    zscot_pv = np.array([len(zscot_pred_dict["T1"]), len(zscot_pred_dict["T2"]), len(zscot_pred_dict["T3"]), len(zscot_pred_dict["T4"])]) / (len(zscot_pred_dict["T1"]) + len(zscot_pred_dict["T2"]) + len(zscot_pred_dict["T3"]) + len(zscot_pred_dict["T4"]))
    gt_pv = np.array([len(label_dict["T1"]), len(label_dict["T2"]), len(label_dict["T3"]), len(label_dict["T4"])]) / (len(label_dict["T1"]) + len(label_dict["T2"]) + len(label_dict["T3"]) + len(label_dict["T4"]))
    # print(memory_pv, zscot_pv, gt_pv)
    # print("Jensen-Shannon distance between kepa and ground truth: " ,jensenshannon(memory_pv, gt_pv))
    # print("Jensen-Shannon distance between zs and ground truth: " ,jensenshannon(zs_pv, gt_pv))
    # print("Jensen-Shannon distance between zscot and ground truth: " ,jensenshannon(zscot_pv, gt_pv))
                                                                                                                                        
    df2 = pd.DataFrame(
        {
            "T1": [f'{len(three_common_dict["T1"])}', f'{len(memory_pred_dict["T1"])}', f'{len(zscot_pred_dict["T1"])}', f'{len(both_same_pred_but_wrong_answer_dict["T1"])}'],
            "T2": [f'{len(three_common_dict["T2"])}', f'{len(memory_pred_dict["T2"])}', f'{len(zscot_pred_dict["T2"])}', f'{len(both_same_pred_but_wrong_answer_dict["T2"])}'],
            "T3": [f'{len(three_common_dict["T3"])}', f'{len(memory_pred_dict["T3"])}', f'{len(zscot_pred_dict["T3"])}', f'{len(both_same_pred_but_wrong_answer_dict["T3"])}'],
            "T4": [f'{len(three_common_dict["T4"])}', f'{len(memory_pred_dict["T4"])}', f'{len(zscot_pred_dict["T4"])}', f'{len(both_same_pred_but_wrong_answer_dict["T4"])}']
        },  
        index=["3 common", "KEPA", "ZSCOT", "Both same pred but wrong answer"]         
    )
    # display(df2.transpose())
 

# print("Total")
# df3 = pd.DataFrame(
#     {
#         "T1": [f'{KEPA_T1/len(run_lst)}', f'{ZSCOT_T1/len(run_lst)}', f'{GT_T1/len(run_lst)}'],
#         "T2": [f'{KEPA_T2/len(run_lst)}', f'{ZSCOT_T2/len(run_lst)}', f'{GT_T2/len(run_lst)}'],
#         "T3": [f'{KEPA_T3/len(run_lst)}', f'{ZSCOT_T3/len(run_lst)}', f'{GT_T3/len(run_lst)}'],
#         "T4": [f'{KEPA_T4/len(run_lst)}', f'{ZSCOT_T4/len(run_lst)}', f'{GT_T4/len(run_lst)}']
#     },
#     index=["KEPA", "ZSCOT", "Ground Truth"]
# )
# display(df3.transpose())

In [None]:
len(memory_pred_dict['T3'] - three_common_dict["T3"] - both_same_pred_but_wrong_answer_dict["T3"])

In [None]:
print("KEPA")
for k, v in memory_pred_dict.items():
    print(f"{k}: total {len(v)} -> {len(memory_correct_dict[k])} correct")

In [None]:
print("ZSCOT")
for k, v in zscot_pred_dict.items():
    print(f"{k}: total {len(v)} -> {len(zscot_correct_dict[k])} correct")

In [None]:
len(ids_set)

### N03

In [None]:
n_train_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/n03_memory_dataset1.csv")
memory_dict_n = {}
for idx, row in n_train_df.iterrows():
    memory_dict_n[f"{idx+1}"] = row['cmem_n_memory_str']


# for i in ['10', '20', '30', '40', '50', '60', '70', '80', '90', '100']:
#     print(f"Memory at {i}")
#     print(memory_dict_n[i])
#     print()
#     print(len(memory_dict_n[i]))
#     print("--------------------------------------------------")

print(memory_dict_n['40'])

In [None]:
n_groundtruth_issue = {"TCGA-BH-A1FJ.8169BE67-03C8-4F4D-9A60-200705B795AE", "TCGA-BH-A6R9.1DB8FAFB-FC4A-4401-8316-30FB5352335D",
                       "TCGA-E9-A1QZ.864BB34A-1008-480C-A3B5-A2C616E95C49", "TCGA-A8-A06Z.956F45E5-A8C6-4A4A-9D1F-D31912180584",
                       "TCGA-B6-A0IK.3A38A97C-2CBB-4802-9528-A4BBD62AEA4A", "TCGA-B6-A0WV.506BFD3B-240B-440E-B7A0-E596FC0B7F72",
                       "TCGA-E9-A3X8.00058FFD-35E6-4891-8B01-DAB3AE9EBF78", "TCGA-GM-A2DA.F3CD8E6E-B02F-4D5D-B895-6DF063F61603", 
                       "TCGA-B6-A0IH.12C64846-1CB3-42E4-B307-54C7AD12F530", "TCGA-B6-A0WW.F05F5886-DC5D-4685-B2BF-57A68A0BB7B9", 
                       "TCGA-B6-A0X1.D792031E-2CCE-4341-B3B3-C7D1D84F8F6B", "TCGA-GM-A2DA.F3CD8E6E-B02F-4D5D-B895-6DF063F61603",
                       "TCGA-B6-A1KN.72996825-1FFA-4C51-8DB0-DA74BCB595EB"}
n_groundtruth_issue

In [None]:
n_test_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_n03_dynamic_test_1_outof_10runs_numerical_verified_for_40.csv").sort_values(by="patient_filename")
n_zscot_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_n03_zscot_test_800_numerical_verified.csv").sort_values(by="patient_filename")

split_ids = n_test_df.patient_filename
n_zscot_df = n_zscot_df[n_zscot_df.patient_filename.isin(split_ids)]
for memory_patient, zs_patient in zip(n_test_df.patient_filename, n_zscot_df.patient_filename):
    assert memory_patient == zs_patient

# output_dir = "/secure/shared_data/studio_label/n3"
# os.makedirs(output_dir, exist_ok=True)

ids_set = set()
for idx, (filename, label, memory_ans, zscot_ans, memory_rsn, zscot_rsn, memory_feedback, zscot_feedback, memory_final_stage, zscot_final_stage) in enumerate(zip(n_test_df.patient_filename, n_test_df.n, n_test_df[f"cmem_n_40reports_ans_str"], n_zscot_df.zs_n_ans_str, n_test_df[f"cmem_n_40reasoning"], n_zscot_df.zs_n_reasoning, n_test_df.n_feedback, n_zscot_df.n_feedback, n_test_df.n_final_stage, n_zscot_df.n_final_stage)):
    memory_ans = memory_ans.upper().replace("NO", "N0").replace("NL", "N1")
    zscot_ans = zscot_ans.upper().replace("NO", "N0").replace("NL", "N1")
    if filename in n_groundtruth_issue:
        continue
    # if (f"N{label}" in zscot_ans) and (f"N{label}" in memory_ans): # cases where both are correct
    # if (f"N{label}" not in zscot_ans) and (f"N{label}" in memory_ans): # cases where only memory was correct
    if (f"N{label}" in zscot_ans) and (f"N{label}" not in memory_ans): # cases where only zs was correct
    # if (f"N{label}" not in zscot_ans) and (f"N{label}" not in memory_ans): # cases where both were wrong
        ids_set.add(filename)
        data = {
            "data": {
                "humanMachineDialogue": [
                    {"author": "Patient filename", "text": filename},
                    {"author": "Memory Reasoning", "text": memory_rsn},
                    {"author": "ZS Reasoning", "text": zscot_rsn}, 
                    {"author": "Answer", "text": f"N{label}"},
                    {"author": "Memory Answer", "text": memory_ans},
                    {"author": "ZS Answer", "text": zscot_ans},
                    {"author": "Memory Feedback", "text": memory_feedback},
                    {"author": "ZS Feedback", "text": zscot_feedback},
                    {"author": "Memory Final Answer", "text": memory_final_stage},
                    {"author": "ZS Final Answer", "text": zscot_final_stage}
                ]
            }
        }
        print(idx)
        print(data["data"]["humanMachineDialogue"][3])
        print(data["data"]["humanMachineDialogue"][4])
        print(data["data"]["humanMachineDialogue"][5])
        print(data["data"]["humanMachineDialogue"][6])
        print(data["data"]["humanMachineDialogue"][7])
        print(data["data"]["humanMachineDialogue"][8])
        print(data["data"]["humanMachineDialogue"][9])
        print("--------------------------")   
        
        # file_name = f"n3_{idx}.json"
        # file_path = os.path.join(output_dir, file_name)
        # with open(file_path, 'w') as json_file:
        #     json.dump(data, json_file, indent=4)

In [None]:
len(ids_set)

In [None]:
KEPA_N0 = 0
KEPA_N1 = 0
KEPA_N2 = 0
KEPA_N3 = 0
ZSCOT_N0 = 0
ZSCOT_N1 = 0
ZSCOT_N2 = 0
ZSCOT_N3 = 0
GT_N0 = 0
GT_N1 = 0
GT_N2 = 0
GT_N3 = 0

run_lst = [0,1,3,4,5,6,7,9]
for run in run_lst:
    # print(f"Run {run}, memory 40")
    n_test_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_n03_dynamic_test_{run}_outof_10runs.csv").sort_values(by="patient_filename")
    n_zs_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_n03_zs_test_800.csv").sort_values(by="patient_filename")
    n_zscot_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_n03_zscot_test_800.csv").sort_values(by="patient_filename")

    split_ids = n_test_df.patient_filename
    n_zs_df = n_zs_df[n_zs_df.patient_filename.isin(split_ids)]
    n_zscot_df = n_zscot_df[n_zscot_df.patient_filename.isin(split_ids)]
    for memory_patient, zs_patient, zscot_patient in zip(n_test_df.patient_filename, n_zs_df.patient_filename, n_zscot_df.patient_filename):
        assert memory_patient == zs_patient and memory_patient == zscot_patient

    ids_set = set()
    label_dict = {"N0": set(), "N1": set(), "N2": set(), "N3": set()}
    memory_pred_dict = {"N0": set(), "N1": set(), "N2": set(), "N3": set()}
    memory_correct_dict = {"N0": set(), "N1": set(), "N2": set(), "N3": set()}
    zs_pred_dict = {"N0": set(), "N1": set(), "N2": set(), "N3": set()}
    zs_correct_dict = {"N0": set(), "N1": set(), "N2": set(), "N3": set()}
    zscot_pred_dict = {"N0": set(), "N1": set(), "N2": set(), "N3": set()}
    zscot_correct_dict = {"N0": set(), "N1": set(), "N2": set(), "N3": set()}
    three_common_dict = {"N0": set(), "N1": set(), "N2": set(), "N3": set()}
    both_same_pred_but_wrong_answer_dict = {"N0": set(), "N1": set(), "N2": set(), "N3": set()} # each key represents pred
    neither_correct_dict = {"N0": set(), "N1": set(), "N2": set(), "N3": set()} # each key represents label
    for idx, (filename, label, memory_ans, zs_ans, zscot_ans) in enumerate(zip(n_test_df.patient_filename, n_test_df.n, n_test_df[f"cmem_n_40reports_ans_str"], n_zs_df.zs_n_ans_str, n_zscot_df.zs_n_ans_str)):
        memory_ans = memory_ans.upper().replace("NO", "N0").replace("NL", "N1")
        zscot_ans = zscot_ans.upper().replace("NO", "N0").replace("NL", "N1")
        if f"N{label}" == "N0":
            label_dict["N0"].add(filename)
            if f"N{label}" in zscot_ans and f"N{label}" in memory_ans:
                three_common_dict["N0"].add(filename)
            elif f"N{label}" not in zscot_ans and f"N{label}" not in memory_ans:
                neither_correct_dict["N0"].add(filename)
        if f"N{label}" == "N1":
            label_dict["N1"].add(filename)
            if f"N{label}" in zscot_ans and f"N{label}" in memory_ans:
                three_common_dict["N1"].add(filename)
            elif f"N{label}" not in zscot_ans and f"N{label}" not in memory_ans:
                neither_correct_dict["N1"].add(filename)
        if f"N{label}" == "N2":
            label_dict["N2"].add(filename)
            if f"N{label}" in zscot_ans and f"N{label}" in memory_ans:
                three_common_dict["N2"].add(filename)
            elif f"N{label}" not in zscot_ans and f"N{label}" not in memory_ans:
                neither_correct_dict["N2"].add(filename)
        if f"N{label}" == "N3":
            label_dict["N3"].add(filename)
            if f"N{label}" in zscot_ans and f"N{label}" in memory_ans:
                three_common_dict["N3"].add(filename)
            elif f"N{label}" not in zscot_ans and f"N{label}" not in memory_ans:
                neither_correct_dict["N3"].add(filename)
        # if (f"N{label}" in zs_ans) and (f"N{label}" in memory_ans): # cases where both are correct
        # if (f"N{label}" not in zs_ans) and (f"N{label}" in memory_ans): # cases where only memory was correct
        # if (f"N{label}" in zs_ans) and (f"N{label}" not in memory_ans): # cases where only zs was correct
        # if (f"N{label}" not in zs_ans) and (f"N{label}" not in memory_ans): # cases where both were wrong
        if "N0" in memory_ans:
            memory_pred_dict["N0"].add(filename)
            if "N0" == f"N{label}":
                memory_correct_dict["N0"].add(filename)
        if "N1" in memory_ans:
            memory_pred_dict["N1"].add(filename)
            if "N1" == f"N{label}":
                memory_correct_dict["N1"].add(filename)
        if "N2" in memory_ans:
            memory_pred_dict["N2"].add(filename)
            if "N2" == f"N{label}":
                memory_correct_dict["N2"].add(filename)
        if "N3" in memory_ans:
            memory_pred_dict["N3"].add(filename)
            if "N3" == f"N{label}":
                memory_correct_dict["N3"].add(filename)
        if "N0" in zs_ans:
            zs_pred_dict["N0"].add(filename)
            if "N0" == f"N{label}":
                zs_correct_dict["N0"].add(filename)
        if "N1" in zs_ans:
            zs_pred_dict["N1"].add(filename)
            if "N1" == f"N{label}":
                zs_correct_dict["N1"].add(filename)
        if "N2" in zs_ans:
            zs_pred_dict["N2"].add(filename)
            if "N2" == f"N{label}":
                zs_correct_dict["N2"].add(filename)
        if "N3" in zs_ans:
            zs_pred_dict["N3"].add(filename)
            if "N3" == f"N{label}":
                zs_correct_dict["N3"].add(filename)
        if "N0" in zscot_ans:
            zscot_pred_dict["N0"].add(filename)
            if "N0" == f"N{label}":
                zscot_correct_dict["N0"].add(filename)
        if "N1" in zscot_ans:
            zscot_pred_dict["N1"].add(filename)
            if "N1" == f"N{label}":
                zscot_correct_dict["N1"].add(filename)
        if "N2" in zscot_ans:
            zscot_pred_dict["N2"].add(filename)
            if "N2" == f"N{label}":
                zscot_correct_dict["N2"].add(filename)
        if "N3" in zscot_ans:
            zscot_pred_dict["N3"].add(filename)
            if "N3" == f"N{label}":
                zscot_correct_dict["N3"].add(filename)
        if "N0" in memory_ans and "N0" in zscot_ans and "N0" != f"N{label}":
            both_same_pred_but_wrong_answer_dict["N0"].add(filename)
        if "N1" in memory_ans and "N1" in zscot_ans and "N1" != f"N{label}":
            both_same_pred_but_wrong_answer_dict["N1"].add(filename)
        if "N2" in memory_ans and "N2" in zscot_ans and "N2" != f"N{label}":
            both_same_pred_but_wrong_answer_dict["N2"].add(filename)
        if "N3" in memory_ans and "N3" in zscot_ans and "N3" != f"N{label}":
            both_same_pred_but_wrong_answer_dict["N3"].add(filename)
            # ids_set.add(filename)

    df1 = pd.DataFrame(
        {
            "N0": [f'{len(memory_pred_dict["N0"])} pred ({len(memory_correct_dict["N0"])} correct)', f'{len(zs_pred_dict["N0"])} pred ({len(zs_correct_dict["N0"])} correct)', f'{len(zscot_pred_dict["N0"])} pred ({len(zscot_correct_dict["N0"])} correct)', f'{len(label_dict["N0"])}'],
            "N1": [f'{len(memory_pred_dict["N1"])} pred ({len(memory_correct_dict["N1"])} correct)', f'{len(zs_pred_dict["N1"])} pred ({len(zs_correct_dict["N1"])} correct)', f'{len(zscot_pred_dict["N1"])} pred ({len(zscot_correct_dict["N1"])} correct)', f'{len(label_dict["N1"])}'],
            "N2": [f'{len(memory_pred_dict["N2"])} pred ({len(memory_correct_dict["N2"])} correct)', f'{len(zs_pred_dict["N2"])} pred ({len(zs_correct_dict["N2"])} correct)', f'{len(zscot_pred_dict["N2"])} pred ({len(zscot_correct_dict["N2"])} correct)', f'{len(label_dict["N2"])}'],
            "N3": [f'{len(memory_pred_dict["N3"])} pred ({len(memory_correct_dict["N3"])} correct)', f'{len(zs_pred_dict["N3"])} pred ({len(zs_correct_dict["N3"])} correct)', f'{len(zscot_pred_dict["N3"])} pred ({len(zscot_correct_dict["N3"])} correct)', f'{len(label_dict["N3"])}']
        },
        index=["KEPA", "ZS", "ZSCOT", "Ground Truth"]
    )
    # display(df1.transpose())
    print(f"{len(memory_pred_dict['N0'])}, {len(memory_pred_dict['N1'])}, {len(memory_pred_dict['N2'])}, {len(memory_pred_dict['N3'])}", end="\t")
    print(f"{len(zs_pred_dict['N0'])}, {len(zs_pred_dict['N1'])}, {len(zs_pred_dict['N2'])}, {len(zs_pred_dict['N3'])}", end="\t")
    print(f"{len(zscot_pred_dict['N0'])}, {len(zscot_pred_dict['N1'])}, {len(zscot_pred_dict['N2'])}, {len(zscot_pred_dict['N3'])}", end="\t")
    print(f"{len(label_dict['N0'])}, {len(label_dict['N1'])}, {len(label_dict['N2'])}, {len(label_dict['N3'])}")
    KEPA_N0 += len(memory_pred_dict["N0"])
    KEPA_N1 += len(memory_pred_dict["N1"])
    KEPA_N2 += len(memory_pred_dict["N2"])
    KEPA_N3 += len(memory_pred_dict["N3"])
    ZSCOT_N0 += len(zscot_pred_dict["N0"])
    ZSCOT_N1 += len(zscot_pred_dict["N1"])
    ZSCOT_N2 += len(zscot_pred_dict["N2"])
    ZSCOT_N3 += len(zscot_pred_dict["N3"])
    GT_N0 += len(label_dict["N0"])
    GT_N1 += len(label_dict["N1"])
    GT_N2 += len(label_dict["N2"])
    GT_N3 += len(label_dict["N3"])

    memory_pv = np.array([len(memory_pred_dict["N0"]), len(memory_pred_dict["N1"]), len(memory_pred_dict["N2"]), len(memory_pred_dict["N3"])]) / (len(memory_pred_dict["N0"]) + len(memory_pred_dict["N1"]) + len(memory_pred_dict["N2"]) + len(memory_pred_dict["N3"]))
    zs_pv = np.array([len(zs_pred_dict["N0"]), len(zs_pred_dict["N1"]), len(zs_pred_dict["N2"]), len(zs_pred_dict["N3"])]) / (len(zs_pred_dict["N0"]) + len(zs_pred_dict["N1"]) + len(zs_pred_dict["N2"]) + len(zs_pred_dict["N3"]))
    zscot_pv = np.array([len(zscot_pred_dict["N0"]), len(zscot_pred_dict["N1"]), len(zscot_pred_dict["N2"]), len(zscot_pred_dict["N3"])]) / (len(zscot_pred_dict["N0"]) + len(zscot_pred_dict["N1"]) + len(zscot_pred_dict["N2"]) + len(zscot_pred_dict["N3"]))
    gt_pv = np.array([len(label_dict["N0"]), len(label_dict["N1"]), len(label_dict["N2"]), len(label_dict["N3"])]) / (len(label_dict["N0"]) + len(label_dict["N1"]) + len(label_dict["N2"]) + len(label_dict["N3"]))
    # print(memory_pv, zscot_pv, gt_pv)
    # print("Jensen-Shannon distance between kepa and ground truth: " ,jensenshannon(memory_pv, gt_pv))
    # print("Jensen-Shannon distance between zs and ground truth: " ,jensenshannon(zs_pv, gt_pv))
    # print("Jensen-Shannon distance between zscot and ground truth: " ,jensenshannon(zscot_pv, gt_pv))
    df2 = pd.DataFrame(
        {
            "N0": [f'{len(three_common_dict["N0"])}', f'{len(memory_pred_dict["N0"])}', f'{len(zscot_pred_dict["N0"])}', f'{len(both_same_pred_but_wrong_answer_dict["N0"])}'],
            "N1": [f'{len(three_common_dict["N1"])}', f'{len(memory_pred_dict["N1"])}', f'{len(zscot_pred_dict["N1"])}', f'{len(both_same_pred_but_wrong_answer_dict["N1"])}'],
            "N2": [f'{len(three_common_dict["N2"])}', f'{len(memory_pred_dict["N2"])}', f'{len(zscot_pred_dict["N2"])}', f'{len(both_same_pred_but_wrong_answer_dict["N2"])}'],
            "N3": [f'{len(three_common_dict["N3"])}', f'{len(memory_pred_dict["N3"])}', f'{len(zscot_pred_dict["N3"])}', f'{len(both_same_pred_but_wrong_answer_dict["N3"])}']
        },  
        index=["3 common", "KEPA", "ZSCOT", "Both same pred but wrong answer"]         
    )
    # display(df2.transpose())
    # print()

# print("Total")
df3 = pd.DataFrame(
    {
        "N0": [f'{KEPA_N0/len(run_lst)}', f'{ZSCOT_N0/len(run_lst)}', f'{GT_N0/len(run_lst)}'],
        "N1": [f'{KEPA_N1/len(run_lst)}', f'{ZSCOT_N1/len(run_lst)}', f'{GT_N1/len(run_lst)}'],
        "N2": [f'{KEPA_N2/len(run_lst)}', f'{ZSCOT_N2/len(run_lst)}', f'{GT_N2/len(run_lst)}'],
        "N3": [f'{KEPA_N3/len(run_lst)}', f'{ZSCOT_N3/len(run_lst)}', f'{GT_N3/len(run_lst)}']
    },
    index=["KEPA", "ZSCOT", "Ground Truth"]
)
# display(df3.transpose())

In [None]:
len(memory_pred_dict['N3'] - three_common_dict["N3"] - both_same_pred_but_wrong_answer_dict["N3"])

In [None]:
len(neither_correct_dict['N3'])

In [None]:
len(ids_set)

# Retrieve reports that satisfy a specific condition

In [8]:
def check_consistency(ans, lst):
    if len(lst) != 10:
        print(lst)
        raise ValueError("Exactly 10 arguments are required.")

    lst = [arg.upper().replace("NO", "N0").replace("NL", "N1") for arg in lst]
    if f"N{ans}" not in lst:
        return False
    
    count_dict = {}
    for arg in lst:
        if "N0" in arg:
            arg = "N0"
        elif "N1" in arg:
            arg = "N1"
        elif "N2" in arg:
            arg = "N2"
        elif "N3" in arg:
            arg = "N3"
            
        if arg in count_dict:
            count_dict[arg] += 1
        else:
            count_dict[arg] = 1

        if count_dict[arg] >= 7:
            return False

    print(count_dict)
    return True

In [None]:
def fun():
    file_name = "TCGA-AO-A0J9.1E3F3136-6D86-4470-85AA-55B11C9E24CD"
    mem_reasoning = ""
    txt ="""...""" 
    weird_lst = {}
    for n in range(10):
        test_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_n03_dynamic_test_{n}_outof_10runs.csv")
        train_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/n03_memory_dataset{n}.csv")
        weird_lst[f"{n}_split"] = {}
        for i in range(len(test_df)):
            if file_name == test_df.iloc[i].patient_filename:
            # if txt.strip() == report.strip():
                # print(file_name)
                # print(zs_predict_prompt_n03.format(report=test_df.iloc[i].text))
                print(test_df.iloc[i].text)
                return
                # print(report)
                # print(idx)
            # if check_consistency(test_df.loc[i].n, test_df.loc[i][["cmem_n_10reports_ans_str", "cmem_n_20reports_ans_str", "cmem_n_30reports_ans_str", "cmem_n_40reports_ans_str", "cmem_n_50reports_ans_str", "cmem_n_60reports_ans_str", "cmem_n_70reports_ans_str", "cmem_n_80reports_ans_str", "cmem_n_90reports_ans_str", "cmem_n_100reports_ans_str"]].tolist()):
            #     print(test_df.loc[i].patient_filename)
            #     print(test_df.loc[i].n)
            #     weird_lst[f"{n}_split"][test_df.loc[i].patient_filename] = {"answer": f"N{test_df.loc[i].n}", "report": test_df.loc[i].text, "kepa(mem_reas_pred)": [(mem, reas, pred) for mem, reas, pred in zip(train_df.cmem_n_memory_str.tolist()[9::10],test_df.loc[i][["cmem_n_10reasoning", "cmem_n_20reasoning", "cmem_n_30reasoning", "cmem_n_40reasoning", "cmem_n_50reasoning", "cmem_n_60reasoning", "cmem_n_70reasoning", "cmem_n_80reasoning", "cmem_n_90reasoning", "cmem_n_100reasoning"]].tolist(), test_df.loc[i][["cmem_n_10reports_ans_str", "cmem_n_20reports_ans_str", "cmem_n_30reports_ans_str", "cmem_n_40reports_ans_str", "cmem_n_50reports_ans_str", "cmem_n_60reports_ans_str", "cmem_n_70reports_ans_str", "cmem_n_80reports_ans_str", "cmem_n_90reports_ans_str", "cmem_n_100reports_ans_str"]].tolist())]}
                
fun()

In [None]:
weird_lst

In [None]:
# error_dict = {}
# id_lst = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/weird_n03.csv")['Unnamed: 0'].tolist()
# answer_lst = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/weird_n03.csv")['answer'].tolist()
# report_lst = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/weird_n03.csv")['report'].tolist()

# for patient, answer, report in zip(id_lst, answer_lst, report_lst):
#     error_dict[patient] = {}
#     error_dict[patient]["answer"] = f"N{answer}"
#     error_dict[patient]["report"] = report
#     error_dict[patient]["kepa"] = {"N0": [], "N1": [], "N2": [], "N3": []}

# for patient in id_lst:
#     for i in range(10):
#         obj = weird_lst[f'{i}split'].get(patient)
#         if obj is not None:
#             error_dict[patient]["answer"] = obj["answer"]
#             error_dict[patient]["report"] = obj["report"]
#             for n in range(10):
#                 pred = obj["kepa(mem_reas_pred)"][n][2].upper().replace("NO", "N0").replace("NL", "N1")
#                 if "N0" in pred:
#                     error_dict[patient]["kepa"]["N0"].append({"memory": obj["kepa(mem_reas_pred)"][n][0], "reasoning": obj["kepa(mem_reas_pred)"][n][1]})
#                 elif "N1" in pred:
#                     error_dict[patient]["kepa"]["N1"].append({"memory": obj["kepa(mem_reas_pred)"][n][0], "reasoning": obj["kepa(mem_reas_pred)"][n][1]})
#                 elif "N2" in pred:
#                     error_dict[patient]["kepa"]["N2"].append({"memory": obj["kepa(mem_reas_pred)"][n][0], "reasoning": obj["kepa(mem_reas_pred)"][n][1]})
#                 elif "N3" in pred:
#                     error_dict[patient]["kepa"]["N3"].append({"memory": obj["kepa(mem_reas_pred)"][n][0], "reasoning": obj["kepa(mem_reas_pred)"][n][1]})

In [None]:
# with open('weird_n03.json', 'w') as json_file:
#     json.dump(error_dict, json_file, indent=4)

# Quantative Analysis

In [None]:
# zs cot
t_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_t14_zscot_test_800.csv")
print(t14_calculate_metrics(t_df['t'], t_df['zs_t_ans_str']))

n_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_n03_zscot_test_800.csv")
print(n03_calculate_metrics(n_df['n'], n_df['zs_n_ans_str']))

# zs
t_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_t14_zs_test_800.csv")
print(t14_calculate_metrics(t_df['t'], t_df['zs_t_ans_str']))

n_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_n03_zs_test_800.csv")
print(n03_calculate_metrics(n_df['n'], n_df['zs_n_ans_str']))

# ensReas
t_df = pd.read_csv("/secure/shared_data/rag_tnm_results/t14_results/mixtral_ensReas_step1/brca_t14_merged_df_800.csv")
print(t14_calculate_metrics(t_df['t'], t_df['sc_ans']))

n_df = pd.read_csv("/secure/shared_data/rag_tnm_results/n03_results/mixtral_ensReas_step1/brca_n03_merged_df.csv")
print(n03_calculate_metrics(n_df['n'], n_df['sc_ans']))

In [42]:
def calculate_mean_std(results, cat):
    precision_list = [result[cat]['precision'] for result in results]
    recall_list = [result[cat]['recall'] for result in results]
    f1_list = [result[cat]['f1'] for result in results]
    support_list = [result[cat]['support'] for result in results]
    num_errors_list = [result[cat]['num_errors'] for result in results]
    
    mean_precision = sum(precision_list) / len(precision_list)
    mean_recall = sum(recall_list) / len(recall_list)
    mean_f1 = sum(f1_list) / len(f1_list)
    
    std_precision = (sum([(x - mean_precision)**2 for x in precision_list]) / len(precision_list))**0.5
    std_recall = (sum([(x - mean_recall)**2 for x in recall_list]) / len(recall_list))**0.5
    std_f1 = (sum([(x - mean_f1)**2 for x in f1_list]) / len(f1_list))**0.5
    
    return {
        'mean_precision': round(mean_precision, 3),
        'mean_recall': round(mean_recall, 3),
        'mean_f1': round(mean_f1, 3),
        'std_precision': round(std_precision, 3),
        'std_recall': round(std_recall, 3),
        'std_f1': round(std_f1, 3),
        'sum_support': sum(support_list),
        'sum_num_errors': sum(num_errors_list),
        'raw_mean_precision': mean_precision,
        'raw_mean_recall': mean_recall,
        'raw_mean_f1': mean_f1,
    }

In [43]:
def calculate_mean_std(results, cat):
    precision_list = [result[cat]['precision'] for result in results]
    recall_list = [result[cat]['recall'] for result in results]
    f1_list = [result[cat]['f1'] for result in results]
    support_list = [result[cat]['support'] for result in results]
    num_errors_list = [result[cat]['num_errors'] for result in results]
    
    mean_precision = sum(precision_list) / len(precision_list)
    mean_recall = sum(recall_list) / len(recall_list)
    mean_f1 = sum(f1_list) / len(f1_list)
    
    std_precision = (sum([(x - mean_precision)**2 for x in precision_list]) / len(precision_list))**0.5
    std_recall = (sum([(x - mean_recall)**2 for x in recall_list]) / len(recall_list))**0.5
    std_f1 = (sum([(x - mean_f1)**2 for x in f1_list]) / len(f1_list))**0.5
    
    return {
        'mean_precision': round(mean_precision, 3),
        'mean_recall': round(mean_recall, 3),
        'mean_f1': round(mean_f1, 3),
        'std_precision': round(std_precision, 3),
        'std_recall': round(std_recall, 3),
        'std_f1': round(std_f1, 3),
        'sum_support': sum(support_list),
        'sum_num_errors': sum(num_errors_list),
        'raw_mean_precision': mean_precision,
        'raw_mean_recall': mean_recall,
        'raw_mean_f1': mean_f1,
    }

def output_tabular_performance(results, categories = ['T1', 'T2', 'T3', 'T4']):
    precisions =[]
    recalls = []
    f1s = []

    for category in categories:
        eval = calculate_mean_std(results, category)
        print("{} {:.3f}({:.3f}) {:.3f}({:.3f}) {:.3f}({:.3f})".format(category, eval["mean_precision"], eval["std_precision"], eval["mean_recall"], eval["std_recall"], eval["mean_f1"], eval["std_f1"]))
        
        # for calculating macro average
        precisions.append(eval['raw_mean_precision'])
        recalls.append(eval['raw_mean_recall'])
        f1s.append(eval['raw_mean_f1'])

    print("MacroAvg. {:.3f} {:.3f} {:.3f}".format(round(sum(precisions)/len(precisions), 3), round(sum(recalls)/len(recalls), 3), round(sum(f1s)/len(f1s), 3)))

### T14

In [None]:
# t14
zs_t = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_t14_zs_test_800.csv")
zscot_t = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_t14_zscot_test_800.csv")
ensReas_t = pd.read_csv("/secure/shared_data/rag_tnm_results/t14_results/mixtral_ensReas_step1/brca_t14_merged_df_800.csv")

zs_t_results = []
zscot_t_results = []
ensReas_t_results = []
kepa_t_results = []

for run in range(1):
    print(run)
    split_ids = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/t14_test_{run}.csv").patient_filename
    zs_t_split = zs_t[zs_t.patient_filename.isin(split_ids)]
    zs_t_results.append(t14_calculate_metrics(zs_t_split['t'], zs_t_split['zs_t_ans_str']))

    zscot_t_split = zscot_t[zscot_t.patient_filename.isin(split_ids)]
    zscot_t_results.append(t14_calculate_metrics(zscot_t_split['t'], zscot_t_split['zs_t_ans_str']))

    ensReas_t_split = ensReas_t[ensReas_t.patient_filename.isin(split_ids)]
    ensReas_t_results.append(t14_calculate_metrics(ensReas_t_split['t'], ensReas_t_split['sc_ans']))

    df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_{run}_outof_10runs.csv")
    kepa_t_results.append(t14_calculate_metrics(df['t'], df['cmem_t_40reports_ans_str']))

In [None]:
# zscot
output_tabular_performance(zscot_t_results)

In [None]:
# kepa
output_tabular_performance(kepa_t_results)

In [None]:
results = zscot_t_results

categories = ['T1', 'T2', 'T3', 'T4']
metrics = {category: calculate_mean_std(results, category) for category in categories}
metrics

In [None]:
precisions =[]
recalls = []
f1s = []
for key, value in metrics.items():
    precisions.append(value['raw_mean_precision'])
    recalls.append(value['raw_mean_recall'])
    f1s.append(value['raw_mean_f1'])
    
# print(round(sum(precisions)/len(precisions), 3), round(sum(recalls)/len(recalls), 3), round(sum(f1s)/len(f1s), 3))
# print in dictionary
print({'macro_average_precision': round(sum(precisions)/len(precisions), 3), 'macro_average_recall': round(sum(recalls)/len(recalls), 3), 'macro_average_f1': round(sum(f1s)/len(f1s), 3)})

### N03

In [None]:
# n03
zs_n = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_n03_zs_test_800.csv")
zscot_n = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_n03_zscot_test_800.csv")
ensReas_n = pd.read_csv("/secure/shared_data/rag_tnm_results/n03_results/mixtral_ensReas_step1/brca_n03_merged_df.csv")

zs_n_results = []
zscot_n_results = []
ensReas_n_results = []
kepa_n_results = []

for run in range(1, 2):
    print(run)
    split_ids = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/n03_test_{run}.csv").patient_filename
    zs_n_split = zs_n[zs_n.patient_filename.isin(split_ids)]
    zs_n_results.append(n03_calculate_metrics(zs_n_split['n'], zs_n_split['zs_n_ans_str']))

    zscot_n_split = zscot_n[zscot_n.patient_filename.isin(split_ids)]
    zscot_n_results.append(n03_calculate_metrics(zscot_n_split['n'], zscot_n_split['zs_n_ans_str']))

    ensReas_n_split = ensReas_n[ensReas_n.patient_filename.isin(split_ids)]
    ensReas_n_results.append(n03_calculate_metrics(ensReas_n_split['n'], ensReas_n_split['sc_ans']))

    df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_n03_dynamic_test_{run}_outof_10runs.csv")
    kepa_n_results.append(n03_calculate_metrics(df['n'], df['cmem_n_40reports_ans_str']))

In [None]:
# zscot
output_tabular_performance(zscot_n_results, ['N0', 'N1', 'N2', 'N3'])

In [None]:
# kepa
output_tabular_performance(kepa_n_results, ['N0', 'N1', 'N2', 'N3'])

In [None]:
results = kepa_n_results
categories = ['N0', 'N1', 'N2', 'N3']
metrics = {category: calculate_mean_std(results, category) for category in categories}
metrics

In [None]:
precisions =[]
recalls = []
f1s = []
for key, value in metrics.items():
    precisions.append(value['raw_mean_precision'])
    recalls.append(value['raw_mean_recall'])
    f1s.append(value['raw_mean_f1'])

# round(sum(precisions)/len(precisions), 3), round(sum(recalls)/len(recalls), 3), round(sum(f1s)/len(f1s), 3)
# print in dictionary
print({'macro_average_precision': round(sum(precisions)/len(precisions), 3), 'macro_average_recall': round(sum(recalls)/len(recalls), 3), 'macro_average_f1': round(sum(f1s)/len(f1s), 3)})

# Plot scores for 10 splits, given 10 memories

### T14

In [None]:
# individual graph

zs_t = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_t14_zs_test_800.csv")
zscot_t = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_t14_zscot_test_800.csv")
df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0914_rag_test.csv")

for run in range(10):
    # t14 training data to extract memory
    t_train_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/t14_memory_dataset{run}.csv")
    memory_tup = []
    for idx, row in t_train_df.iterrows():
        # if row["cmem_t_is_updated"] == True:
        memory_tup.append((idx+1,row['cmem_t_memory_str']))
    memory_tup = memory_tup[9::10]
    df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_{run}_outof_10runs.csv")
    split_ids = df.patient_filename
    zs_t_split = zs_t[zs_t.patient_filename.isin(split_ids)]
    zscot_t_split = zscot_t[zscot_t.patient_filename.isin(split_ids)]
    df_split = df[df.patient_filename.isin(split_ids)]

    for i, _ in memory_tup:
        if len(df[df[f"cmem_t_{i}reports_is_parsed"]==False]) > 0:
            print(f"parsing error at memory {i}")


   # gather y-axis data
    precision_lst = []
    recall_lst = []
    f1_lst = []

    zs_precision = t14_calculate_metrics(zs_t_split['t'], zs_t_split['zs_t_ans_str'])['overall']['macro_precision']
    zs_recall = t14_calculate_metrics(zs_t_split['t'], zs_t_split['zs_t_ans_str'])['overall']['macro_recall']
    zs_f1 = t14_calculate_metrics(zs_t_split['t'], zs_t_split['zs_t_ans_str'])['overall']['macro_f1']

    zscot_precision = t14_calculate_metrics(zscot_t_split['t'], zscot_t_split['zs_t_ans_str'])['overall']['macro_precision']
    zscot_recall = t14_calculate_metrics(zscot_t_split['t'], zscot_t_split['zs_t_ans_str'])['overall']['macro_recall']
    zscot_f1 = t14_calculate_metrics(zscot_t_split['t'], zscot_t_split['zs_t_ans_str'])['overall']['macro_f1']

    rag_precision = t14_calculate_metrics(df_split['t'], df_split['cmem_t_ans_str'])['overall']['macro_precision']
    rag_recall = t14_calculate_metrics(df_split['t'], df_split['cmem_t_ans_str'])['overall']['macro_recall']
    rag_f1 = t14_calculate_metrics(df_split['t'], df_split['cmem_t_ans_str'])['overall']['macro_f1']

    x_idx = []
    for i, _ in memory_tup:
        x_idx.append(i)
        result = t14_calculate_metrics(df['t'], df[f'cmem_t_{i}reports_ans_str'])['overall']
        precision_lst.append(result['macro_precision'])
        recall_lst.append(result['macro_recall'])
        f1_lst.append(result['macro_f1'])


    plt.figure(figsize=(15, 10))

    plt.plot(x_idx, precision_lst, label='Memory Precision', color='blue', marker='o')
    plt.plot(x_idx, recall_lst, label='Memory Recall', color='green', marker='o')
    plt.plot(x_idx, f1_lst, label='Memory F1 Score', color='red', marker='o')

    # plt.axhline(y=zs_precision, color='blue', linestyle='--', label='ZS Precision')
    # plt.axhline(y=zs_recall, color='green', linestyle='--', label='ZS Recall')
    # plt.axhline(y=zs_f1, color='red', linestyle='--', label='ZS F1 Score')
    plt.axhline(y=rag_precision, color='blue', linestyle='--', label='RAG Precision')
    plt.axhline(y=rag_recall, color='green', linestyle='--', label='RAG Recall')
    plt.axhline(y=rag_f1, color='red', linestyle='--', label='RAG F1 Score')

    plt.axhline(y=zscot_precision, color='blue', linestyle='-.', label='ZSCOT Precision')
    plt.axhline(y=zscot_recall, color='green', linestyle='-.', label='ZSCOT Recall')
    plt.axhline(y=zscot_f1, color='red', linestyle='-.', label='ZSCOT F1 Score')
    
    plt.xlabel(f'# of Reports for Memory (t14_train_{run}.csv)')
    plt.ylabel('Scores')
    plt.title(f'Testing Results on 700 test Reports (t14_test_{run}.csv)')
    plt.legend()
    plt.grid(True)

    plt.show()

In [None]:
# Average (with new metric)

zs_t = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_t14_zs_test_800.csv")
zscot_t = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_t14_zscot_test_800.csv")
df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0914_rag_test.csv")


zs_t_results = []
zscot_t_results = []
rag_t_results = []

total_run = 10
for run in range(total_run):
    split_ids = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/t14_test_{run}.csv").patient_filename
    zs_t_split = zs_t[zs_t.patient_filename.isin(split_ids)]
    zs_t_results.append(t14_calculate_metrics(zs_t_split['t'], zs_t_split['zs_t_ans_str'])['overall'])

    zscot_t_split = zscot_t[zscot_t.patient_filename.isin(split_ids)]
    zscot_t_results.append(t14_calculate_metrics(zscot_t_split['t'], zscot_t_split['zs_t_ans_str'])['overall'])
    
    df_split = df[df.patient_filename.isin(split_ids)]
    rag_t_results.append(t14_calculate_metrics(df_split['t'], df_split['cmem_t_ans_str'])['overall'])

zs_precision_avg = sum([rs['macro_precision'] for rs in zs_t_results])/len(zs_t_results)
zs_recall_avg = sum([rs['macro_recall'] for rs in zs_t_results])/len(zs_t_results)
zs_f1_avg = sum([rs['macro_f1'] for rs in zs_t_results])/len(zs_t_results)

zscot_precision_avg = sum([rs['macro_precision'] for rs in zscot_t_results])/len(zscot_t_results)
zscot_recall_avg = sum([rs['macro_recall'] for rs in zscot_t_results])/len(zscot_t_results)
zscot_f1_avg = sum([rs['macro_f1'] for rs in zscot_t_results])/len(zscot_t_results)

rag_precision_avg = sum([rs['macro_precision'] for rs in rag_t_results])/len(rag_t_results)
rag_recall_avg = sum([rs['macro_recall'] for rs in rag_t_results])/len(rag_t_results)
rag_f1_avg = sum([rs['macro_f1'] for rs in rag_t_results])/len(rag_t_results)

x_axis = np.array(range(1, 11)) * 10

memory_precision_cumulative = []
memory_recall_cumulative = []
memory_f1_cumulative = []


for run in range(total_run):
    df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_{run}_outof_10runs.csv")

    for i in np.array(range(1, 11)): # memory (10, 20, 30, 40, 50, 60, 70, 80, 90, 100)
        result = t14_calculate_metrics(df['t'], df[f'cmem_t_{i*10}reports_ans_str'])['overall']
        if run == 0:
            memory_precision_cumulative.append(result['macro_precision'])
            memory_recall_cumulative.append(result['macro_recall'])
            memory_f1_cumulative.append(result['macro_f1'])
        else:
            memory_precision_cumulative[i-1] += result['macro_precision']
            memory_recall_cumulative[i-1] += result['macro_recall']
            memory_f1_cumulative[i-1] += result['macro_f1']


# average
precision_avg = [p / total_run for p in memory_precision_cumulative]
recall_avg = [r / total_run for r in memory_recall_cumulative]
f1_avg = [f / total_run for f in memory_f1_cumulative]


plt.figure(figsize=(15, 10))

plt.plot(x_axis, precision_avg, label='Average KEPA Precision', color='blue', marker='o')
plt.plot(x_axis, recall_avg, label='Average KEPA Recall', color='green', marker='o')
plt.plot(x_axis, f1_avg, label='Average KEPA F1 Score', color='red', marker='o')


plt.axhline(y=zscot_precision_avg, color='blue', linestyle=':', label='ZSCOT Precision')
plt.axhline(y=zscot_recall_avg, color='green', linestyle=':', label='ZSCOT Recall')
plt.axhline(y=zscot_f1_avg, color='red', linestyle=':', label='ZSCOT F1 Score')

plt.axhline(y=rag_precision_avg, color='blue', linestyle='--', label='RAG Precision')
plt.axhline(y=rag_recall_avg, color='green', linestyle='--', label='RAG Recall')
plt.axhline(y=rag_f1_avg, color='red', linestyle='--', label='RAG F1 Score')

plt.text(x_axis[-1] + 2, zscot_precision_avg, f'{zscot_precision_avg:.3f}', fontsize=9, ha='left', va='center', color='blue')
plt.text(x_axis[-1] + 2, zscot_recall_avg, f'{zscot_recall_avg:.3f}', fontsize=9, ha='left', va='center', color='green')
plt.text(x_axis[-1] + 2, zscot_f1_avg, f'{zscot_f1_avg:.3f}', fontsize=9, ha='left', va='center', color='red')


plt.xlabel('Number of Training Reports')
plt.ylabel('Scores')
# plt.title(f'The Average of 10 Results on 700 Test Reports (t14)')
plt.legend()
plt.grid(True)

plt.show()


### N03

In [None]:
# individual graph

zs_n = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_n03_zs_test_800.csv")
zscot_n = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_n03_zscot_test_800.csv")
df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0914_rag_test.csv")

for run in range(10):
    # n03 training data to extract memory
    t_train_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/n03_memory_dataset{run}.csv")

    memory_tup = []
    for idx, row in t_train_df.iterrows():
        # if row["cmem_t_is_updated"] == True:
        memory_tup.append((idx+1,row['cmem_n_memory_str']))
    memory_tup = memory_tup[9::10]
    df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_n03_dynamic_test_{run}_outof_10runs.csv")
    split_ids = df.patient_filename
    zs_n_split = zs_n[zs_n.patient_filename.isin(split_ids)]
    zscot_n_split = zscot_n[zscot_n.patient_filename.isin(split_ids)]
    df_split = df[df.patient_filename.isin(split_ids)]


    for i, _ in memory_tup:
        if len(df[df[f"cmem_n_{i}reports_is_parsed"]==False]) > 0:
            print(f"parsing error at memory {i}")


   # gather y-axis data
    precision_lst = []
    recall_lst = []
    f1_lst = []

    zs_precision = n03_calculate_metrics(zs_n_split['n'], zs_n_split['zs_n_ans_str'])['overall']['macro_precision']
    zs_recall = n03_calculate_metrics(zs_n_split['n'], zs_n_split['zs_n_ans_str'])['overall']['macro_recall']
    zs_f1 = n03_calculate_metrics(zs_n_split['n'], zs_n_split['zs_n_ans_str'])['overall']['macro_f1']

    zscot_precision = n03_calculate_metrics(zscot_n_split['n'], zscot_n_split['zs_n_ans_str'])['overall']['macro_precision']
    zscot_recall = n03_calculate_metrics(zscot_n_split['n'], zscot_n_split['zs_n_ans_str'])['overall']['macro_recall']
    zscot_f1 = n03_calculate_metrics(zscot_n_split['n'], zscot_n_split['zs_n_ans_str'])['overall']['macro_f1']

    rag_precision = n03_calculate_metrics(df_split['n'], df_split['cmem_n_ans_str'])['overall']['macro_precision']
    rag_recall = n03_calculate_metrics(df_split['n'], df_split['cmem_n_ans_str'])['overall']['macro_recall']
    rag_f1 = n03_calculate_metrics(df_split['n'], df_split['cmem_n_ans_str'])['overall']['macro_f1']

    x_idx = []
    for i, _ in memory_tup:
        x_idx.append(i)
        result = n03_calculate_metrics(df['n'], df[f'cmem_n_{i}reports_ans_str'])['overall']
        precision_lst.append(result['macro_precision'])
        recall_lst.append(result['macro_recall'])
        f1_lst.append(result['macro_f1'])


    plt.figure(figsize=(15, 10))

    plt.plot(x_idx, precision_lst, label='Memory Precision', color='blue', marker='o')
    plt.plot(x_idx, recall_lst, label='Memory Recall', color='green', marker='o')
    plt.plot(x_idx, f1_lst, label='Memory F1 Score', color='red', marker='o')
    
    # plt.axhline(y=zs_precision, color='blue', linestyle='--', label='ZS Precision')
    # plt.axhline(y=zs_recall, color='green', linestyle='--', label='ZS Recall')
    # plt.axhline(y=zs_f1, color='red', linestyle='--', label='ZS F1 Score')

    plt.axhline(y=rag_precision, color='blue', linestyle='--', label='RAG Precision')
    plt.axhline(y=rag_recall, color='green', linestyle='--', label='RAG Recall')
    plt.axhline(y=rag_f1, color='red', linestyle='--', label='RAG F1 Score')

    plt.axhline(y=zscot_precision, color='blue', linestyle='-.', label='ZSCOT Precision')
    plt.axhline(y=zscot_recall, color='green', linestyle='-.', label='ZSCOT Recall')
    plt.axhline(y=zscot_f1, color='red', linestyle='-.', label='ZSCOT F1 Score')
    
    plt.xlabel(f'# of Reports for Memory (n03_train_{run}.csv)')
    plt.ylabel('Scores')
    plt.title(f'Testing Results on 700 test Reports (n03_test_{run}.csv)')
    plt.legend()
    plt.grid(True)

    plt.show()

In [None]:
# Average (with new metric)

zs_n = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_n03_zs_test_800.csv")
zscot_n = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0716_n03_zscot_test_800.csv")
df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/result/0914_rag_test.csv")

zs_n_results = []
zscot_n_results = []
rag_n_results = []

total_run = 10
for run in range(total_run):
    # if run == 8:
    #     continue
    split_ids = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/n03_test_{run}.csv").patient_filename
    zs_n_split = zs_n[zs_n.patient_filename.isin(split_ids)]
    zs_n_results.append(n03_calculate_metrics(zs_n_split['n'], zs_n_split['zs_n_ans_str'])['overall'])

    zscot_n_split = zscot_n[zscot_n.patient_filename.isin(split_ids)]
    zscot_n_results.append(n03_calculate_metrics(zscot_n_split['n'], zscot_n_split['zs_n_ans_str'])['overall'])

    df_split = df[df.patient_filename.isin(split_ids)]
    rag_n_results.append(n03_calculate_metrics(df_split['n'], df_split['cmem_n_ans_str'])['overall'])


zs_precision_avg = sum([rs['macro_precision'] for rs in zs_n_results])/len(zs_n_results)
zs_recall_avg = sum([rs['macro_recall'] for rs in zs_n_results])/len(zs_n_results)
zs_f1_avg = sum([rs['macro_f1'] for rs in zs_n_results])/len(zs_n_results)

zscot_precision_avg = sum([rs['macro_precision'] for rs in zscot_n_results])/len(zscot_n_results)
zscot_recall_avg = sum([rs['macro_recall'] for rs in zscot_n_results])/len(zscot_n_results)
zscot_f1_avg = sum([rs['macro_f1'] for rs in zscot_n_results])/len(zscot_n_results)

rag_precision_avg = sum([rs['macro_precision'] for rs in rag_n_results])/len(rag_n_results)
rag_recall_avg = sum([rs['macro_recall'] for rs in rag_n_results])/len(rag_n_results)
rag_f1_avg = sum([rs['macro_f1'] for rs in rag_n_results])/len(rag_n_results)


x_axis = np.array(range(1, 11)) * 10

memory_precision_cumulative = []
memory_recall_cumulative = []
memory_f1_cumulative = []
devided_by = 0
for run in range(total_run):
    df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_n03_dynamic_test_{run}_outof_10runs.csv")

    for i in np.array(range(1, 11)): # memory (10, 20, 30, 40, 50, 60, 70, 80, 90, 100)
        result = n03_calculate_metrics(df['n'], df[f'cmem_n_{i*10}reports_ans_str'])['overall']
        if run == 0:
            memory_precision_cumulative.append(result['macro_precision'])
            memory_recall_cumulative.append(result['macro_recall'])
            memory_f1_cumulative.append(result['macro_f1'])
        else:
            memory_precision_cumulative[i-1] += result['macro_precision']
            memory_recall_cumulative[i-1] += result['macro_recall']
            memory_f1_cumulative[i-1] += result['macro_f1']
    devided_by += 1


# average
precision_avg = [p / devided_by for p in memory_precision_cumulative]
recall_avg = [r / devided_by for r in memory_recall_cumulative]
f1_avg = [f / devided_by for f in memory_f1_cumulative]


plt.figure(figsize=(15, 10))

plt.plot(x_axis, precision_avg, label='Average KEPA Precision', color='blue', marker='o')
plt.plot(x_axis, recall_avg, label='Average KEPA Recall', color='green', marker='o')
plt.plot(x_axis, f1_avg, label='Average KEPA F1 Score', color='red')


plt.axhline(y=zscot_precision_avg, color='blue', linestyle=':', label='ZSCOT Precision')
plt.axhline(y=zscot_recall_avg, color='green', linestyle=':', label='ZSCOT Recall')
plt.axhline(y=zscot_f1_avg, color='red', linestyle=':', label='ZSCOT F1 Score')

plt.axhline(y=rag_precision_avg, color='blue', linestyle='--', label='RAG Precision')
plt.axhline(y=rag_recall_avg, color='green', linestyle='--', label='RAG Recall')
plt.axhline(y=rag_f1_avg, color='red', linestyle='--', label='RAG F1 Score')

plt.text(x_axis[-1] + 2, zscot_precision_avg, f'{zscot_precision_avg:.3f}', fontsize=9, ha='left', va='center', color='blue')
plt.text(x_axis[-1] + 2, zscot_recall_avg, f'{zscot_recall_avg:.3f}', fontsize=9, ha='left', va='center', color='green')
plt.text(x_axis[-1] + 2, zscot_f1_avg, f'{zscot_f1_avg:.3f}', fontsize=9, ha='left', va='center', color='red')

plt.xlabel('Number of Training Reports')
plt.ylabel('Scores')
# plt.title(f'The Average of 10 Results on 700 Test Reports (n03)')
plt.legend()
plt.grid(True)

plt.show()

# Re-run for Error cases

In [None]:
client = OpenAI(api_key = "empty",
                base_url = "http://localhost:8000/v1")
    
class TestingResponse(BaseModel):
    predictedStage: str = Field(description="predicted cancer stage")
    reasoning: str = Field(description="reasoning to support predicted cancer stage") 

testing_schema = TestingResponse.model_json_schema()

def test_individual_report(dataset: pd.DataFrame, patient_filename: str, memory_tup: tuple, category = 'n'):
    num, memory = memory_tup
    report = dataset[dataset.patient_filename == patient_filename]["text"].values[0]

    if category.lower()[0] == 'n':
        prompt = testing_predict_prompt_n03.format(memory=memory, report=report)
    else:
        prompt = testing_predict_prompt_t14.format(memory=memory, report=report)
    
    prompt = system_instruction + "\n" + prompt
    messages = [{"role": "user", "content": prompt}]
    response = client.chat.completions.create(
        model = "mistralai/Mixtral-8x7B-Instruct-v0.1",
        messages = messages,
        extra_body = {"guided_json": testing_schema},
        temperature = 0.1)
    # response = json.loads(response.choices[0].message.content.replace("\\", "\\\\"))
    response = json.loads(response.choices[0].message.content)

    dataset.loc[dataset["patient_filename"] == patient_filename, f"cmem_{category}_{num}reports_is_parsed"] = True
    dataset.loc[dataset["patient_filename"] == patient_filename, f"cmem_{category}_{num}reports_ans_str"] = response["predictedStage"]
    dataset.loc[dataset["patient_filename"] == patient_filename, f"cmem_{category}_{num}reports_reasoning"] = response["reasoning"]

    return dataset

In [None]:
# T14
for run in range(10):
    print(f"{run}th split")

    t_train_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/t14_memory_dataset{run}.csv")
    memory_tup = []
    for idx, row in t_train_df.iterrows():
        memory_tup.append((idx+1,row['cmem_t_memory_str']))
    memory_tup = memory_tup[9::10]

    df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_{run}_outof_10runs.csv")
    for num, memory in memory_tup:
        print(f"{num}th memory")
        for idx in range(len(df)):     
            patient_filename = df.loc[idx, "patient_filename"]
            if not isinstance(df.loc[df["patient_filename"] == patient_filename, f"cmem_t_{num}reports_ans_str"].values.item(), str):
                print(idx) 
                print("before: ", df.loc[df["patient_filename"] == patient_filename, f"cmem_t_{num}reports_ans_str"].values.item())
                test_individual_report(df, patient_filename, (num, memory), 't')
                print("after: ", df.loc[df["patient_filename"] == patient_filename, f"cmem_t_{num}reports_ans_str"].values.item())
                print("label: ", df.loc[df["patient_filename"] == patient_filename, "t"].values.item())
    
    df.to_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_{run}_outof_10runs.csv", index=False)

In [None]:
# N03
for run in range(10):
    print(f"{run}th split")

    t_train_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/n03_memory_dataset{run}.csv")
    memory_tup = []
    for idx, row in t_train_df.iterrows():
        memory_tup.append((idx+1,row['cmem_n_memory_str']))
    memory_tup = memory_tup[9::10]

    df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_n03_dynamic_test_{run}_outof_10runs.csv")
    for num, memory in memory_tup:
        print(f"{num}th memory")
        for idx in range(len(df)):     
            patient_filename = df.loc[idx, "patient_filename"]
            if not isinstance(df.loc[df["patient_filename"] == patient_filename, f"cmem_n_{num}reports_ans_str"].values.item(), str):
                print(idx) 
                print("before: ", df.loc[df["patient_filename"] == patient_filename, f"cmem_n_{num}reports_ans_str"].values.item())
                test_individual_report(df, patient_filename, (num, memory), 'n')
                print("after: ", df.loc[df["patient_filename"] == patient_filename, f"cmem_n_{num}reports_ans_str"].values.item())
                print("label: ", df.loc[df["patient_filename"] == patient_filename, "n"].values.item())
    
    df.to_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_n03_dynamic_test_{run}_outof_10runs.csv", index=False)

# Check the Difference in Performance Based on the Order of Fields in the Schema

In [None]:
# run = 6

# test_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/0718_t14_dynamic_test_{run}_outof_10runs.csv")
# for i in np.array(range(1, 11)): # memory (10, 20, 30, 40, 50, 60, 70, 80, 90, 100)
#     result = t14_calculate_metrics(test_df['t'], test_df[f'cmem_t_{i*10}reports_ans_str'])['overall']
#     print(result)


# test_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/0718_t14_dynamic_test_{run}_outof_10runs_revised.csv")
# for i in np.array(range(1, 11)): # memory (10, 20, 30, 40, 50, 60, 70, 80, 90, 100)
#     if len(test_df[test_df[f"cmem_t_{i*10}reports_is_parsed"]==False]) > 0:
#         print(len(test_df[test_df[f"cmem_t_{i*10}reports_is_parsed"]==False]))
#     print(t14_calculate_metrics(test_df['t'], test_df[f'cmem_t_{i*10}reports_ans_str'])['overall'])

In [None]:
# n_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/0718_n03_dynamic_test_5_outof_10runs_reordered.csv")
# for i in np.array(range(1, 11)): # memory (10, 20, 30, 40, 50, 60, 70, 80, 90, 100)
#     result = n03_calculate_metrics(n_df['n'], n_df[f'cmem_n_{i*10}reports_ans_str'])['overall']
#     print(result)

# n_df = pd.read_csv("/home/yl3427/cylab/selfCorrectionAgent/0718_n03_dynamic_test_5_outof_10runs.csv")
# for i in np.array(range(1, 11)): # memory (10, 20, 30, 40, 50, 60, 70, 80, 90, 100)
#     result = n03_calculate_metrics(n_df['n'], n_df[f'cmem_n_{i*10}reports_ans_str'])['overall']
#     print(result)

# Plot Memory length

In [None]:
# individual memory string length for T14

for i in range(10):
    train_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/t14_memory_dataset{i}.csv")
    x_indices = []
    y_str_length_mem = []
    y_str_length_rules = []
 
    for idx, row in train_df.iterrows():
        x_indices.append(idx+1)
        y_str_length_mem.append(len(row['cmem_t_memory_str']))
        y_str_length_rules.append(len(row['cmem_t_rules_str']))
    
    plt.figure(figsize=(15, 10))

    plt.plot(x_indices, y_str_length_mem, label='Memory String Length', color='blue', marker='o')
    plt.plot(x_indices, y_str_length_rules, label='Rules String Length', color='red', marker='o')

    plt.xlabel(f'Index of Memory Dataset (t14_memory_dataset{i}.csv)')
    plt.ylabel('Length')
    # plt.title(f'Length of Memory and Rules')
    plt.legend()
    plt.grid(True)

    plt.show()

In [None]:
# average memory string length for T14
y_str_length_mem_arr = np.array([0]*100)
y_str_length_rules_arr = np.array([0]*100)
x_indices = np.array(range(1, 101))

for i in range(10):
    train_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/t14_memory_dataset{i}.csv")

    for idx, row in train_df.iterrows():
        y_str_length_mem_arr[idx] += len(row['cmem_t_memory_str'])
        y_str_length_rules_arr[idx] += len(row['cmem_t_rules_str']) 

plt.figure(figsize=(15, 10))

plt.plot(x_indices, y_str_length_mem_arr/10, label='Threshold 80', color='blue', marker='o')
plt.plot(x_indices, y_str_length_rules_arr/10, label='Threshold 0', color='red', marker='o')

plt.xlabel(f'Number of Training Reports (T14)')
plt.ylabel('Average Length of Memory')
# plt.title(f'Length of Memory and Rules')
plt.legend()
plt.grid(True)

plt.show()


In [None]:
# individual number of rules in memory for T14
for i in range(10):
    train_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/t14_memory_dataset{i}.csv")
    x_indices = []

    y_num_rules_mem = []
    y_num_rules_rules = []
    for idx, row in train_df.iterrows():
        x_indices.append(idx+1)
        y_num_rules_mem.append(len(row['cmem_t_memory_str'].split("\n")))
        y_num_rules_rules.append(len(row['cmem_t_rules_str'].split("\n")))
    
    plt.figure(figsize=(15, 10))

    plt.plot(x_indices, y_num_rules_mem, label='Memory Num Rules', color='blue', marker='o')
    plt.plot(x_indices, y_num_rules_rules, label='Rules Num Rules', color='red', marker='o')

    plt.xlabel(f'Index of Memory Dataset (t14_memory_dataset{i}.csv)')
    plt.ylabel('Number')
    plt.title(f'Number of Rules')
    plt.legend()
    plt.grid(True)

    max_y = max(max(y_num_rules_mem), max(y_num_rules_rules))
    plt.yticks(range(0, int(max_y) + 1))

    plt.show()

In [None]:
# average number of rules in memory for T14
y_num_rules_mem_arr = np.array([0]*100)
y_num_rules_rules_arr = np.array([0]*100)
x_indices = np.array(range(1, 101))

for i in range(10):
    train_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/t14_memory_dataset{i}.csv")

    for idx, row in train_df.iterrows():
        y_num_rules_mem_arr[idx] += len(row['cmem_t_memory_str'].split("\n"))
        y_num_rules_rules_arr[idx] += len(row['cmem_t_rules_str'].split("\n"))

plt.figure(figsize=(15, 10))

plt.plot(x_indices, y_num_rules_mem_arr/10, label='Threshold 80', color='blue', marker='o')
plt.plot(x_indices, y_num_rules_rules_arr/10, label='Threshold 0', color='red', marker='o')

plt.xlabel(f'Number of Training Reports')
plt.ylabel('Average Number of Rules')

plt.legend()
plt.grid(True)

plt.show()

### n

In [None]:
# individual memory string length for N03
for i in range(10):
    train_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/n03_memory_dataset{i}.csv")
    x_indices = []
    y_str_length_mem = []
    y_str_length_rules = []
 
    for idx, row in train_df.iterrows():
        x_indices.append(idx+1)
        y_str_length_mem.append(len(row['cmem_n_memory_str']))
        y_str_length_rules.append(len(row['cmem_n_rules_str']))
    
    plt.figure(figsize=(15, 10))

    plt.plot(x_indices, y_str_length_mem, label='Memory String Length', color='blue', marker='o')
    plt.plot(x_indices, y_str_length_rules, label='Rules String Length', color='red', marker='o')

    plt.xlabel(f'Index of Memory Dataset (n03_memory_dataset{i}.csv)')
    plt.ylabel('Length')
    # plt.title(f'Length of Memory and Rules')
    plt.legend()
    plt.grid(True)

    plt.show()

In [None]:
# average memory string length for N03
y_str_length_mem_arr = np.array([0]*100)
y_str_length_rules_arr = np.array([0]*100)
x_indices = np.array(range(1, 101))

for i in range(10):
    train_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/n03_memory_dataset{i}.csv")

    for idx, row in train_df.iterrows():
        y_str_length_mem_arr[idx] += len(row['cmem_n_memory_str'])
        y_str_length_rules_arr[idx] += len(row['cmem_n_rules_str']) 

plt.figure(figsize=(15, 10))

plt.plot(x_indices, y_str_length_mem_arr/10, label='Threshold 80', color='blue', marker='o')
plt.plot(x_indices, y_str_length_rules_arr/10, label='Threshold 0', color='red', marker='o')

plt.xlabel(f'Number of Training Reports (N03)')
plt.ylabel('Average Length of Memory')
# plt.title(f'Length of Memory and Rules')
plt.legend()
plt.grid(True)

plt.show()


# Create format instruction

In [None]:
# class TrainingResponse(BaseModel):
#     reasoning: str = Field(description="reasoning to support predicted cancer stage")
#     predictedStage: str = Field(description="predicted cancer stage")
#     rules: List[str] = Field(description="list of rules") 

# class TestingResponse(BaseModel):
#     reasoning: str = Field(description="reasoning to support predicted cancer stage") 
#     predictedStage: str = Field(description="predicted cancer stage")
     

In [None]:
# type(TrainingResponse.model_json_schema())

In [None]:
# parser = PydanticOutputParser(pydantic_object=TestingResponse)
# format_instruction=parser.get_format_instructions()
# print(type(format_instruction))
# print(format_instruction)

In [None]:
# res=TestingResponse.model_validate_json('{\n  "predictedStage": "T2",\n  "reasoning": "The largest dimension of the tumor is 3.7 cm, which falls within the range for T2 (greater than 2 cm but not greater than 5 cm)."\n}')

In [None]:
# res.reasoning

# vendiagram

In [None]:
from matplotlib_venn import venn3

# Define the subset sizes again, considering the total values for each set
subsets = {'100': 7, '010': 1, '001': 14, '111': 49, '110': 0, '101': 2}

# Re-create the Venn diagram with adjustments
plt.figure(figsize=(8, 8))
venn = venn3(subsets=subsets, set_labels=('kepa', 'zscot', 'label'))

# Adjust circle sizes if necessary manually (note: matplotlib_venn doesn't always allow this directly)
# So we can only guide by the correct inputs

plt.title("Adjusted Venn Diagram")
plt.show()
