In [27]:
import pandas as pd
import torch
import string
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import accuracy_score
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, BertTokenizer

os.environ["WANDB_DISABLED"] = "true"

In [28]:
df = pd.read_csv("MLHC_train_classification_2.csv")

triage_mapping = {
    "Immediate.": 0,
    "Emergent.": 1,
    "Urgent.": 2,
    "Semi-urgent.": 3,
    "Non-urgent.": 4
}

df["triage_value"] = df["triage_level"].map(triage_mapping)
df.dropna(inplace=True)

In [29]:
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df["text_data"].tolist(),
    df["triage_value"].tolist(),
    test_size=0.2,
    random_state=42,
    stratify=df["triage_value"].tolist()
)


In [70]:
tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
max_length = df["text_data"].apply(len).max()
def tokenize_data(texts):
    return tokenizer(texts, padding="max_length", truncation=True, max_length=max_length)

In [71]:
train_encodings = tokenize_data(train_texts)
val_encodings = tokenize_data(val_texts)

In [72]:
class TriageDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

In [73]:
train_dataset = TriageDataset(train_encodings, train_labels)
val_dataset = TriageDataset(val_encodings, val_labels)

In [74]:
# Define compute_metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Get predicted class by taking the argmax over the logits
    predictions = np.argmax(logits, axis=1)
    accuracy = accuracy_score(labels, predictions)
    return {"accuracy": accuracy}

In [75]:
config = AutoConfig.from_pretrained(
    "emilyalsentzer/Bio_ClinicalBERT",
    num_labels=5,
    problem_type="single_label_classification",
    hidden_dropout_prob=0.3,
    attention_probs_dropout_prob=0.3,

)

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

model = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", config=config)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [76]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.3, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.3, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [77]:
training_args = TrainingArguments(
    output_dir="./clinicalbert_triage",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size= 32,
    num_train_epochs=10,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    report_to="none",
)



In [78]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.0053,0.943734,0.571699
2,0.8723,0.944764,0.585597
3,0.9419,0.903163,0.576121
4,0.9623,0.903913,0.590651
5,0.8476,0.908695,0.581807
6,0.884,0.910796,0.608339
7,0.9131,0.908022,0.585597
8,0.8048,0.915876,0.596336
9,0.8565,0.92271,0.603917
10,0.8549,0.922854,0.596968


TrainOutput(global_step=1980, training_loss=0.8897822404148603, metrics={'train_runtime': 863.9632, 'train_samples_per_second': 73.255, 'train_steps_per_second': 2.292, 'total_flos': 1.092836537398368e+16, 'train_loss': 0.8897822404148603, 'epoch': 10.0})

In [79]:
# Evaluate on validation set
results = trainer.evaluate()
print(results)

{'eval_loss': 0.9228537082672119, 'eval_accuracy': 0.5969677826910929, 'eval_runtime': 6.2999, 'eval_samples_per_second': 251.272, 'eval_steps_per_second': 7.937, 'epoch': 10.0}


In [81]:
from sklearn.metrics import accuracy_score

# Get model predictions on validation dataset
preds_output = trainer.predict(val_dataset)

# Convert logits to class predictions
preds = np.argmax(preds_output.predictions, axis=1)

# Compute accuracy
accuracy = accuracy_score(val_labels, preds)
print(f"Model Accuracy: {accuracy:.4f}")

Model Accuracy: 0.5970


In [82]:
import shutil

# Define the model save path
model_save_path = "./clinicalbert_triage_model"

# Save the fine-tuned model and tokenizer
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)

# Save training arguments
import json
training_args_dict = training_args.to_dict()
with open(f"{model_save_path}/training_args.json", "w") as f:
    json.dump(training_args_dict, f)

print(f"Model and tokenizer saved to {model_save_path}")

# Zip the model folder for easy transfer
shutil.make_archive("clinicalbert_triage_model", 'zip', model_save_path)
print("Model folder zipped for transfer.")

Model and tokenizer saved to ./clinicalbert_triage_model
Model folder zipped for transfer.
