Home = Low Severity. The patient was stable enough to recover on their own.

Home With Service Facility: = Medium Severity. The patient was stable, but still needed a nurse or therapist to visit them.

Extended Care Facility: = High Severity. The patient was not stable enough to go home and needed 24/7 care in a skilled nursing or rehab facility.

Expired = Critical Severity. The patient's condition was fatal.

...

In [2]:
import pandas as pd
import numpy as np
import json
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os
import shutil
from typing import List, Dict, Any, Tuple

# Hugging Face and Transformers
try:
    from datasets import Dataset, DatasetDict, load_from_disk
    from transformers import (
        AutoTokenizer,
        AutoModelForSequenceClassification,
        TrainingArguments,
        Trainer,
        EarlyStoppingCallback,
        pipeline
    )
except ImportError:
    print("ImportError: Hugging Face libraries not found.")
    print("Please run: pip install torch transformers datasets scikit-learn pandas numpy matplotlib seaborn accelerate")
    exit()

# SKLearn
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    classification_report,
    confusion_matrix
)

# CONFIGURATION

In [1]:
# Data files
TRAIN_FILE = "data/bigdata/train.csv"
VAL_FILE = "data/bigdata/val.csv"
TEST_FILE = "data/bigdata/test.csv"

# Lighter data settings

N_TRAIN_ROWS = 5000  
N_TEST_ROWS = 500    

# Preprocessing
TARGET_COLUMN = 'discharge_disposition'

# Using ClinicalBERT for its domain-specific knowledge
MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
MAX_TOKEN_LENGTH = 512
PROCESSED_DATA_DIR = "data/processed_data/tpaul/processed_triage_data"
LABEL_INFO_FILE = "data/info/tpaul/label_info.json"

# Training
MODEL_OUTPUT_DIR = "models/tpaul/triage_classifier_model"

# --- Lighter training settings ---
BATCH_SIZE = 2 

LEARNING_RATE = 2e-5
NUM_EPOCHS = 3

# Evaluation
CONFUSION_MATRIX_FILE = "evaluation/tpaul/confusion_matrix.png"

# PREPROCESSING

In [2]:
def get_label_maps() -> Dict[str, Any]:
    print("Creating complete label mappings from a sample of data...")
    try:
        df_train = pd.read_csv(TRAIN_FILE, usecols=[TARGET_COLUMN], nrows=N_TRAIN_ROWS)
        df_val = pd.read_csv(VAL_FILE, usecols=[TARGET_COLUMN], nrows=N_TEST_ROWS)
        df_test = pd.read_csv(TEST_FILE, usecols=[TARGET_COLUMN], nrows=N_TEST_ROWS)

    except FileNotFoundError as e:
        print(f"Error: {e}. Make sure {TRAIN_FILE}, {VAL_FILE}, and {TEST_FILE} exist.")
        df_train = []
        df_val = []
        df_test = []

    all_labels = pd.concat([df_train, df_val, df_test])
    all_labels[TARGET_COLUMN] = all_labels[TARGET_COLUMN].fillna('Unknown').astype(str).str.strip()

    unique_labels = all_labels[TARGET_COLUMN].unique()
    unique_labels.sort()

    # Create mappings
    label2id = {label: i for i, label in enumerate(unique_labels)}
    id2label = {i: label for label, i in label2id.items()}
    num_labels = len(unique_labels)

    print(f"Found {num_labels} unique labels in the sample.")

    label_info = {'label2id': label2id, 'id2label': id2label, 'num_labels': num_labels}

    # Ensure parent directories exist before saving
    os.makedirs(os.path.dirname(LABEL_INFO_FILE) or '.', exist_ok=True)
    
    # Save mappings to disk
    with open(LABEL_INFO_FILE, 'w') as f:
        json.dump(label_info, f)

    return label_info

def preprocess(label2id: Dict[str, int]):
    print("\n--- STARTING PREPROCESSING ---")

    # Load Data
    print(f"Loading datasets (Train: {N_TRAIN_ROWS} rows, Val/Test: {N_TEST_ROWS} rows)...")
    try:
        df_train = pd.read_csv(TRAIN_FILE, nrows=N_TRAIN_ROWS)
        df_val = pd.read_csv(VAL_FILE, nrows=N_TEST_ROWS)
        df_test = pd.read_csv(TEST_FILE, nrows=N_TEST_ROWS)

    except FileNotFoundError as e:
        print(f"ERROR: {e}")
        print(f"Please ensure {TRAIN_FILE}, {VAL_FILE}, and {TEST_FILE} are in the correct paths.")
        df_train = []
        df_val = []
        df_test = []
        print("Warning: Created dummy data to proceed.")

    print(f"Train shape: {df_train.shape}, Val shape: {df_val.shape}, Test shape: {df_test.shape}")

    print(f"Loading tokenizer: {MODEL_NAME}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    def preprocess_function(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
        cc_list = [str(cc) if cc is not None and pd.notna(cc) else "" for cc in examples["chief_complaint"]]
        hpi_list = [str(hpi) if hpi is not None and pd.notna(hpi) else "" for hpi in examples["history_of_present_illness"]]

        text = [f"CHIEF COMPLAINT: {cc} | HISTORY: {hpi}" for cc, hpi in zip(cc_list, hpi_list)]

        tokenized_inputs = tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=MAX_TOKEN_LENGTH
        )

        if TARGET_COLUMN in examples:
            cleaned_labels = [str(label).strip() if label is not None and pd.notna(label) else 'Unknown' for label in examples[TARGET_COLUMN]]
            tokenized_inputs["labels"] = [label2id.get(label, label2id['Unknown']) for label in cleaned_labels]

        return tokenized_inputs

    # Convert Pandas to Hugging Face Datasets
    print("Converting Pandas DataFrames to Hugging Face Datasets...")
    ds_train = Dataset.from_pandas(df_train)
    ds_val = Dataset.from_pandas(df_val)
    ds_test = Dataset.from_pandas(df_test)

    print("Tokenizing datasets...")
    tokenized_train = ds_train.map(preprocess_function, batched=True)
    tokenized_val = ds_val.map(preprocess_function, batched=True)
    tokenized_test = ds_test.map(preprocess_function, batched=True)

    # Create the final DatasetDict
    processed_dataset = DatasetDict({
        'train': tokenized_train,
        'validation': tokenized_val,
        'test': tokenized_test
    })
    print("\nTokenization complete. Processed dataset:")
    print(processed_dataset)

    # Save Processed Data
    print(f"Saving processed dataset to disk at {PROCESSED_DATA_DIR}...")
    if os.path.exists(PROCESSED_DATA_DIR):
        print(f"Removing old processed data at {PROCESSED_DATA_DIR}")
        shutil.rmtree(PROCESSED_DATA_DIR)
    processed_dataset.save_to_disk(PROCESSED_DATA_DIR)

    print("--- PREPROCESSING COMPLETE ---")


# MODEL FINE-TUNING

In [3]:
def compute_metrics(eval_pred: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    accuracy = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average='weighted', zero_division=0)
    precision = precision_score(labels, predictions, average='weighted', zero_division=0)
    recall = recall_score(labels, predictions, average='weighted', zero_division=0)

    return {
        'accuracy': accuracy,
        'f1_weighted': f1,
        'precision_weighted': precision,
        'recall_weighted': recall
    }

def finetune(label_info: Dict[str, Any], device: torch.device):
    """
    Loads the processed data, fine-tunes the ClinicalBERT model,
    and saves the best model.
    """
    print("\n--- STARTING MODEL FINE-TUNING ---")

    # Load Processed Data
    try:
        print(f"Loading processed dataset from {PROCESSED_DATA_DIR}...")
        processed_dataset = load_from_disk(PROCESSED_DATA_DIR)
    except FileNotFoundError:
        print(f"Error: Processed data not found at {PROCESSED_DATA_DIR}.")
        return

    # Load Pre-trained Model
    print(f"Loading pre-trained model: {MODEL_NAME}...")
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=label_info['num_labels'],
        label2id=label_info['label2id'],
        id2label=label_info['id2label']
    )
    model.to(device)

    print("Defining training arguments...")
    os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)

    training_args = TrainingArguments(
        output_dir=MODEL_OUTPUT_DIR,
        num_train_epochs=NUM_EPOCHS,
        learning_rate=LEARNING_RATE,
        per_device_train_batch_size=BATCH_SIZE, 
        per_device_eval_batch_size=BATCH_SIZE, 
        gradient_accumulation_steps=4, 
        weight_decay=0.01,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="steps",
        logging_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,      
        fp16=torch.cuda.is_available(),
        report_to="none"              
    )

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=processed_dataset['train'],
        eval_dataset=processed_dataset['validation'],
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=1)]
    )

    trainer.train()

    # Evaluate and Save
    print("Evaluating best model on validation set...")
    eval_results = trainer.evaluate()
    print("\nValidation Results:")
    print(json.dumps(eval_results, indent=2))

    print(f"Saving final model to {MODEL_OUTPUT_DIR}")
    trainer.save_model(MODEL_OUTPUT_DIR)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.save_pretrained(MODEL_OUTPUT_DIR)

    print("--- FINE-TUNNING COMPLETE ---")

# EVALUATION 

In [4]:
def evaluate_and_infer(label_info: Dict[str, Any], device: torch.device):
    print("\n--- STARTING EVALUATION & INFERENCE ---")

    print(f"Loading fine-tuned model from {MODEL_OUTPUT_DIR}...")
    try:
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_OUTPUT_DIR)
        tokenizer = AutoTokenizer.from_pretrained(MODEL_OUTPUT_DIR)
        model.to(device)
    except OSError:
        print(f"Error: Model not found at {MODEL_OUTPUT_DIR}.")
        return

    print(f"Loading processed test data from {PROCESSED_DATA_DIR}...")
    try:
        processed_dataset = load_from_disk(PROCESSED_DATA_DIR)
        test_dataset = processed_dataset['test']
    except FileNotFoundError:
        print(f"Error: Processed data not found at {PROCESSED_DATA_DIR}.")
        return

    id2label = {int(k): v for k, v in label_info['id2label'].items()}

    all_label_indices = list(range(label_info['num_labels']))
    all_label_names = [id2label.get(i, f"UNK_{i}") for i in all_label_indices] 

    print("Running predictions on the test set...")
    os.makedirs("./temp_eval", exist_ok=True)
    trainer = Trainer(
        model=model,
        args=TrainingArguments(output_dir="./temp_eval", per_device_eval_batch_size=BATCH_SIZE)
    )

    predictions = trainer.predict(test_dataset)
    y_pred = np.argmax(predictions.predictions, axis=1)
    y_true = predictions.label_ids

    print("\n--- Test Set Classification Report ---")

    report = classification_report(
        y_true,
        y_pred,
        labels=all_label_indices, 
        target_names=all_label_names,
        zero_division=0
    )

    print(report)

    print("Generating confusion matrix...")
    cm = confusion_matrix(y_true, y_pred, labels=all_label_indices)

    figsize_x = max(10, label_info['num_labels'])
    figsize_y = max(8, label_info['num_labels'])

    plt.figure(figsize=(figsize_x, figsize_y))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=all_label_names, yticklabels=all_label_names)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix for Triage Prediction (Test Set)')
    plt.tight_layout() 

    os.makedirs(os.path.dirname(CONFUSION_MATRIX_FILE) or '.', exist_ok=True)
    plt.savefig(CONFUSION_MATRIX_FILE)
    print(f"Confusion matrix saved to {CONFUSION_MATRIX_FILE}")

    print("\n--- Running Sample Inference ---")

    triage_pipe = pipeline(
        "text-classification",
        model=model,
        tokenizer=tokenizer,
        device=0 if (device.type == "cuda") else -1 
    )

    # Test with a new patient
    new_chief_complaint = "Patient presents with chest pain and difficulty breathing."
    new_history = "73-year-old male with a history of hypertension and diabetes. Smoker for 30 years."

    input_text = f"CHIEF COMPLAINT: {new_chief_complaint} | HISTORY: {new_history}"

    print(f"\nInput Text:\n{input_text}")

    prediction = triage_pipe(input_text)
    print("\n--- Prediction ---")
    print(json.dumps(prediction, indent=2))

    all_scores = triage_pipe(input_text, return_all_scores=True)
    print("\n--- All scores ---")
    all_scores_sorted = sorted(all_scores[0], key=lambda x: x['score'], reverse=True)
    print(json.dumps(all_scores_sorted, indent=2))

    print("--- EVALUATION COMPLETE ---")


# MAIN EXECUTION

In [None]:
def main():
    """
    Runs the full pipeline:
    1. Preprocess data
    2. Fine-tune model
    3. Evaluate and infer
    """
    print("====== STARTING TRIAGE MODEL ======")

    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    print(f"Using device: {device}")

    try:
        label_info = get_label_maps()
    except Exception as e:
        print(f"Fatal Error during label map creation: {e}")
        return

    try:
        preprocess(label_info['label2id'])
    except Exception as e:
        print(f"Fatal Error during Preprocessing: {e}")
        return

    try:
        finetune(label_info, device)
    except Exception as e:
        print(f"Fatal Error during Fine-tuning: {e}")
        return

    try:
        evaluate_and_infer(label_info, device)
    except Exception as e:
        print(f"Fatal Error during Evaluation: {e}")
        return

    print("\n" + "="*60)
    print("====== TRIAGE MODEL COMPLETE ======")
    print("="*60)

if __name__ == "__main__":
    main()