# Introduction

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In this notebook, I attempt to produce a classifier that maps medical transcripts to their associated medicla specialty. This is a multi-label text sequence classification problem. I outline my strategy below:

1. Inspect and understand data
2. Clean data
3. Pre-process data (ex: remove special characters, normalization)
4. Produce train, validation and test splits
5. Choose a tokenizer and a model
6. Tokenize the data
7. Instantiate model and training loop
8. Train
9. Evaluate training on validation and test sets.
10. Lock or re-adjust training strategy based on results

In [None]:
# pip install all required packages
!pip install torch
!pip install evaluate
!pip install datasets
!pip install transformers
!pip install accelerate -U

**1. Inspect Data**

In [None]:
# Imports
import re
import torch
import evaluate
import numpy as np
import pandas as pd

from tqdm.auto import tqdm
from functools import partial
from collections import Counter
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from nltk.stem import WordNetLemmatizer
from transformers import DataCollatorWithPadding
from datasets import load_dataset, ClassLabel, DatasetDict
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AdamW, get_scheduler, TrainingArguments, Trainer

# Use GPU if available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# LOAD DATA as a csv file
data = load_dataset("csv", data_files="/content/drive/MyDrive/Colab Notebooks/mtsamples.csv")

# We inspect our data by looking at a few fields of interest. Namely 'medical_specialty' and 'transcription.'
# 1. How many unique labels exist?
set(data['train']['medical_specialty'])

# A quick glance determines that there are around 40 unique labels.
# Our first focus is to determine what specialties are relevant for our classification problem.

In [74]:
Counter(data['train']['medical_specialty'])

Counter({' Allergy / Immunology': 7,
         ' Bariatrics': 18,
         ' Cardiovascular / Pulmonary': 372,
         ' Neurology': 223,
         ' Dentistry': 27,
         ' Urology': 158,
         ' General Medicine': 259,
         ' Surgery': 1103,
         ' Speech - Language': 9,
         ' SOAP / Chart / Progress Notes': 166,
         ' Sleep Medicine': 20,
         ' Rheumatology': 10,
         ' Radiology': 273,
         ' Psychiatry / Psychology': 53,
         ' Podiatry': 47,
         ' Physical Medicine - Rehab': 21,
         ' Pediatrics - Neonatal': 70,
         ' Pain Management': 62,
         ' Orthopedic': 355,
         ' Ophthalmology': 83,
         ' Office Notes': 51,
         ' Obstetrics / Gynecology': 160,
         ' Neurosurgery': 94,
         ' Nephrology': 81,
         ' Letters': 23,
         ' Lab Medicine - Pathology': 8,
         ' IME-QME-Work Comp etc.': 16,
         ' Hospice - Palliative Care': 6,
         ' Hematology - Oncology': 90,
         ' Gastr

**2. Clean data**

In [53]:
# From looking at the labels, I determine there are a handful that are not medical specialties, but fall
# into the category of 'notes'. We remove these labels:
notes_labels = [" Consult - History and Phy.",
         " SOAP / Chart / Progress Notes",
         " Discharge Summary",
         " Emergency Room Reports",
         " Office Notes",
         " Letters",
         " IME-QME-Work Comp etc."]

# I also noticed the specialty 'General Medicine' in the dataset was
# not in the strictest sense a specialty, since it could generally
# be attributed to many other 'specialties' (i.e a lot of specialties)
# are a subset of 'General Medicine.' As such I decided to exclude it as well.

# Additionally, I'd like to get a sense of the frequency of each specialty in the dataset
# Ideally we do NOT want to include specialties that have very few entries, since
# there simply wont be enough data to split for model training between training,validation and test.
# I choose 75 as the cutoff point, though this is a bit arbitrary.
labels = data['train'][:]['medical_specialty']
label_freq = Counter(labels)
fewer_than_75 = []

# Add specialties with fewer than 75 entries
for key, value in label_freq.items():
  if value < 75:
    fewer_than_75.append(key)

# Filter data according to a few properties:
# a. Filter out any specialties that belong to the notes_labels
# b. Filter any specialties with fewer than 75 entries
# c. Filter 'General Medicine'
# d. Filter empty entries
filtered_data = data.filter(lambda x: x["medical_specialty"] not in notes_labels
                            and x["medical_specialty"] not in fewer_than_75
                            and x["medical_specialty"] not in " General Medicine"
                            and x["medical_specialty"] is not None
                            and x['transcription'] is not None)

# Extra filter step to remove any transcriptions that are less than 10 words.
filtered_data = filtered_data.filter(lambda x: len(x["transcription"].split()) > 10)

**3. Pre-process**

In [54]:
# Now that the dataset is pruned, we must still pre-process our transcripts before tokenizing
# We define a mapping function that allows us to modify our fields of interest quickly (this is much faster than a for loop.)
# due to parallelization

def sanitize(example):
    # Modify the column to remove special characters from the transcriptions
    # Generally, special characters can make it more difficult for a classifier to
    # abstract features.
    example["transcription"] = re.sub(r'[^a-zA-Z0-9]', ' ', example["transcription"])
    return example

# Sanitize the data of special chars
filtered_data_sanitized = filtered_data.map(sanitize)

In [55]:
# Generate Labels for the data
# Here we need to assign a class label (a number) to each specialty.
# Obtain unique class labels
unique_labels = list(set(filtered_data_sanitized['train']['medical_specialty']))
class_labels = ClassLabel(names=unique_labels)

# Map original column to class labels
filtered_data_withlabel = filtered_data_sanitized.map(lambda x: {"labels": class_labels.str2int(x['medical_specialty'])})
filtered_data_withlabel = filtered_data_withlabel.class_encode_column("labels")

**4. Produce train, validation and test splits**

In [56]:
# We use the train_test_split functionality to create our splits
# Note: All our splits are stratified by our labels so that we have a proportional split
# of classes between our sets. This prevents, for example one or multiple classes from ending up
# in exlusively one split.

# We take about 70% of the data for training
train_dataset = filtered_data_withlabel["train"].train_test_split(train_size=0.7, stratify_by_column='labels', seed=40)

# We split the remaining 30% between validation and test
val_dataset = train_dataset["test"].train_test_split(test_size=0.5, stratify_by_column='labels', seed=40)
val_dataset["validation"] = val_dataset.pop("train")

# Combine datasets into one 'processed dataset'
combined_processed_data = DatasetDict({'train': train_dataset['train'],
                                       'validation': val_dataset['validation'],
                                       'test': val_dataset['test']})



**5. Tokenize the data**

In [57]:
# For this problem, I chose to utilize Bio_ClinicalBERT as a pre-trained model to finetune.
# Bio_ClinicalBERT is a good candidate given it has already been trained on a corpus
# of medical text. As such, it may be sufficiently suited
# for fine-tuning  on medical transcripts as a downstream task.

# Set model checkpoint and load model+ tokenizer.
checkpoint = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# Define a tokenization function.
# As before, we can use this to map each transcription to a tokenizer.
def tokenize_function(example):
    # Set max token length based on our
    return tokenizer(example['transcription'], truncation=True, max_length=512)

In [58]:
# Tokenize the dataset
tokenized_dataset = combined_processed_data.map(tokenize_function, batched=True)
# Remove columns that the trainer does not need. We only wish to have the tokenized transcriptions, class labels, and attention masks
tokenized_dataset = tokenized_dataset.remove_columns(["Unnamed: 0", "description", "medical_specialty", "sample_name", "transcription", "keywords"])
# Set format to torch in preparation for training
tokenized_dataset.set_format("torch")

# Here we define a DataCollator, which is an object that allows us to dynamically pad our input tensors.
# Since we have text squences of different lengths, we must pad all smaller sequences in a tensor to the largest available sequence
# before passing it to a model. While we can do this globally, it is often not a good approach since the longest sequence globally
# may be substantially larger than most other sequences. This would mean most sequences are heavily exetended with pad tokens,
# significaltly slowing down model training and the its ability to geenralize patterns (pad tokens offer no unique information.)
# As an alternative, we can pad batch wise instead, such that we can take advantange of vectorizing our inputs (passing batches is faster) while
# avoiding too much padding.
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# set batch_size
BATCH_SIZE = 16

# Create DataLoader objects that batch our data and prepare it for ingestion at the training or evaluation loop.
train_dataloader = DataLoader(tokenized_dataset["train"], shuffle=True, batch_size=BATCH_SIZE, collate_fn=data_collator)
eval_dataloader = DataLoader(tokenized_dataset["validation"], shuffle=True, batch_size=BATCH_SIZE, collate_fn=data_collator)
test_dataloader = DataLoader(tokenized_dataset["test"], shuffle=True, batch_size=BATCH_SIZE, collate_fn=data_collator)

Map:   0%|          | 0/2283 [00:00<?, ? examples/s]

Map:   0%|          | 0/489 [00:00<?, ? examples/s]

Map:   0%|          | 0/490 [00:00<?, ? examples/s]

**6. Instantiate model and training loop**

In [None]:
# Instantiate model
# Set number of unique labels.
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=13)
model.to(device)
# Instantiate an optimizer and a learning rate. Here, we use Adam which is a common optimizer.
optimizer = AdamW(model.parameters(), lr=5e-5)

# Set epoch number
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)

# Set lr scheduler. The lr scheduler dynamically adjusts our learning rate based
# on the gradient of the loss as a function of model weights
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)


**7. Train**

In [60]:
def compute_f1_score(model, dataloader):

  '''
  Computes macro (class average) F1 score

  Args:
  model: torch model we instantiated with checkpoint
  dataloader: The evaluation datalaoder which contains our validation set

  Returns:
  f1 score of type (float)
  '''

  # Keeps track of predictions and ground truths
  predictions = []
  ground_truth = []

  # loads each batch from the validation set
  for batch in dataloader:
      batch = {k: v.to(device) for k, v in batch.items()}
      # computes model output
      with torch.no_grad():
          outputs = model(**batch)
      # consolidates prediction
      logits = outputs.logits
      prediction = torch.argmax(logits, dim=-1)

      predictions.append(prediction)
      ground_truth.append(batch['labels'])
  # Flatten the batch of tensors to a single long tensor for each and then covert to list
  flattened_predictions = torch.cat(predictions, dim=0).tolist()
  flattened_ground_truth = torch.cat(ground_truth, dim=0).tolist()

  # Pass lists to f1 score metric to compute.
  precision, recall, f1, support = precision_recall_fscore_support(flattened_ground_truth, flattened_predictions, average="macro")
  return f1


In [None]:
# Keep track of loss
losses = []
validation_f1_scores = []

count = 0

# Track f1
best_f1 = 0.0
current_f1 = 0.0
eval_interval = 50

model.train()
progress_bar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
    for batch in train_dataloader:
        count += 1

        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        losses.append(loss)
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

        # Evaluate f1 periodically
        if count % eval_interval == 0:  # Evaluate for a set iterval
            model.eval()
            with torch.no_grad():
                validation_f1 = compute_f1_score(model, eval_dataloader)
                validation_f1_scores.append(validation_f1)
                print(f"Validation F1: " + str(validation_f1))

                # Save the model if the validation F1 improves
                if validation_f1 > best_f1:
                    best_f1 = validation_f1
                    torch.save(model.state_dict(), "/content/drive/MyDrive/Colab Notebooks/best_model.pth")

            model.train()

In [None]:
# Inspect losses and f1 scores:
losses_float = [tensor.item() for tensor in losses]
plt.plot(losses_float)

print(validation_f1_scores)

**8. Evaluate training on validation and test sets.**

In [76]:
# Start by laoding the model with th best macro F1 score:
model.load_state_dict(torch.load("/content/drive/MyDrive/Colab Notebooks/best_model.pth"))
model.eval()
# Compute model predictions for the validation dataset
predictions = []
ground_truth = []

for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    prediction = torch.argmax(logits, dim=-1)

    predictions.append(prediction)
    ground_truth.append(batch['labels'])

# Flatten the batch of tensors to a single long tensor for each and then covert to list
flattened_predictions = torch.cat(predictions, dim=0).tolist()
flattened_ground_truth = torch.cat(ground_truth, dim=0).tolist()

In [70]:
# Compute precision, accuracy and recall. We first look at the per class values
precision, recall, f1, support = precision_recall_fscore_support(flattened_ground_truth, flattened_predictions, average="weighted")

In [73]:
f1

0.37228258977225015

**10. Lock or re-adjust**

From our initial experiment, we reach an aggregate f1 score of 0.465. While this is not a high f1 score, class imbalance is likely a reason the classifier struggles to accurately predict. As an additional experiment, we run the same training but **upweight classes that are sparse**.



**Adjustment 1: Upweight classes**

In [None]:
# Set custom class weights

class_weights = [1,1,1.5,1,1.5,1.5,1,1,1.5,1,1,1.5,1]
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
class_weights

tensor([1.0000, 1.0000, 1.5000, 1.0000, 1.5000, 1.5000, 1.0000, 1.0000, 1.5000,
        1.0000, 1.0000, 1.5000, 1.0000], device='cuda:0')

In [None]:
# RE - RUN Training Loop with weighted classes
# Keep track of loss
losses = []
validation_f1_scores = []

count = 0

# Track f1
best_f1 = 0.0
current_f1 = 0.0
eval_interval = 50

model.train()
progress_bar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
    for batch in train_dataloader:
        count += 1

        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        # Weight our losses
        loss_weighted = loss * class_weights[batch['labels']].to(device)
        # Take mean of weighted loss
        loss_weighted = loss_weighted.mean()
        losses.append(loss_weighted)
        loss_weighted.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

        # Evaluate f1 periodically
        if count % eval_interval == 0:  # Evaluate for a set iterval
            model.eval()
            with torch.no_grad():
                validation_f1 = compute_f1_score(model, eval_dataloader)
                validation_f1_scores.append(validation_f1)
                print(f"Validation F1: " + str(validation_f1))

                # Save the model if the validation F1 improves
                if validation_f1 > best_f1:
                    best_f1 = validation_f1
                    torch.save(model.state_dict(), "/content/drive/MyDrive/Colab Notebooks/best_model.pth")

            model.train()

In [None]:
# Inspect f1 scores
validation_f1_scores

In [None]:
# RE EVALUATE predictions

# Load the model with th best macro F1 score:
model.load_state_dict(torch.load("/content/drive/MyDrive/Colab Notebooks/best_model.pth"))
model.eval()
# Compute model predictions for the validation dataset
predictions = []
ground_truth = []

for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    prediction = torch.argmax(logits, dim=-1)

    predictions.append(prediction)
    ground_truth.append(batch['labels'])

# Flatten the batch of tensors to a single long tensor for each and then covert to list
flattened_predictions = torch.cat(predictions, dim=0).tolist()
flattened_ground_truth = torch.cat(ground_truth, dim=0).tolist()

In [None]:
# Compute precision, accuracy and recall. We first look at the per class values
precision, recall, f1, support = precision_recall_fscore_support(flattened_ground_truth, flattened_predictions, average='weighted')

**Determination and further revision**

While custom weighting has improved the f1 score for underrepresented classes, it came at the cost of loosing performance in some of the majority classes.

I do think there is a set of weights that will maximize the f1 score, but I decided to try one more alternative change that did not involve fine-tuning. I inspect the distribution of the dataset and noticed each row has a 'keywords' field.

In [None]:
filtered_data_sanitized['train']['keywords'][0]

'cardiovascular / pulmonary, 2-d m-mode, doppler, aortic valve, atrial enlargement, diastolic function, ejection fraction, mitral, mitral valve, pericardial effusion, pulmonary valve, regurgitation, systolic function, tricuspid, tricuspid valve, normal lv '

One idea I have is to condense each transcription into a set of key words. This may make training the classifier a lot easier and faster. Many times, we may not actually have access to a set of keywords as present in the dataset. As such, I build my own medical corpus from the keywords column, and use it to identify and retain only the keywords in the transcripts:

**Adjustment 2: Building custom medical dictionary to condense transcriptions**

In [77]:
# 1. We remove any rows that have less than 5 keywords and are not empty
filtered_data_kwds = filtered_data_withlabel.filter(lambda x: x['keywords'] is not None)
filtered_data_kwds = filtered_data_kwds.filter(lambda x: len(x['keywords']) > 5)

# Sanitize the keywords of non-alphanumeric characters:
# Credit: https://stackoverflow.com/questions/68028334/replace-characters-other-than-a-za-z0-9-and-decimal-values-with-space-using-rege
def sanitize_keywords(example):
    # Modify the column to remove special characters from the transcriptions
    # Generally, special characters can make it more difficult for a classifier to
    # abstract features.
    example["keywords"] = re.sub(r'[^a-zA-Z0-9]', ' ', example["keywords"])
    return example

filtered_data_sanitized_kwds = filtered_data_kwds.map(sanitize_keywords)

all_keywords = [w.split() for w in filtered_data_sanitized_kwds['train']['keywords'] if w is not None]
all_keywords_flattened = [item for sublist in all_keywords for item in sublist]

In [None]:
# Inspect the keywords
medical_dict = set(all_keywords_flattened)
set(all_keywords_flattened)
# Note this dictionary is far from perfect, since some medical terms do in fact have special characters
# so when these are removed, it results in partitioned words. (Ex: bi-polar may become two unique medicla entities 'bi' and 'polar')
# Since the transcriptions have undergone the same sanitation, however, I believe such words will still be caught.

In [None]:
set(all_keywords_flattened)

In [10]:
# We now define a condense function that will take our medical transcriptions and only match words that are in the medical dictionary.
def condense(example, medical_dict):

  # This list ensures that we do not have duplicate words on the condensed sentence.
  appeared = []
  # This is our return value (i.e the condensed transcript)
  condensed_word = ""

  # isalpha ensures we are also removing numbers from the transcripts since they may not offer
  # any useful information for this classification problem.
  for word in example['transcription'].split():
    if word in medical_dict and word not in appeared and word.isalpha():
       condensed_word = condensed_word + " " + word
       appeared.append(word)


  example['transcription'] = condensed_word

  return example

mapping = partial(condense, medical_dict=medical_dict)
filtered_data_condensed_kwrds = filtered_data_sanitized_kwds.map(mapping)

Map:   0%|          | 0/2789 [00:00<?, ? examples/s]

In [None]:
# Inspect the data
filtered_data_condensed_kwrds['train']['transcription'][0]

# Great, we not have a condensed set transcriptions!
# We run the exact same model, but now we use the condensed transcripts instead.
# So as not to bloat this notebook. I ran the results by re-running some of the above cells post some modification.
# Those cells arent shown below, but the model resuts are in the pdf!