In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import json
from huggingface_hub import login

from tqdm import tqdm
import torch
from transformers import AutoTokenizer, pipeline

import pandas as pd
import numpy as np
import re

from baseline import *

# Read the JSON config file
with open('config.json', 'r') as f:
    config = json.load(f)

# Get the token from the JSON file
hg_token = config['HuggingFace']['token']
# Login using the token
login(token=hg_token)

# LLM folder
llm_folder = "/PHShome/jn180/llm_public_host"
# Data folder
data_folder = "/PHShome/cs1839/capstone_data/"
# results table path
results_df_path = data_folder + "results.csv"

# data to inference 
medication_status_test = pd.read_csv(data_folder + "medication_status_test.csv")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /PHShome/cs1839/.cache/huggingface/token
Login successful


# Pipeline

In [30]:
name_model_paths ={   
    # "Bio_ClinicalBERT": "/PHShome/jn180/llm_public_host/Bio_ClinicalBERT",

    # "Llama-3.1-8B": "/netapp3/raw_data3/share/llm_public_host/Llama-3.1-8B",
    "Llama-3.1-8B-Instruct": "/netapp3/raw_data3/share/llm_public_host/Llama-3.1-8B-Instruct",

    "Llama-3.2-1B-Instruct": "/netapp3/raw_data3/share/llm_public_host/Llama-3.2-1B-Instruct",
    "Llama-3.2-3B-Instruct": "/netapp3/raw_data3/share/llm_public_host/Llama-3.2-3B-Instruct",

    "Qwen2-7B-Instruct": "/PHShome/jn180/llm_public_host/Qwen2-7B-Instruct",
    "Qwen2.5-14B-Instruct": "/netapp3/raw_data3/share/llm_public_host/Qwen2.5-14B-Instruct",

    "meditron-7b": "/PHShome/jn180/llm_public_host/meditron-7b",

    # "Mistral-7B-Instruct-v0.3": "/netapp3/raw_data3/share/llm_public_host/Mistral-7B-Instruct-v0.3"

}

import os
# Set the environment variable to specify the GPUs

os.environ["CUDA_VISIBLE_DEVICES"] = "2"


name_dataset = "MIT"
data_folder = "/PHShome/cs1839/capstone_data/"
results_df_path = data_folder + "results.csv"
medication_status_test = pd.read_csv(data_folder + "medication_status_test.csv")

# prompt_template = """
# Identify and categorize the medications mentioned in the following medical note. Extract all medications the patient has taken before, is currently taking, and any other medications mentioned.
# Note: Adjust the number of medications in each category based on the input. Write None if no other medication mentioned. Strictly follow the output format.
# Expected Output Format:
# "
# - Current Medications (Active): Medication_1, Medication_2
# - Discontinued Medications: Medication_3, Medication_4
# - Other Mentioned Medications (neither active nor discontinued): Medication_5, Medication_6
# END"

# Input Medical Note:
# {}

# Output:
# """

prompt_template = """
Input Medical Note:
{}

Create a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.

Expected Output Format:
"
- Current Medications (Active): Medication_1, Medication_2
- Discontinued Medications: Medication_3, Medication_4
- Other Mentioned Medications (neither active nor discontinued): Medication_5, Medication_6
END"

Output:
"""

for model_name, model_path in name_model_paths.items():
    df = run_pipeline(model_path=model_path,
                        input_df=medication_status_test[medication_status_test['index']==96],
                        prompt_template=prompt_template,
                        batch_size=16,
                        max_token_output=80,
                        use_sampling=False)
df

Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.95it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 1/1 [00:01<00:00,  1.07s/it]


Unnamed: 0,index,snippet,active_medications,discontinued_medications,neither_medications,active_medications_pred,discontinued_medications_pred,neither_medications_pred,extraction_precision,extraction_recall,conditional_accuracy,conditional_macro_f1,conditional_macro_precision,conditional_macro_recall
0,96,She did have a capsule study done during her p...,"[imuran, remicade]",[6 mp],[],[iv imuran],[6 mp],"[remicade, capsule study]",0.5,0.666667,0.5,0.333333,0.333333,0.333333


# Result

# Metrics

## Task 1: Medication Extraction

- **Precision**: Measures the proportion of correctly predicted medications out of all predicted medications.

$$
  \text{Precision} = \frac{\text{True Positives (TP)}}{\text{True Positives (TP)} + \text{False Positives (FP)}}
  $$

- **Recall**: Measures the proportion of correctly predicted medications out of all actual medications.

 $$
  \text{Recall} = \frac{\text{True Positives (TP)}}{\text{True Positives (TP)} + \text{False Negatives (FN)}}
  $$

## Task 2: Status Classification

- **Conditional Accuracy**: Measures the proportion of correct status predictions out of all correctly extracted medications from Task 1.
  $$
  \text{Conditional Accuracy} = \frac{\text{Correct Predictions for the Classes}}{\text{Total Correctly Extracted Medications from Task 1}}
  $$

- **Conditional Macro F1**: Combines precision and recall for each status class, calculates the F1-score for each, then averages them across classes.
  $$
  \text{F1\text{-}score} = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}}
  $$

---

# Example
## Task 1

| Active Medication | Discontinued Medication | Active Medication (Predicted) | Discontinued Medication (Predicted) |
|-------------------|-------------------------|-------------------------------|-------------------------------------|
| A                 | B                       | A                             | C                                   |


True Set: A, B

Pred Set: A, C



Precision = 1/2

Recall = 1/2


## Task 2
conditional metrics will only consider: A (C is not correctly extracted, removed)
| Active Medication | Discontinued Medication | Active Medication (Predicted) | Discontinued Medication (Predicted) |
|-------------------|-------------------------|-------------------------------|-------------------------------------|
| A                 | B                       | A                             |                                     |

conditional_accuracy = 1/2 

conditional_precision:
- Active: 1
- **Discountinued: 1**

conditional_recall:
- Active: 1
- Discountinued: 0



| Active Medication | Discontinued Medication | Active Medication (Predicted) | Discontinued Medication (Predicted) |
|-------------------|-------------------------|-------------------------------|-------------------------------------|
| A, C              |                         | A                             |C                                    |
| A                 |B, C                     | A                             |C                                    |
| A, B              |                         |                               |                                     |

conditional_acc =（A+A+C）/ (A+C+A+C) = 3/4 

conditional_precision_active = (A+A)/(A+A) = 1

conditional_precision_discountinued = C / (C+C) = 1/2

conditional_recall_active = (A+A)/ (A+A+A+B+c) = 2/5

conditional_recall_discountinued = (C)/ (B+C) = 1/2


In [17]:
import pandas as pd
from baseline import calculate_metrics_by_dataset

# Example usage
data = {
    'active_medications': [['A', 'C'], ['A'], ['A', 'B']],
    'discontinued_medications': [['E'], ['B', 'C'], []],
    'neither_medications': [['D'], [], []],
    'active_medications_pred': [['A','E'], ['A'], ['A','B']],
    'discontinued_medications_pred': [['C','D'], ['C'], []],
    'neither_medications_pred': [[], [], ['E']]
}

# Create the DataFrame
mimic_iv = pd.DataFrame(data)

# Run the function on the dataset
extraction_precision, extraction_recall, extraction_f1, conditional_accuracy, conditional_macro_f1, conditional_macro_precision, conditional_macro_recall = calculate_metrics_by_dataset(mimic_iv, 'MIMIC')

# Print the results
print(f"Extraction Precision: {extraction_precision:.3f}")
print(f"Extraction Recall: {extraction_recall:.3f}")
print(f"Extraction F1: {extraction_f1:.3f}")
print(f"Conditional Accuracy: {conditional_accuracy:.3f}")
print(f"Conditional Macro Precision: {conditional_macro_precision:.3f}")
print(f"Conditional Macro Recall: {conditional_macro_recall:.3f}")
print(f"Conditional Macro F1: {conditional_macro_f1:.3f}")

mimic_iv[['active_medications', 'discontinued_medications', 'neither_medications', 'active_medications_pred', 'discontinued_medications_pred', 'neither_medications_pred']]

Extraction Precision: 0.889
Extraction Recall: 0.889
Extraction F1: 0.889
Conditional Accuracy: 0.556
Conditional Macro Precision: 0.378
Conditional Macro Recall: 0.378
Conditional Macro F1: 0.378


Unnamed: 0,active_medications,discontinued_medications,neither_medications,active_medications_pred,discontinued_medications_pred,neither_medications_pred
0,"[A, C]",[E],[D],"[A, E]","[C, D]",[]
1,[A],"[B, C]",[],[A],[C],[]
2,"[A, B]",[],[],"[A, B]",[],[E]


In [30]:
# Load the data and sort it by the specified columns
import pandas as pd

# Data folder
data_folder = "/PHShome/cs1839/capstone_data/"
# results table path
results_df_path = data_folder + "results.csv"

result_df = pd.read_csv(results_df_path).round(3)

result_df = result_df._append({'Prompt': 'Create a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.',
                   'Dataset': 'MIT', 
                   'Model': 'GPT-3 + R(32 LOC)(0-Shot)',
                   'extraction_precision': 0.87,
                   'extraction_recall': 0.83,
                   'extraction_f1': round(2 * 0.87 * 0.83 / (0.87 + 0.83),3),
                   'conditional_accuracy': 0.85,
                   'conditional_macro_f1': 0.69,
                   'conditional_macro_precision': '--',
                   'conditional_macro_recall': '--'}, ignore_index=True)


result_df = result_df._append({'Prompt': 'Create a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.',
                   'Dataset': 'MIT', 
                   'Model': 'GPT-3 + R(8 LOC)(1-Shot)',
                   'extraction_precision': 0.90,
                   'extraction_recall': 0.92,
                    'extraction_f1': round(2 * 0.90 * 0.92 / (0.90 + 0.92),3),
                   'conditional_accuracy': 0.89,
                   'conditional_macro_f1': 0.62,
                   'conditional_macro_precision': '--',
                   'conditional_macro_recall': '--'}, ignore_index=True)

result_df[result_df.Dataset != 'Internal Data'].\
sort_values(
    by=['Prompt', 'Dataset', 'extraction_f1', 'conditional_accuracy', 'conditional_macro_f1'],
    ascending=[True, False, False, False, False] 
).set_index(['Prompt', 'Dataset', 'Model'])


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,extraction_precision,extraction_recall,extraction_f1,conditional_accuracy,conditional_macro_f1,conditional_macro_precision,conditional_macro_recall
Prompt,Dataset,Model,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
"Create a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.",MIT,GPT-3 + R(8 LOC)(1-Shot),0.9,0.92,0.91,0.89,0.62,--,--
"Create a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.",MIT,GPT-3 + R(32 LOC)(0-Shot),0.87,0.83,0.85,0.85,0.69,--,--
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\n- Other Mentioned Medications (neither active nor discontinued): Medication_5, Medication_6\nEND\n\nOutput:\n",MIT,Llama-3.1-70B-Instruct,0.788,0.909,0.844,0.727,0.812,0.846,0.782
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\n- Other Mentioned Medications (neither active nor discontinued): Medication_5, Medication_6\nEND\n\nOutput:\n",MIT,Llama-3.1-8B-Instruct,0.784,0.909,0.842,0.64,0.729,0.781,0.71
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\n- Other Mentioned Medications (neither active nor discontinued): Medication_5, Medication_6\nEND\n\nOutput:\n",MIT,Llama-3.1-8B,0.829,0.844,0.837,0.532,0.47,0.522,0.434
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\n- Other Mentioned Medications (neither active nor discontinued): Medication_5, Medication_6\nEND\n\nOutput:\n",MIT,Qwen2.5-32B-Instruct,0.773,0.853,0.811,0.675,0.732,0.775,0.705
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\n- Other Mentioned Medications (neither active nor discontinued): Medication_5, Medication_6\nEND\n\nOutput:\n",MIT,meditron-70b,0.727,0.909,0.808,0.487,0.539,0.581,0.54
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\n- Other Mentioned Medications (neither active nor discontinued): Medication_5, Medication_6\nEND\n\nOutput:\n",MIT,Mistral-Nemo-Instruct-2407,0.753,0.862,0.804,0.653,0.699,0.744,0.666
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\n- Other Mentioned Medications (neither active nor discontinued): Medication_5, Medication_6\nEND\n\nOutput:\n",MIT,Qwen2.5-14B-Instruct,0.731,0.847,0.785,0.652,0.718,0.776,0.669
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\n- Other Mentioned Medications (neither active nor discontinued): Medication_5, Medication_6\nEND\n\nOutput:\n",MIT,Qwen2-7B-Instruct,0.689,0.862,0.766,0.494,0.567,0.602,0.554


In [32]:
result_df[result_df.Dataset == 'Internal Data'].\
sort_values(
    by=['Prompt', 'Dataset', 'extraction_recall', 'conditional_accuracy', 'conditional_macro_f1'],
    ascending=[True, False, False, False, False] 
).set_index(['Prompt', 'Dataset', 'Model'])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,extraction_precision,extraction_recall,extraction_f1,conditional_accuracy,conditional_macro_f1,conditional_macro_precision,conditional_macro_recall
Prompt,Dataset,Model,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued. This dataset will only be evaluated on the following medications: 'prochlorperazine', 'compazine', 'navane', 'fluphenazine', 'haldol', 'haloperidol', 'pimozide', 'STELAZINE', 'THORAZINE', 'prolixin', 'perphenazine', 'sertraline', 'memantine', 'chlorproMAZINE', 'loxapine'.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\nEND\n\nOutput:\n",Internal Data,Qwen2.5-14B-Instruct,0.209,0.84,0.335,0.169,0.686,0.776,0.623
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued. This dataset will only be evaluated on the following medications: 'prochlorperazine', 'compazine', 'navane', 'fluphenazine', 'haldol', 'haloperidol', 'pimozide', 'STELAZINE', 'THORAZINE', 'prolixin', 'perphenazine', 'sertraline', 'memantine', 'chlorproMAZINE', 'loxapine'.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\nEND\n\nOutput:\n",Internal Data,meditron-70b,0.32,0.761,0.451,0.223,0.389,0.635,0.397
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued. This dataset will only be evaluated on the following medications: 'prochlorperazine', 'compazine', 'navane', 'fluphenazine', 'haldol', 'haloperidol', 'pimozide', 'STELAZINE', 'THORAZINE', 'prolixin', 'perphenazine', 'sertraline', 'memantine', 'chlorproMAZINE', 'loxapine'.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\nEND\n\nOutput:\n",Internal Data,Mistral-Nemo-Instruct-2407,0.498,0.708,0.585,0.416,0.607,0.847,0.515
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued. This dataset will only be evaluated on the following medications: 'prochlorperazine', 'compazine', 'navane', 'fluphenazine', 'haldol', 'haloperidol', 'pimozide', 'STELAZINE', 'THORAZINE', 'prolixin', 'perphenazine', 'sertraline', 'memantine', 'chlorproMAZINE', 'loxapine'.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\nEND\n\nOutput:\n",Internal Data,Llama-3.1-70B-Instruct,0.579,0.68,0.626,0.539,0.76,0.898,0.665
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued. This dataset will only be evaluated on the following medications: 'prochlorperazine', 'compazine', 'navane', 'fluphenazine', 'haldol', 'haloperidol', 'pimozide', 'STELAZINE', 'THORAZINE', 'prolixin', 'perphenazine', 'sertraline', 'memantine', 'chlorproMAZINE', 'loxapine'.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\nEND\n\nOutput:\n",Internal Data,Llama-3.1-8B,0.183,0.667,0.287,0.123,0.426,0.596,0.368
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued. This dataset will only be evaluated on the following medications: 'prochlorperazine', 'compazine', 'navane', 'fluphenazine', 'haldol', 'haloperidol', 'pimozide', 'STELAZINE', 'THORAZINE', 'prolixin', 'perphenazine', 'sertraline', 'memantine', 'chlorproMAZINE', 'loxapine'.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\nEND\n\nOutput:\n",Internal Data,Llama-3.2-3B-Instruct,0.186,0.638,0.289,0.123,0.439,0.532,0.374
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued. This dataset will only be evaluated on the following medications: 'prochlorperazine', 'compazine', 'navane', 'fluphenazine', 'haldol', 'haloperidol', 'pimozide', 'STELAZINE', 'THORAZINE', 'prolixin', 'perphenazine', 'sertraline', 'memantine', 'chlorproMAZINE', 'loxapine'.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\nEND\n\nOutput:\n",Internal Data,Qwen2-7B-Instruct,0.227,0.588,0.327,0.171,0.532,0.7,0.429
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued. This dataset will only be evaluated on the following medications: 'prochlorperazine', 'compazine', 'navane', 'fluphenazine', 'haldol', 'haloperidol', 'pimozide', 'STELAZINE', 'THORAZINE', 'prolixin', 'perphenazine', 'sertraline', 'memantine', 'chlorproMAZINE', 'loxapine'.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\nEND\n\nOutput:\n",Internal Data,Qwen2.5-32B-Instruct,0.344,0.565,0.427,0.276,0.508,0.7,0.399
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued. This dataset will only be evaluated on the following medications: 'prochlorperazine', 'compazine', 'navane', 'fluphenazine', 'haldol', 'haloperidol', 'pimozide', 'STELAZINE', 'THORAZINE', 'prolixin', 'perphenazine', 'sertraline', 'memantine', 'chlorproMAZINE', 'loxapine'.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\nEND\n\nOutput:\n",Internal Data,Qwen2.5-72B-Instruct,0.771,0.546,0.639,0.707,0.584,0.914,0.446
"Input Medical Note:\n{}\n\nCreate a bulleted list of which medications are mentioned and whether they are active, discontinued. This dataset will only be evaluated on the following medications: 'prochlorperazine', 'compazine', 'navane', 'fluphenazine', 'haldol', 'haloperidol', 'pimozide', 'STELAZINE', 'THORAZINE', 'prolixin', 'perphenazine', 'sertraline', 'memantine', 'chlorproMAZINE', 'loxapine'.\n\nExpected Output Format:\n- Current Medications (Active): Medication_1, Medication_2\n- Discontinued Medications: Medication_3, Medication_4\nEND\n\nOutput:\n",Internal Data,Mistral-7B-Instruct-v0.3,0.195,0.461,0.274,0.153,0.479,0.75,0.354
