# Clinical Text Classification with PyHealth

Welcome to the PyHealth tutorial on clinical text classification. In this notebook, we will explore how to use PyHealth to perform clinical natural language processing.

## Environment Setup

To begin, we need to install PyHealth and a few additional packages to support our analysis.

In [None]:
!pip install mne pandarallel rdkit transformers



In [None]:
!rm -rf PyHealth
!git clone -b zhenbang/f-image_text_support https://github.com/sunlabuiuc/PyHealth.git

Cloning into 'PyHealth'...
remote: Enumerating objects: 7110, done.[K
remote: Counting objects: 100% (1208/1208), done.[K
remote: Compressing objects: 100% (471/471), done.[K
remote: Total 7110 (delta 864), reused 940 (delta 731), pack-reused 5902[K
Receiving objects: 100% (7110/7110), 104.52 MiB | 21.86 MiB/s, done.
Resolving deltas: 100% (4631/4631), done.


In [None]:
import sys


sys.path.append("./PyHealth")

## Download Data

Next, we will download the clinical text dataset. Specifically, we will be using the medical transcription data scraped from mtsamples.com. This dataset includes the transcripted medical reports and the corresponding medical category. You can find more information about the dataset [here](https://www.kaggle.com/datasets/tboyle10/medicaltranscriptions).

In [None]:
!wget -N https://storage.googleapis.com/pyhealth/medical_transcriptions_data/MedicalTranscriptions.zip

--2023-07-16 00:44:02--  https://storage.googleapis.com/pyhealth/medical_transcriptions_data/MedicalTranscriptions.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 172.253.114.128, 172.253.119.128, 108.177.111.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|172.253.114.128|:443... connected.
HTTP request sent, awaiting response... 304 Not Modified
File ‘MedicalTranscriptions.zip’ not modified on server. Omitting download.



In [None]:
!unzip -q -o MedicalTranscriptions.zip

In [None]:
!ls -1 MedicalTranscriptions

mtsamples.csv


Next, we will proceed with the medical transcriptions classification task using PyHealth, following a five-stage pipeline.

## Step 1. Load Data in PyHealth

The initial step involves loading the data into PyHealth's internal structure. This process is straightforward: import the appropriate dataset class from PyHealth and specify the root directory where the raw dataset is stored. PyHealth will handle the dataset processing automatically.

In [None]:
from pyhealth.datasets import MedicalTranscriptionsDataset


root = "/content/MedicalTranscriptions"
base_dataset = MedicalTranscriptionsDataset(root)

Once the data is loaded, we can perform simple queries on the dataset.

In [None]:
base_dataset.stat()

Statistics of MedicalTranscriptionsDataset:
Number of samples: 4999
Number of classes: 40
Class distribution: Counter({' Surgery': 1103, ' Consult - History and Phy.': 516, ' Cardiovascular / Pulmonary': 372, ' Orthopedic': 355, ' Radiology': 273, ' General Medicine': 259, ' Gastroenterology': 230, ' Neurology': 223, ' SOAP / Chart / Progress Notes': 166, ' Obstetrics / Gynecology': 160, ' Urology': 158, ' Discharge Summary': 108, ' ENT - Otolaryngology': 98, ' Neurosurgery': 94, ' Hematology - Oncology': 90, ' Ophthalmology': 83, ' Nephrology': 81, ' Emergency Room Reports': 75, ' Pediatrics - Neonatal': 70, ' Pain Management': 62, ' Psychiatry / Psychology': 53, ' Office Notes': 51, ' Podiatry': 47, ' Dermatology': 29, ' Dentistry': 27, ' Cosmetic / Plastic Surgery': 27, ' Letters': 23, ' Physical Medicine - Rehab': 21, ' Sleep Medicine': 20, ' Endocrinology': 19, ' Bariatrics': 18, ' IME-QME-Work Comp etc.': 16, ' Chiropractic': 14, ' Rheumatology': 10, ' Diets and Nutritions': 10, 

In [None]:
base_dataset.patients[0]

{'description': ' A 23-year-old white female presents with complaint of allergies.',
 'medical_specialty': ' Allergy / Immunology',
 'sample_name': ' Allergic Rhinitis ',
 'transcription': 'SUBJECTIVE:,  This 23-year-old white female presents with complaint of allergies.  She used to have allergies when she lived in Seattle but she thinks they are worse here.  In the past, she has tried Claritin, and Zyrtec.  Both worked for short time but then seemed to lose effectiveness.  She has used Allegra also.  She used that last summer and she began using it again two weeks ago.  It does not appear to be working very well.  She has used over-the-counter sprays but no prescription nasal sprays.  She does have asthma but doest not require daily medication for this and does not think it is flaring up.,MEDICATIONS: , Her only medication currently is Ortho Tri-Cyclen and the Allegra.,ALLERGIES: , She has no known medicine allergies.,OBJECTIVE:,Vitals:  Weight was 130 pounds and blood pressure 124/7

## Step 2. Define the Task

The next step is to define the machine learning task. This step instructs the package to generate a list of samples with the desired features and labels based on the data for each individual patient. Please note that in this dataset, patient identification information is not available. Therefore, we will assume that each medical transcript belongs to a unique patient.

For this dataset, PyHealth offers a default task specifically for transcription classification. This task takes the transcription text as input and aims to predict the medical categories associated with it.

In [None]:
base_dataset.default_task

MedicalTranscriptionsClassification(task_name='MedicalTranscriptionsClassification', input_schema={'transcription': 'text'}, output_schema={'label': 'label'})

In [None]:
sample_dataset = base_dataset.set_task()

Generating samples for MedicalTranscriptionsClassification: 100%|██████████| 4999/4999 [00:00<00:00, 526764.29it/s]


Here is an example of a single sample, represented as a dictionary. The dictionary contains keys for feature names, label names, and other metadata associated with the sample.

In [None]:
sample_dataset[0]

{'transcription': 'SUBJECTIVE:,  This 23-year-old white female presents with complaint of allergies.  She used to have allergies when she lived in Seattle but she thinks they are worse here.  In the past, she has tried Claritin, and Zyrtec.  Both worked for short time but then seemed to lose effectiveness.  She has used Allegra also.  She used that last summer and she began using it again two weeks ago.  It does not appear to be working very well.  She has used over-the-counter sprays but no prescription nasal sprays.  She does have asthma but doest not require daily medication for this and does not think it is flaring up.,MEDICATIONS: , Her only medication currently is Ortho Tri-Cyclen and the Allegra.,ALLERGIES: , She has no known medicine allergies.,OBJECTIVE:,Vitals:  Weight was 130 pounds and blood pressure 124/78.,HEENT:  Her throat was mildly erythematous without exudate.  Nasal mucosa was erythematous and swollen.  Only clear drainage was seen.  TMs were clear.,Neck:  Supple wi

Finally, we will split the entire dataset into training, validation, and test sets using the ratios of 70%, 10%, and 20%, respectively. We will then obtain the corresponding data loaders for each set.

In [None]:
from pyhealth.datasets import split_by_sample


train_dataset, val_dataset, test_dataset = split_by_sample(
    dataset=sample_dataset,
    ratios=[0.7, 0.1, 0.2]
)

In [None]:
from pyhealth.datasets import get_dataloader


train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

## Step 3. Define the Model

Next, we will define the deep learning model we want to use for our task. PyHealth supports all major language models available in the Huggingface's Transformer package. You can load any of these models using the model_name argument.

In [None]:
from pyhealth.models import TransformersModel


model = TransformersModel(
    model_name="emilyalsentzer/Bio_ClinicalBERT",
    dataset=sample_dataset,
    feature_keys=["transcription"],
    label_key="label",
    mode="multiclass",
)

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
model

TransformersModel(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

## Step 4. Training

In this step, we will train the model using PyHealth's Trainer class, which simplifies the training process and provides standard functionalities.

In [None]:
from pyhealth.trainer import Trainer


trainer = Trainer(model=model)

TransformersModel(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

INFO:pyhealth.trainer:TransformersModel(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,),

Metrics: None


INFO:pyhealth.trainer:Metrics: None


Device: cuda


INFO:pyhealth.trainer:Device: cuda





INFO:pyhealth.trainer:


Before we begin training, let's first evaluate the initial performance of the model.

In [None]:
print(trainer.evaluate(test_dataloader))

Evaluation: 100%|██████████| 32/32 [00:15<00:00,  2.04it/s]

{'accuracy': 0.018108651911468814, 'f1_macro': 0.001017293997965412, 'f1_micro': 0.018108651911468814, 'loss': 3.722870334982872}





Now, let's start the training process. The trainer will automatically track the best model based on the metric you set to monitor (e.g., accuracy).

In [None]:
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=1,
    monitor="accuracy"
)

Training:


INFO:pyhealth.trainer:Training:


Batch size: 32


INFO:pyhealth.trainer:Batch size: 32


Optimizer: <class 'torch.optim.adam.Adam'>


INFO:pyhealth.trainer:Optimizer: <class 'torch.optim.adam.Adam'>


Optimizer params: {'lr': 0.001}


INFO:pyhealth.trainer:Optimizer params: {'lr': 0.001}


Weight decay: 0.0


INFO:pyhealth.trainer:Weight decay: 0.0


Max grad norm: None


INFO:pyhealth.trainer:Max grad norm: None


Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f8769161e40>


INFO:pyhealth.trainer:Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f8769161e40>


Monitor: accuracy


INFO:pyhealth.trainer:Monitor: accuracy


Monitor criterion: max


INFO:pyhealth.trainer:Monitor criterion: max


Epochs: 1


INFO:pyhealth.trainer:Epochs: 1





INFO:pyhealth.trainer:


Epoch 0 / 1:   0%|          | 0/109 [00:00<?, ?it/s]

--- Train epoch-0, step-109 ---


INFO:pyhealth.trainer:--- Train epoch-0, step-109 ---


loss: 3.1713


INFO:pyhealth.trainer:loss: 3.1713
Evaluation: 100%|██████████| 16/16 [00:07<00:00,  2.17it/s]

--- Eval epoch-0, step-109 ---



INFO:pyhealth.trainer:--- Eval epoch-0, step-109 ---


accuracy: 0.2540


INFO:pyhealth.trainer:accuracy: 0.2540


f1_macro: 0.0116


INFO:pyhealth.trainer:f1_macro: 0.0116


f1_micro: 0.2540


INFO:pyhealth.trainer:f1_micro: 0.2540


loss: 2.9753


INFO:pyhealth.trainer:loss: 2.9753


New best accuracy score (0.2540) at epoch-0, step-109


INFO:pyhealth.trainer:New best accuracy score (0.2540) at epoch-0, step-109


Loaded best model


INFO:pyhealth.trainer:Loaded best model


## Step 5. Evaluation

At the end of training, the trainer will automatically load the best save model weights. So that we can easily eavluate the ResNet model on the test set. This can be done using PyHealth's `Trainer.evaluate()` function.

In [None]:
print(trainer.evaluate(test_dataloader))

Evaluation: 100%|██████████| 32/32 [00:14<00:00,  2.14it/s]

{'accuracy': 0.22032193158953722, 'f1_macro': 0.010030228084638637, 'f1_micro': 0.22032193158953722, 'loss': 3.106539271771908}



