In [19]:
import pandas as pd
from src.prompts import system_instruction
from huggingface_hub import InferenceClient
from tqdm import tqdm

from src.metrics import t14_performance_report, n03_performance_report, m01_performance_report
from sklearn.metrics import classification_report
import warnings
warnings.filterwarnings("ignore")

# initialized client
client = InferenceClient(model="http://127.0.0.1:8080")

In [20]:
MED42_PROMPT_TEMPLATE = """<|system|>:{system_instruction}
<|prompter|>:{prompt}
<|assistant|>:"""

# T_stage_mention_detecting_prompt = """You are provided with a pathology report for a cancer patient. 
# Please review this report. 

# Here is the report:
# ```
# {report}
# ```

# Is the pathologic T (pT) stage of the patient's cancer mentioned in the report? Please answer only Yes or No.
# """

T_stage_mention_detecting_prompt = """You are provided with a pathology report for a cancer patient. 

Here is the report:
```
{report}
```

Is the T stage explicitly mentioned in the provided pathology report? Please answer only Yes or No.
"""

N_stage_mention_detecting_prompt = """You are provided with a pathology report for a cancer patient. 

Here is the report:
```
{report}
```

Is the N stage explicitly mentioned in the provided pathology report? Please answer only Yes or No.
"""

M_stage_mention_detecting_prompt = """You are provided with a pathology report for a cancer patient. 

Here is the report:
```
{report}
```

Is the M stage explicitly mentioned in the provided pathology report? Please answer only Yes or No.
"""

# BRCA T Category

In [21]:
t14_brca_testing = pd.read_csv("/secure/shared_data/tcga_path_reports/t14_data/BRCA_T14_testing.csv")

pbar = tqdm(total=t14_brca_testing.shape[0])
ans_list = []
for _, report in t14_brca_testing.iterrows():

    filled_text_prompt = T_stage_mention_detecting_prompt.format(report=report["text"])
    lm_formated_prompt = MED42_PROMPT_TEMPLATE.format(system_instruction=system_instruction, prompt=filled_text_prompt)
    decoded_ans = client.text_generation(prompt=lm_formated_prompt, do_sample=False, 
                                             max_new_tokens=12)
    ans_list.append(decoded_ans)

    pbar.update(1)
pbar.close()

  0%|          | 0/146 [00:00<?, ?it/s]

100%|██████████| 146/146 [03:05<00:00,  1.27s/it]


In [22]:
t14_brca_testing["stage_mention"] = ans_list
sum(t14_brca_testing["stage_mention"].str.contains('No', case=False))

37

In [23]:
t14_ZSCOT_path = "/secure/shared_data/tnm/t14_res/med42-t0.7-tp0.95-nrs1.csv"
# load the ZS-COT results
t14_ZSCOT_df = pd.read_csv(t14_ZSCOT_path)
t14_ZSCOT_brca = t14_ZSCOT_df.merge(t14_brca_testing[["patient_filename", "stage_mention"]], on="patient_filename")
print("BRCA:", t14_ZSCOT_brca.shape)

mention_flag = t14_ZSCOT_brca["stage_mention"].str.contains('No', case=False)
t14_ZSCOT_brca_hasMention = t14_ZSCOT_brca[~mention_flag]
t14_ZSCOT_brca_noMention = t14_ZSCOT_brca[mention_flag]
print(t14_ZSCOT_brca_hasMention.shape, t14_ZSCOT_brca_noMention.shape)

BRCA: (146, 8)
(109, 8) (37, 8)


In [24]:
_ = t14_performance_report(t14_ZSCOT_brca_hasMention, ans_col="ans_str_0")

              precision    recall  f1-score   support

          T1       0.87      0.83      0.85        24
          T2       0.91      0.92      0.91        63
          T3       0.69      0.75      0.72        12
          T4       1.00      0.50      0.67         2

    accuracy                           0.87       101
   macro avg       0.87      0.75      0.79       101
weighted avg       0.87      0.87      0.87       101

Precision: 0.8064409961214533 (CI: 0.5695539941829004 0.9305555555555556 )
Recall: 0.765293477705456 (CI: 0.58003484249916 0.9275659602283293 )
F1: 0.773263615784203 (CI: 0.5712708914415336 0.9155401640890521 )


In [25]:
_ = t14_performance_report(t14_ZSCOT_brca_noMention, ans_col="ans_str_0")

              precision    recall  f1-score   support

          T1       0.67      0.57      0.62         7
          T2       0.83      0.83      0.83        24
          T3       0.00      0.00      0.00         1
          T4       0.50      0.50      0.50         2

    accuracy                           0.74        34
   macro avg       0.50      0.48      0.49        34
weighted avg       0.75      0.74      0.74        34

Precision: 0.48305864157892064 (CI: 0.3079967948717949 0.6955946969696968 )
Recall: 0.46481297095845037 (CI: 0.2636838161838162 0.7084953703703704 )
F1: 0.45943471844685174 (CI: 0.281437106918239 0.6542196998480242 )


In [31]:
t14_ZSCOT_brca_noMention.to_csv("/secure/shared_data/rag_tnm_results/t14_results/brca_t14_noMention.csv")

# BRCA N Category

In [26]:
n03_brca_testing = pd.read_csv("/secure/shared_data/tcga_path_reports/n03_data/BRCA_N03_testing.csv")

pbar = tqdm(total=n03_brca_testing.shape[0])
ans_list = []
for _, report in n03_brca_testing.iterrows():

    filled_text_prompt = N_stage_mention_detecting_prompt.format(report=report["text"])
    lm_formated_prompt = MED42_PROMPT_TEMPLATE.format(system_instruction=system_instruction, prompt=filled_text_prompt)
    decoded_ans = client.text_generation(prompt=lm_formated_prompt, do_sample=False, 
                                             max_new_tokens=12)
    ans_list.append(decoded_ans)

    pbar.update(1)
pbar.close()

100%|██████████| 131/131 [02:31<00:00,  1.16s/it]


In [27]:
n03_brca_testing["stage_mention"] = ans_list
sum(n03_brca_testing["stage_mention"].str.contains('No', case=False))

49

In [28]:
n03_ZSCOT_path = "/secure/shared_data/tnm/n03_res/med42-t0.7-tp0.95-nrs1.csv"
# load the ZS-COT results
n03_ZSCOT_df = pd.read_csv(n03_ZSCOT_path)
n03_ZSCOT_brca = n03_ZSCOT_df.merge(n03_brca_testing[["patient_filename", "stage_mention"]], on="patient_filename")
print("BRCA:", n03_ZSCOT_brca.shape)

mention_flag = n03_ZSCOT_brca["stage_mention"].str.contains('No', case=False)
n03_ZSCOT_brca_hasMention = n03_ZSCOT_brca[~mention_flag]
n03_ZSCOT_brca_noMention = n03_ZSCOT_brca[mention_flag]
print(n03_ZSCOT_brca_hasMention.shape, n03_ZSCOT_brca_noMention.shape)

BRCA: (131, 8)
(82, 8) (49, 8)


In [29]:
_ = n03_performance_report(n03_ZSCOT_brca_hasMention, ans_col="ans_str_0")

              precision    recall  f1-score   support

          N0       1.00      0.94      0.97        16
          N1       0.92      0.90      0.91        39
          N2       0.62      0.80      0.70        10
          N3       0.78      0.70      0.74        10

    accuracy                           0.87        75
   macro avg       0.83      0.83      0.83        75
weighted avg       0.88      0.87      0.87        75

Precision: 0.8264161501966703 (CI: 0.730248006566604 0.9132142857142856 )
Recall: 0.8310111636678434 (CI: 0.7247888661635478 0.9202989483634645 )
F1: 0.8189161115757877 (CI: 0.7201918956439536 0.9057276600819024 )


In [30]:
def modifiled_n03_performance_report(df, ans_col="ans_str"):
    # check if the ans_col contain any valid prediction (e.g., T1, T2, T3, T4)
    df['Has_Valid_Prediction'] = df[ans_col].str.contains('N0|N1|N2|N3', case=False)
    # transform the prediction string to code
    coded_pred_list = []
    for _, row in df.iterrows():
        row[ans_col] = str(row[ans_col])
        if "N0" in row[ans_col]:
            coded_pred_list.append(0)
        elif "N1" in row[ans_col]:
            coded_pred_list.append(1)
        elif "N2" in row[ans_col]:
            coded_pred_list.append(2)
        elif "N3" in row[ans_col]:
            coded_pred_list.append(3)
        else:
            # unvalid answers 
            # Has_Valid_Prediction == False
            coded_pred_list.append(-1)
    df['coded_pred'] = coded_pred_list

    effective_index = df["Has_Valid_Prediction"] == True
    coded_pred = df[effective_index]['coded_pred'].to_list()
    n_labels = df[effective_index]["n"].to_list()

    target_names = ['N0', 'N1', 'N3']
    print(classification_report(n_labels, coded_pred, target_names=target_names))

_ = modifiled_n03_performance_report(n03_ZSCOT_brca_noMention, ans_col="ans_str_0")

              precision    recall  f1-score   support

          N0       0.91      1.00      0.95        39
          N1       1.00      0.25      0.40         4
          N3       0.00      0.00      0.00         1

    accuracy                           0.91        44
   macro avg       0.64      0.42      0.45        44
weighted avg       0.89      0.91      0.88        44



In [32]:
n03_ZSCOT_brca_noMention.to_csv("/secure/shared_data/rag_tnm_results/n03_results/brca_n03_noMention.csv")

# BRCA M Category

In [15]:
m01_brca_testing = pd.read_csv("/secure/shared_data/tcga_path_reports/m01_data/BRCA_M01_testing.csv")

pbar = tqdm(total=m01_brca_testing.shape[0])
ans_list = []
for _, report in m01_brca_testing.iterrows():

    filled_text_prompt = M_stage_mention_detecting_prompt.format(report=report["text"])
    lm_formated_prompt = MED42_PROMPT_TEMPLATE.format(system_instruction=system_instruction, prompt=filled_text_prompt)
    decoded_ans = client.text_generation(prompt=lm_formated_prompt, do_sample=False, 
                                             max_new_tokens=12)
    ans_list.append(decoded_ans)

    pbar.update(1)
pbar.close()

  0%|          | 0/138 [00:00<?, ?it/s]

100%|██████████| 138/138 [02:40<00:00,  1.16s/it]


In [16]:
m01_brca_testing["stage_mention"] = ans_list
sum(m01_brca_testing["stage_mention"].str.contains('No', case=False))

105

In [17]:
tmp1 = pd.read_csv("/secure/shared_data/tnm/m01_res/med42-t0.7-tp0.95-nrs5_batch1.csv")
tmp2 = pd.read_csv("/secure/shared_data/tnm/m01_res/med42-t0.7-tp0.95-nrs5_batch2.csv")
m01_ZSCOT_df = pd.concat([tmp1, tmp2])
# load the ZS-COT results
m01_ZSCOT_brca = m01_ZSCOT_df.merge(m01_brca_testing[["patient_filename", "stage_mention"]], on="patient_filename")
print("BRCA:", m01_ZSCOT_brca.shape)

mention_flag = m01_ZSCOT_brca["stage_mention"].str.contains('No', case=False)
m01_ZSCOT_brca_hasMention = m01_ZSCOT_brca[~mention_flag]
m01_ZSCOT_brca_noMention = m01_ZSCOT_brca[mention_flag]
print(m01_ZSCOT_brca_hasMention.shape, m01_ZSCOT_brca_noMention.shape)

BRCA: (138, 18)
(33, 18) (105, 18)


In [18]:
_ = m01_performance_report(m01_ZSCOT_brca_hasMention, ans_col="ans_str_0")

tn=14, fp=16, fn=1, tp=1
              precision    recall  f1-score   support

          M0       0.93      0.47      0.62        30
          M1       0.06      0.50      0.11         2

    accuracy                           0.47        32
   macro avg       0.50      0.48      0.36        32
weighted avg       0.88      0.47      0.59        32

Precision: 0.49835754449251957 (CI: 0.4075994318181818 0.585906862745098 )
Recall: 0.4524569624204844 (CI: 0.171875 0.8064516129032258 )
F1: 0.3628455002637719 (CI: 0.23809523809523808 0.5174603174603173 )


In [7]:
_ = m01_performance_report(m01_ZSCOT_brca_noMention, ans_col="ans_str_0")

tn=68, fp=32, fn=1, tp=2
              precision    recall  f1-score   support

          M0       0.99      0.68      0.80       100
          M1       0.06      0.67      0.11         3

    accuracy                           0.68       103
   macro avg       0.52      0.67      0.46       103
weighted avg       0.96      0.68      0.78       103

Precision: 0.5225604985685791 (CI: 0.4855072463768116 0.5694444444444444 )
Recall: 0.6656147667850195 (CI: 0.321050418148611 0.8706757425742574 )
F1: 0.45516168658343753 (CI: 0.38690476190476186 0.5433136912205315 )


# LUAD T Category

In [33]:
t14_luad_testing = pd.read_csv("/secure/shared_data/tcga_path_reports/t14_data/LUAD_T14_testing.csv")

pbar = tqdm(total=t14_luad_testing.shape[0])
ans_list = []
for _, report in t14_luad_testing.iterrows():

    filled_text_prompt = T_stage_mention_detecting_prompt.format(report=report["text"])
    lm_formated_prompt = MED42_PROMPT_TEMPLATE.format(system_instruction=system_instruction, prompt=filled_text_prompt)
    decoded_ans = client.text_generation(prompt=lm_formated_prompt, do_sample=False, 
                                             max_new_tokens=12)
    ans_list.append(decoded_ans)

    pbar.update(1)
pbar.close()

  0%|          | 0/77 [00:00<?, ?it/s]

100%|██████████| 77/77 [02:01<00:00,  1.58s/it]


In [34]:
t14_luad_testing["stage_mention"] = ans_list
sum(t14_luad_testing["stage_mention"].str.contains('No', case=False))

17

In [35]:
t14_ZSCOT_path = "/secure/shared_data/tnm/t14_res/med42-t0.7-tp0.95-nrs1.csv"
# load the ZS-COT results
t14_ZSCOT_df = pd.read_csv(t14_ZSCOT_path)
t14_ZSCOT_luad = t14_ZSCOT_df.merge(t14_luad_testing[["patient_filename", "stage_mention"]], on="patient_filename")
print("LUAD:", t14_ZSCOT_luad.shape)

mention_flag = t14_ZSCOT_luad["stage_mention"].str.contains('No', case=False)
t14_ZSCOT_luad_hasMention = t14_ZSCOT_luad[~mention_flag]
t14_ZSCOT_luad_noMention = t14_ZSCOT_luad[mention_flag]
print(t14_ZSCOT_luad_hasMention.shape, t14_ZSCOT_luad_noMention.shape)

LUAD: (77, 8)
(60, 8) (17, 8)


In [36]:
_ = t14_performance_report(t14_ZSCOT_luad_hasMention, ans_col="ans_str_0")

              precision    recall  f1-score   support

          T1       1.00      0.95      0.97        19
          T2       1.00      1.00      1.00        28
          T3       0.90      1.00      0.95         9
          T4       1.00      1.00      1.00         4

    accuracy                           0.98        60
   macro avg       0.97      0.99      0.98        60
weighted avg       0.98      0.98      0.98        60

Precision: 0.9739981782242813 (CI: 0.9166666666666666 1.0 )
Recall: 0.9861578083897027 (CI: 0.952734375 1.0 )
F1: 0.9781584380773338 (CI: 0.9283263305322129 1.0 )


In [37]:
def modified_t14_performance_report(df, ans_col="ans_str"):
    # check if the ans_col contain any valid prediction (e.g., T1, T2, T3, T4)
    df['Has_Valid_Prediction'] = df[ans_col].str.contains('T1|T2|T3|T4', case=False)
    # transform the prediction string to code
    # note that following the t column we set T1 = 0, ... T4 = 3 
    coded_pred_list = []
    for _, row in df.iterrows():
        if "T1" in row[ans_col]:
            coded_pred_list.append(0)
        elif "T2" in row[ans_col]:
            coded_pred_list.append(1)
        elif "T3" in row[ans_col]:
            coded_pred_list.append(2)
        elif "T4" in row[ans_col]:
            coded_pred_list.append(3)
        else:
            # unvalid answers 
            # Has_Valid_Prediction == False
            coded_pred_list.append(-1)
    df['coded_pred'] = coded_pred_list

    effective_index = df["Has_Valid_Prediction"] == True
    coded_pred = df[effective_index]['coded_pred'].to_list()
    t_labels = df[effective_index]["t"].to_list()

    target_names = ['T1', 'T2', 'T3']
    print(classification_report(t_labels, coded_pred, target_names=target_names))

modified_t14_performance_report(t14_ZSCOT_luad_noMention, ans_col="ans_str_0")

              precision    recall  f1-score   support

          T1       0.50      0.25      0.33         4
          T2       0.56      0.71      0.63         7
          T3       0.00      0.00      0.00         1

    accuracy                           0.50        12
   macro avg       0.35      0.32      0.32        12
weighted avg       0.49      0.50      0.48        12



In [38]:
t14_ZSCOT_luad_noMention.to_csv("/secure/shared_data/rag_tnm_results/t14_results/luad_t14_noMention.csv")

# LUAD N Category

In [39]:
n03_luad_testing = pd.read_csv("/secure/shared_data/tcga_path_reports/n03_data/LUAD_N03_testing.csv")

pbar = tqdm(total=n03_luad_testing.shape[0])
ans_list = []
for _, report in n03_luad_testing.iterrows():

    filled_text_prompt = N_stage_mention_detecting_prompt.format(report=report["text"])
    lm_formated_prompt = MED42_PROMPT_TEMPLATE.format(system_instruction=system_instruction, prompt=filled_text_prompt)
    decoded_ans = client.text_generation(prompt=lm_formated_prompt, do_sample=False, 
                                             max_new_tokens=12)
    ans_list.append(decoded_ans)

    pbar.update(1)
pbar.close()

100%|██████████| 82/82 [02:36<00:00,  1.90s/it]


In [40]:
n03_luad_testing["stage_mention"] = ans_list
sum(n03_luad_testing["stage_mention"].str.contains('No', case=False))

30

In [41]:
n03_ZSCOT_path = "/secure/shared_data/tnm/n03_res/med42-t0.7-tp0.95-nrs1.csv"
# load the ZS-COT results
n03_ZSCOT_df = pd.read_csv(n03_ZSCOT_path)
n03_ZSCOT_luad = n03_ZSCOT_df.merge(n03_luad_testing[["patient_filename", "stage_mention"]], on="patient_filename")
print("LUAD:", n03_ZSCOT_luad.shape)

mention_flag = n03_ZSCOT_luad["stage_mention"].str.contains('No', case=False)
n03_ZSCOT_luad_hasMention = n03_ZSCOT_luad[~mention_flag]
n03_ZSCOT_luad_noMention = n03_ZSCOT_luad[mention_flag]
print(n03_ZSCOT_luad_hasMention.shape, n03_ZSCOT_luad_noMention.shape)

LUAD: (82, 8)
(52, 8) (30, 8)


In [42]:
_ = n03_performance_report(n03_ZSCOT_luad_hasMention, ans_col="ans_str_0")

              precision    recall  f1-score   support

          N0       1.00      0.97      0.98        29
          N1       0.86      1.00      0.92        12
          N2       1.00      0.89      0.94         9
          N3       1.00      1.00      1.00         1

    accuracy                           0.96        51
   macro avg       0.96      0.96      0.96        51
weighted avg       0.97      0.96      0.96        51

Precision: 0.9595605515893324 (CI: 0.9 1.0 )
Recall: 0.9585178763715139 (CI: 0.8752638888888888 1.0 )
F1: 0.954488209501788 (CI: 0.8842447017920595 1.0 )


In [43]:
def modifiled_n03_performance_report(df, ans_col="ans_str"):
    # check if the ans_col contain any valid prediction (e.g., T1, T2, T3, T4)
    df['Has_Valid_Prediction'] = df[ans_col].str.contains('N0|N1|N2|N3', case=False)
    # transform the prediction string to code
    coded_pred_list = []
    for _, row in df.iterrows():
        row[ans_col] = str(row[ans_col])
        if "N0" in row[ans_col]:
            coded_pred_list.append(0)
        elif "N1" in row[ans_col]:
            coded_pred_list.append(1)
        elif "N2" in row[ans_col]:
            coded_pred_list.append(2)
        elif "N3" in row[ans_col]:
            coded_pred_list.append(3)
        else:
            # unvalid answers 
            # Has_Valid_Prediction == False
            coded_pred_list.append(-1)
    df['coded_pred'] = coded_pred_list

    effective_index = df["Has_Valid_Prediction"] == True
    coded_pred = df[effective_index]['coded_pred'].to_list()
    n_labels = df[effective_index]["n"].to_list()

    target_names = ['N0', 'N1', 'N2']
    print(classification_report(n_labels, coded_pred, target_names=target_names))

_ = modifiled_n03_performance_report(n03_ZSCOT_luad_noMention, ans_col="ans_str_0")

              precision    recall  f1-score   support

          N0       1.00      0.95      0.98        22
          N1       1.00      0.33      0.50         3
          N2       0.00      0.00      0.00         0

    accuracy                           0.88        25
   macro avg       0.67      0.43      0.49        25
weighted avg       1.00      0.88      0.92        25



In [None]:
n03_ZSCOT_luad_noMention.to_csv("/secure/shared_data/rag_tnm_results/n03_results/luad_n03_noMention.csv")

# LUAD M Category

In [8]:
m01_luad_testing = pd.read_csv("/secure/shared_data/tcga_path_reports/m01_data/LUAD_M01_testing.csv")

pbar = tqdm(total=m01_luad_testing.shape[0])
ans_list = []
for _, report in m01_luad_testing.iterrows():

    filled_text_prompt = M_stage_mention_detecting_prompt.format(report=report["text"])
    lm_formated_prompt = MED42_PROMPT_TEMPLATE.format(system_instruction=system_instruction, prompt=filled_text_prompt)
    decoded_ans = client.text_generation(prompt=lm_formated_prompt, do_sample=False, 
                                             max_new_tokens=12)
    ans_list.append(decoded_ans)

    pbar.update(1)
pbar.close()

100%|██████████| 63/63 [01:01<00:00,  1.02it/s]


In [9]:
m01_luad_testing["stage_mention"] = ans_list
sum(m01_luad_testing["stage_mention"].str.contains('No', case=False))

51

In [10]:
tmp1 = pd.read_csv("/secure/shared_data/tnm/m01_res/med42-t0.7-tp0.95-nrs5_batch1.csv")
tmp2 = pd.read_csv("/secure/shared_data/tnm/m01_res/med42-t0.7-tp0.95-nrs5_batch2.csv")
m01_ZSCOT_df = pd.concat([tmp1, tmp2])
# load the ZS-COT results
m01_ZSCOT_luad = m01_ZSCOT_df.merge(m01_luad_testing[["patient_filename", "stage_mention"]], on="patient_filename")
print("BRCA:", m01_ZSCOT_luad.shape)

mention_flag = m01_ZSCOT_luad["stage_mention"].str.contains('No', case=False)
m01_ZSCOT_luad_hasMention = m01_ZSCOT_luad[~mention_flag]
m01_ZSCOT_luad_noMention = m01_ZSCOT_luad[mention_flag]
print(m01_ZSCOT_luad_hasMention.shape, m01_ZSCOT_luad_noMention.shape)

BRCA: (63, 18)
(12, 18) (51, 18)


In [11]:
_ = m01_performance_report(m01_ZSCOT_luad_hasMention, ans_col="ans_str_0")

tn=8, fp=2, fn=0, tp=2
              precision    recall  f1-score   support

          M0       1.00      0.80      0.89        10
          M1       0.50      1.00      0.67         2

    accuracy                           0.83        12
   macro avg       0.75      0.90      0.78        12
weighted avg       0.92      0.83      0.85        12

Precision: 0.7511115079365079 (CI: 0.5 1.0 )
Recall: 0.8422174963924965 (CI: 0.353125 1.0 )
F1: 0.7494292861387726 (CI: 0.4135714285714286 1.0 )


In [12]:
_ = m01_performance_report(m01_ZSCOT_luad_noMention, ans_col="ans_str_0")

tn=35, fp=13, fn=2, tp=0
              precision    recall  f1-score   support

          M0       0.95      0.73      0.82        48
          M1       0.00      0.00      0.00         2

    accuracy                           0.70        50
   macro avg       0.47      0.36      0.41        50
weighted avg       0.91      0.70      0.79        50

Precision: 0.47412418082206487 (CI: 0.4342105263157895 0.5 )
Recall: 0.36645461187328565 (CI: 0.3040019132653061 0.42857142857142855 )
F1: 0.41257697382484215 (CI: 0.3708465189873418 0.45368490205446715 )
