In [1]:
!pip install transformers bertviz transformers-interpret optuna

Collecting bertviz
  Downloading bertviz-1.4.0-py3-none-any.whl.metadata (19 kB)
Collecting transformers-interpret
  Downloading transformers_interpret-0.10.0-py3-none-any.whl.metadata (45 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.9/45.9 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting optuna
  Downloading optuna-4.2.1-py3-none-any.whl.metadata (17 kB)
Collecting boto3 (from bertviz)
  Downloading boto3-1.37.18-py3-none-any.whl.metadata (6.7 kB)
Collecting captum>=0.3.1 (from transformers-interpret)
  Downloading captum-0.7.0-py3-none-any.whl.metadata (26 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.15.1-py3-none-any.whl.metadata (7.2 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Collecting jedi>=0.16 (from ipython<8.0.0,>=7.31.1->transformers-interpret)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from 

In [21]:
import os
import json
import shutil
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModel,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback
)
from bertviz import head_view
from transformers_interpret import SequenceClassificationExplainer
import optuna

os.environ["WANDB_DISABLED"] = "true"

In [22]:
df = pd.read_csv("/content/MLHC_train_classification_3.csv")

# Map triage levels to numerical labels
triage_mapping = {
    "Immediate": 0,
    "Emergent": 1,
    "Urgent": 2,
    "Semi-urgent": 3,
    "Nonurgent": 4
}
df["triage_value"] = df["triage_level"].map(triage_mapping)
df.dropna(inplace=True)

# Split the dataset (with stratification)
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df["text_data"].tolist(),
    df["triage_value"].tolist(),
    test_size=0.2,
    random_state=42,
    stratify=df["triage_value"].tolist()
)

In [23]:
model_checkpoint = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
max_length = 256

def tokenize_data(texts):
    return tokenizer(texts, padding="max_length", truncation=True, max_length=max_length)

train_encodings = tokenize_data(train_texts)
val_encodings = tokenize_data(val_texts)

class TriageDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

train_dataset = TriageDataset(train_encodings, train_labels)
val_dataset = TriageDataset(val_encodings, val_labels)

In [18]:
# Model Setup and Training
config = AutoConfig.from_pretrained(
    model_checkpoint,
    num_labels=5,
    problem_type="single_label_classification",
)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probabilities = F.softmax(torch.tensor(logits), dim=-1).numpy()
    auc = roc_auc_score(labels, probabilities, multi_class='ovr', labels=[0, 1, 2, 3, 4])
    predictions = np.argmax(probabilities, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    return {"eval_auc": auc, "accuracy": accuracy}

# TrainingArguments with a scheduler
training_args = TrainingArguments(
    output_dir="./clinicalbert_triage",
    # Evaluate every X steps instead of every epoch
    evaluation_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    learning_rate=3e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=10,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_strategy="steps",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="eval_auc",
    greater_is_better=True,
    report_to="none",
    warmup_steps=200,
    fp16=torch.cuda.is_available(),
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

trainer.train()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Auc,Accuracy
200,1.1018,1.117243,0.690489,0.522992
400,1.0541,1.000289,0.754252,0.575107
600,0.98,0.97403,0.752648,0.587983
800,0.9791,0.967226,0.755246,0.59534
1000,0.962,0.965086,0.756861,0.602085
1200,0.9351,0.978715,0.756255,0.587983
1400,0.9152,0.992933,0.755032,0.594114
1600,0.8825,0.986699,0.760125,0.575107
1800,0.8595,0.992483,0.755572,0.591048
2000,0.8655,0.992732,0.754662,0.583691


TrainOutput(global_step=2040, training_loss=0.9593689207937203, metrics={'train_runtime': 157.3341, 'train_samples_per_second': 207.298, 'train_steps_per_second': 12.966, 'total_flos': 4290799108231680.0, 'train_loss': 0.9593689207937203, 'epoch': 5.0})

In [31]:
fixed_epochs = 5
fixed_warmup_steps = 200

# Define model_init for fresh initialization.
def model_init():
    return AutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config)

def hp_space(trial):
    return {
        "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64]),
        "weight_decay": trial.suggest_float("weight_decay", 0.0, 0.1, step=0.01),
        "hidden_dropout_prob": trial.suggest_float("hidden_dropout_prob", 0.1, 0.5, step=0.1)
    }

hp_training_args = TrainingArguments(
    output_dir="./clinicalbert_triage_hp",
    evaluation_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    learning_rate=2e-5,  # placeholder
    per_device_train_batch_size=16,  # placeholder
    per_device_eval_batch_size=16,
    num_train_epochs=fixed_epochs,
    weight_decay=0.01,  # placeholder
    logging_dir="./logs",
    logging_strategy="steps",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="eval_auc",
    greater_is_better=True,
    report_to="none",
    warmup_steps=fixed_warmup_steps,
    fp16=torch.cuda.is_available(),
)
hp_trainer = Trainer(
    model_init=model_init,
    args=hp_training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
best_run = hp_trainer.hyperparameter_search(
    hp_space=hp_space,
    direction="maximize",
    n_trials=10,
    compute_objective=lambda metrics: metrics["eval_auc"]
)
print("Best hyperparameters:", best_run.hyperparameters)
print("Best run:", best_run)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
[I 2025-03-22 17:21:46,749] A new study created in memory with name: no-name-c5f2a695-1181-4ef7-9090-335ffd5ca20c
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
Trying to set hidden_dropout_prob in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classif

Step,Training Loss,Validation Loss,Auc,Accuracy
200,1.0038,0.993376,0.751387,0.602698
400,0.9008,0.982801,0.769048,0.576947


[I 2025-03-22 17:22:57,843] Trial 0 finished with value: 0.7690475829125978 and parameters: {'learning_rate': 4.931280201666149e-05, 'per_device_train_batch_size': 64, 'weight_decay': 0.06, 'hidden_dropout_prob': 0.2}. Best is trial 0 with value: 0.7690475829125978.
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
Trying to set hidden_dropout_prob in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Auc,Accuracy
200,1.088,1.015781,0.740799,0.54813
400,0.9818,0.969296,0.764417,0.597793
600,0.9504,0.977269,0.756533,0.589209
800,0.9195,0.993645,0.763742,0.581852
1000,0.8536,1.001426,0.752657,0.586757


[I 2025-03-22 17:24:28,766] Trial 1 finished with value: 0.7526572357218823 and parameters: {'learning_rate': 3.8992578177666106e-05, 'per_device_train_batch_size': 32, 'weight_decay': 0.06, 'hidden_dropout_prob': 0.4}. Best is trial 0 with value: 0.7690475829125978.
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
Trying to set hidden_dropout_prob in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Auc,Accuracy
200,1.0477,1.005864,0.728832,0.597793
400,0.9253,0.965343,0.774789,0.591662


[I 2025-03-22 17:25:39,690] Trial 2 finished with value: 0.7747886692725368 and parameters: {'learning_rate': 2.2983682443227746e-05, 'per_device_train_batch_size': 64, 'weight_decay': 0.03, 'hidden_dropout_prob': 0.2}. Best is trial 2 with value: 0.7747886692725368.
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
Trying to set hidden_dropout_prob in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Auc,Accuracy
200,1.0261,0.999352,0.746316,0.602085
400,0.9151,0.97067,0.778018,0.57572


[I 2025-03-22 17:26:54,995] Trial 3 finished with value: 0.7780183602503673 and parameters: {'learning_rate': 2.866764812379614e-05, 'per_device_train_batch_size': 64, 'weight_decay': 0.07, 'hidden_dropout_prob': 0.30000000000000004}. Best is trial 3 with value: 0.7780183602503673.
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
Trying to set hidden_dropout_prob in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Auc,Accuracy
200,1.0908,1.109307,0.664025,0.522992
400,1.0672,1.018032,0.724933,0.572655
600,0.975,0.975097,0.757595,0.59534
800,0.9786,0.989944,0.756038,0.593501
1000,0.9684,0.965357,0.763189,0.590435
1200,0.9561,0.968381,0.760306,0.589209
1400,0.9242,0.973402,0.76324,0.611281
1600,0.8987,0.974393,0.76808,0.595953
1800,0.8947,0.975489,0.765361,0.593501
2000,0.8874,0.975961,0.765113,0.599019


[I 2025-03-22 17:29:33,743] Trial 4 finished with value: 0.7651126932244133 and parameters: {'learning_rate': 1.2704348684768619e-05, 'per_device_train_batch_size': 16, 'weight_decay': 0.1, 'hidden_dropout_prob': 0.5}. Best is trial 3 with value: 0.7780183602503673.
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
Trying to set hidden_dropout_prob in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Auc,Accuracy
200,1.0836,1.018146,0.736551,0.574494


[I 2025-03-22 17:29:49,962] Trial 5 pruned. 
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
Trying to set hidden_dropout_prob in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Auc,Accuracy
200,1.1428,1.073447,0.668003,0.542612


[I 2025-03-22 17:30:06,208] Trial 6 pruned. 
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
Trying to set hidden_dropout_prob in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Auc,Accuracy
200,1.074,1.058464,0.679934,0.575107


[I 2025-03-22 17:30:31,414] Trial 7 pruned. 
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
Trying to set hidden_dropout_prob in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Auc,Accuracy
200,1.0038,0.98574,0.753253,0.601471
400,0.9024,0.988024,0.769412,0.574494


[I 2025-03-22 17:31:40,544] Trial 8 finished with value: 0.7694116988027743 and parameters: {'learning_rate': 4.892484444042487e-05, 'per_device_train_batch_size': 64, 'weight_decay': 0.1, 'hidden_dropout_prob': 0.2}. Best is trial 3 with value: 0.7780183602503673.
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
Trying to set hidden_dropout_prob in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Auc,Accuracy
200,1.0754,1.004695,0.721341,0.571429


[I 2025-03-22 17:31:56,816] Trial 9 pruned. 


Best hyperparameters: {'learning_rate': 2.866764812379614e-05, 'per_device_train_batch_size': 64, 'weight_decay': 0.07, 'hidden_dropout_prob': 0.30000000000000004}
Best run: BestRun(run_id='3', objective=0.7780183602503673, hyperparameters={'learning_rate': 2.866764812379614e-05, 'per_device_train_batch_size': 64, 'weight_decay': 0.07, 'hidden_dropout_prob': 0.30000000000000004}, run_summary=None)


In [34]:
from torch.utils.data import ConcatDataset
import shutil

# Combine training and validation datasets for final training.
final_dataset = ConcatDataset([train_dataset, val_dataset])

# Extract best hyperparameters from the hyperparameter search.
best_params = best_run.hyperparameters
print("Best hyperparameters:", best_params)

# Set up final training arguments using best hyperparameters.
# Disable evaluation because we're training on the complete dataset.
final_training_args = TrainingArguments(
    output_dir="./final_clinicalbert_model",
    eval_strategy="no",  # Disable evaluation
    learning_rate=best_params["learning_rate"],
    per_device_train_batch_size=best_params["per_device_train_batch_size"],
    per_device_eval_batch_size=best_params["per_device_train_batch_size"],
    num_train_epochs=fixed_epochs,       # Fixed at 5 epochs
    weight_decay=best_params["weight_decay"],
    warmup_steps=fixed_warmup_steps,       # Fixed warmup steps
    logging_dir="./logs_final",
    logging_strategy="steps",
    logging_steps=50,
    save_strategy="epoch",
    load_best_model_at_end=False,  # Not needed if evaluation is disabled
    report_to="none",
    fp16=torch.cuda.is_available(),
)

# Update model configuration to use the best dropout value.
final_config = AutoConfig.from_pretrained(
    model_checkpoint,
    num_labels=5,
    problem_type="single_label_classification",
)
final_config.hidden_dropout_prob = best_params["hidden_dropout_prob"]
final_config.attention_probs_dropout_prob = best_params["hidden_dropout_prob"]

# Initialize the final model with updated configuration.
final_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=final_config)

# Create a new Trainer for final training on the full dataset.
# Notice: We have removed the EarlyStoppingCallback since evaluation is disabled.
final_trainer = Trainer(
    model=final_model,
    args=final_training_args,
    train_dataset=final_dataset,
    compute_metrics=compute_metrics,  # Optional, not used if eval is disabled.
)

# Train the final model on the complete dataset.
final_trainer.train()

# Save the final model and tokenizer for later testing.
final_model.save_pretrained("./final_clinicalbert_model")
tokenizer.save_pretrained("./final_clinicalbert_model")
print("Final model and tokenizer saved to ./final_clinicalbert_model")

# Optionally, zip the final model directory for easy transfer.
shutil.make_archive("final_clinicalbert_model", 'zip', "./final_clinicalbert_model")
print("Final model folder zipped as final_clinicalbert_model.zip")

Best hyperparameters: {'learning_rate': 2.866764812379614e-05, 'per_device_train_batch_size': 64, 'weight_decay': 0.07, 'hidden_dropout_prob': 0.30000000000000004}


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss
50,1.4618
100,1.1769
150,1.1694
200,1.1278
250,1.0916
300,1.0619
350,0.9926
400,0.9931
450,0.9912
500,0.9778


Final model and tokenizer saved to ./final_clinicalbert_model
Final model folder zipped as final_clinicalbert_model.zip


In [41]:
# Load model and tokenizer (using your final model folder)
viz_model = AutoModel.from_pretrained("./final_clinicalbert_model", output_attentions=True)
viz_tokenizer = AutoTokenizer.from_pretrained("./final_clinicalbert_model")

# Choose a sample from the validation set.
sample_text_viz = val_texts[0]
# For a concise view, take a slice of the text (e.g. words 30 to 40)
truncated_text_viz = " ".join(sample_text_viz.split())

# Tokenize the sample text.
viz_inputs = viz_tokenizer.encode_plus(
    truncated_text_viz,
    return_tensors="pt",
    max_length=50,
    truncation=True,
)

# Run the model to get outputs including attentions.
viz_outputs = viz_model(**viz_inputs)
attentions = viz_outputs[-1]  # List of attention tensors (one per layer)

# Convert input IDs to tokens.
viz_tokens = viz_tokenizer.convert_ids_to_tokens(viz_inputs["input_ids"][0])

# Optionally, apply threshold filtering for visualization.
threshold = 0.018
filtered_attentions = [
    torch.where(layer_att > threshold, layer_att, torch.zeros_like(layer_att))
    for layer_att in attentions
]

# Visualize the filtered attention using BERTViz.
head_view(filtered_attentions, viz_tokens)

attentions_tensor = torch.stack(attentions, dim=0).squeeze(1)  # shape: [num_layers, num_heads, seq_len, seq_len]
avg_attentions = attentions_tensor.mean(dim=0).mean(dim=0)       # shape: [seq_len, seq_len]
token_importance = avg_attentions.sum(dim=0)  # shape: [seq_len]
importance_scores = token_importance.detach().cpu().numpy()
sorted_indices = np.argsort(importance_scores)[::-1]

print("\nToken importances (only tokens with importance > 0.5):")
for idx in sorted_indices:
    if importance_scores[idx] > 0.5:
        print(f"Token: {viz_tokens[idx]:15s} Importance: {importance_scores[idx]:.4f}")

<IPython.core.display.Javascript object>


Token importances (only tokens with importance > 0.5):
Token: [SEP]           Importance: 19.1947
Token: [CLS]           Importance: 2.4330
Token: vehicle         Importance: 1.1000
Token: headache        Importance: 1.0185
Token: arrived         Importance: 0.8910
Token: :               Importance: 0.8616
Token: .               Importance: 0.8435
Token: di              Importance: 0.7465
Token: injury          Importance: 0.7344
Token: following       Importance: 0.7341
Token: private         Importance: 0.7324
Token: pain            Importance: 0.7261
Token: ##igo           Importance: 0.7231
Token: reported        Importance: 0.7058
Token: female          Importance: 0.7050
Token: visit           Importance: 0.6737
Token: a               Importance: 0.6592
Token: head            Importance: 0.6545
Token: -               Importance: 0.6451
Token: pain            Importance: 0.6445
Token: complaint       Importance: 0.6220
Token: primary         Importance: 0.6161
Token: ed          

In [36]:
explainer = SequenceClassificationExplainer(final_model, tokenizer)
samples_by_label = defaultdict(list)
for text, label in zip(train_texts, train_labels):
    samples_by_label[label].append(text)
for label in sorted(samples_by_label.keys()):
    sample_text_expl = samples_by_label[label][1]  # choose one representative sample per label
    print(f"\n=== Triage Level {label} ===")
    print(f"Sample text:\n{sample_text_expl}\n")
    attributions = explainer(sample_text_expl)
    explainer.visualize()


=== Triage Level 0 ===
Sample text:
A 67.0-year-old female arrived by ambulance at the ED with an non-injury visit. The patient reported the following primary complaint(s): Convulsions. Recorded vital signs include temperature Blank, pulse 82.0, blood pressure 141.0/67.0, respiratory rate 16.0, and O₂ saturation 100.0%. Pain scale was noted as 0.0. The patient has a total of 0.0 chronic condition(s), including: none. Possible cause(s) related to this visit: no specific causes reported.



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
2.0,LABEL_2 (0.53),LABEL_2,4.69,"[CLS] a 67 . 0 - year - old female arrived by ambulance at the ed with an non - injury visit . the patient reported the following primary complaint ( s ) : con ##vu ##ls ##ions . recorded vital signs include temperature blank , pulse 82 . 0 , blood pressure 141 . 0 / 67 . 0 , respiratory rate 16 . 0 , and o ##₂ sat ##uration 100 . 0 % . pain scale was noted as 0 . 0 . the patient has a total of 0 . 0 chronic condition ( s ) , including : none . possible cause ( s ) related to this visit : no specific causes reported . [SEP]"
,,,,



=== Triage Level 1 ===
Sample text:
A 65.0-year-old male arrived by ambulance at the ED with an non-injury visit. The patient reported the following primary complaint(s): Other disease of circulatory system. Recorded vital signs include temperature 98.4, pulse 93.0, blood pressure 87.0/53.0, respiratory rate 18.0, and O₂ saturation 100.0%. Pain scale was noted as 0.0. The patient has a total of 0.0 chronic condition(s), including: none. Possible cause(s) related to this visit: no specific causes reported.



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,LABEL_1 (0.43),LABEL_1,4.55,"[CLS] a 65 . 0 - year - old male arrived by ambulance at the ed with an non - injury visit . the patient reported the following primary complaint ( s ) : other disease of c ##ir ##cula ##tory system . recorded vital signs include temperature 98 . 4 , pulse 93 . 0 , blood pressure 87 . 0 / 53 . 0 , respiratory rate 18 . 0 , and o ##₂ sat ##uration 100 . 0 % . pain scale was noted as 0 . 0 . the patient has a total of 0 . 0 chronic condition ( s ) , including : none . possible cause ( s ) related to this visit : no specific causes reported . [SEP]"
,,,,



=== Triage Level 2 ===
Sample text:
A 73.0-year-old female arrived by private vehicle at the ED with an non-injury visit. The patient reported the following primary complaint(s): Back pain, ache, soreness, discomfort, Postoperative visit. Recorded vital signs include temperature 99.0, pulse 71.0, blood pressure 134.0/76.0, respiratory rate 18.0, and O₂ saturation 99.0%. Pain scale was noted as 9.0. The patient has a total of 3.0 chronic condition(s), including: none. Possible cause(s) related to this visit: no specific causes reported.



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
3.0,LABEL_2 (0.43),LABEL_3,-2.54,"[CLS] a 73 . 0 - year - old female arrived by private vehicle at the ed with an non - injury visit . the patient reported the following primary complaint ( s ) : back pain , ache , sore ##ness , discomfort , post ##oper ##ative visit . recorded vital signs include temperature 99 . 0 , pulse 71 . 0 , blood pressure 134 . 0 / 76 . 0 , respiratory rate 18 . 0 , and o ##₂ sat ##uration 99 . 0 % . pain scale was noted as 9 . 0 . the patient has a total of 3 . 0 chronic condition ( s ) , including : none . possible cause ( s ) related to this visit : no specific causes reported . [SEP]"
,,,,



=== Triage Level 3 ===
Sample text:
A 11.0-year-old male arrived by private vehicle at the ED with an non-injury visit. The patient reported the following primary complaint(s): Earache, pain. Recorded vital signs include temperature 98.7, pulse 68.0, blood pressure 118.0/68.0, respiratory rate 21.0, and O₂ saturation 99.0%. Pain scale was noted as 0.0. The patient has a total of 0.0 chronic condition(s), including: none. Possible cause(s) related to this visit: no specific causes reported.



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
3.0,LABEL_3 (0.66),LABEL_3,6.1,"[CLS] a 11 . 0 - year - old male arrived by private vehicle at the ed with an non - injury visit . the patient reported the following primary complaint ( s ) : ear ##ache , pain . recorded vital signs include temperature 98 . 7 , pulse 68 . 0 , blood pressure 118 . 0 / 68 . 0 , respiratory rate 21 . 0 , and o ##₂ sat ##uration 99 . 0 % . pain scale was noted as 0 . 0 . the patient has a total of 0 . 0 chronic condition ( s ) , including : none . possible cause ( s ) related to this visit : no specific causes reported . [SEP]"
,,,,



=== Triage Level 4 ===
Sample text:
A 26.0-year-old female arrived by private vehicle at the ED with an non-injury visit. The patient reported the following primary complaint(s): Earache, pain. Recorded vital signs include temperature 97.9, pulse 95.0, blood pressure 156.0/75.0, respiratory rate 18.0, and O₂ saturation 99.0%. Pain scale was noted as 6.0. The patient has a total of 0.0 chronic condition(s), including: none. Possible cause(s) related to this visit: no specific causes reported.



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
3.0,LABEL_3 (0.60),LABEL_3,3.79,"[CLS] a 26 . 0 - year - old female arrived by private vehicle at the ed with an non - injury visit . the patient reported the following primary complaint ( s ) : ear ##ache , pain . recorded vital signs include temperature 97 . 9 , pulse 95 . 0 , blood pressure 156 . 0 / 75 . 0 , respiratory rate 18 . 0 , and o ##₂ sat ##uration 99 . 0 % . pain scale was noted as 6 . 0 . the patient has a total of 0 . 0 chronic condition ( s ) , including : none . possible cause ( s ) related to this visit : no specific causes reported . [SEP]"
,,,,
