In [1]:
import pandas as pd
import torch
import string
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments

os.environ["WANDB_DISABLED"] = "true"  # Disable Weights & Biases logging

In [2]:
# Load dataset
df = pd.read_csv("MLHC_train_classification_2.csv")

# Text cleaning function
def clean_text(text):
    text = text.lower().translate(str.maketrans("", "", string.punctuation))  # Lowercase & remove punctuation
    text = text.replace(" ", "")  # Remove spaces
    return text

# Mapping triage levels to numbers
triage_mapping = {
    "immediate": 1,
    "emergent": 2,
    "urgent": 3,
    "semiurgent": 4,
    "nonurgent": 5
}

df["triage_value"] = df["triage_level"].astype(str).apply(clean_text).map(triage_mapping)
df.dropna(inplace=True)

In [3]:

# Split into train and validation sets
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df["text_data"].tolist(), df["triage_value"].tolist(), test_size=0.2, random_state=42
)

In [4]:
# Load ClinicalBERT tokenizer
tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

# Tokenization function
def tokenize_data(texts):
    return tokenizer(texts, padding="max_length", truncation=True, max_length=512)


In [5]:
# Tokenizing train and validation texts
train_encodings = tokenize_data(train_texts)
val_encodings = tokenize_data(val_texts)

In [6]:
# Define a PyTorch Dataset class
class TriageDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])  # Ensure labels are tensors
        return item

In [7]:
# Convert to PyTorch dataset format
train_dataset = TriageDataset(train_encodings, train_labels)
val_dataset = TriageDataset(val_encodings, val_labels)

In [9]:
# Set device (MPS for Apple Silicon, CUDA for Nvidia GPUs, or CPU)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# Load ClinicalBERT model for classification
model = BertForSequenceClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=5)
model.to(device)  # Move model to the correct device

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./clinicalbert_triage",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    report_to="none",
)



In [11]:
# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,  # Ensure this is a PyTorch Dataset
    tokenizer=tokenizer  # Add tokenizer for logging purposes
)

# Train the model
trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,0.9993,0.879425
2,0.8015,0.875104
3,0.7074,0.877757


TrainOutput(global_step=2448, training_loss=0.8639618810874964, metrics={'train_runtime': 1556.5305, 'train_samples_per_second': 12.572, 'train_steps_per_second': 1.573, 'total_flos': 5148958929878016.0, 'train_loss': 0.8639618810874964, 'epoch': 3.0})

In [12]:
# Evaluate on validation set
results = trainer.evaluate()
print(results)

{'eval_loss': 0.8777569532394409, 'eval_runtime': 34.2217, 'eval_samples_per_second': 47.66, 'eval_steps_per_second': 5.961, 'epoch': 3.0}


In [13]:
# Function to predict triage level for new patient cases
def predict_triage(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
    inputs = {key: val.to(device) for key, val in inputs.items()}  # Move input tensors to the correct device
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits).item()
    return predicted_class + 1  # Convert 0-4 back to triage level 1-5

In [14]:
from sklearn.metrics import accuracy_score

# Get model predictions on validation dataset
preds_output = trainer.predict(val_dataset)

# Convert logits to class predictions
preds = np.argmax(preds_output.predictions, axis=1)

# Compute accuracy
accuracy = accuracy_score(val_labels, preds)
print(f"Model Accuracy: {accuracy:.4f}")

Model Accuracy: 0.5868


In [15]:
import shutil

# Define the model save path
model_save_path = "./clinicalbert_triage_model"

# Save the fine-tuned model and tokenizer
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)

# Save training arguments
import json
training_args_dict = training_args.to_dict()
with open(f"{model_save_path}/training_args.json", "w") as f:
    json.dump(training_args_dict, f)

print(f"Model and tokenizer saved to {model_save_path}")

# Zip the model folder for easy transfer
shutil.make_archive("clinicalbert_triage_model", 'zip', model_save_path)
print("Model folder zipped for transfer.")

Model and tokenizer saved to ./clinicalbert_triage_model
Model folder zipped for transfer.
