In [None]:
import transformers
from transformers import LongformerTokenizerFast, LongformerForSequenceClassification, BertTokenizer, AdamW, get_linear_schedule_with_warmup
import torch

import numpy as np
import pandas as pd
import re
from tqdm.auto import tqdm

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.metrics import multilabel_confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap

from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from torchmetrics.functional import accuracy, auroc
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

#matplotlib setting/format specifications

%matplotlib inline
%config InlineBackend.figure_format='retina'

sns.set(style='whitegrid', palette='muted', font_scale=1.2)

HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]

sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))

#Controls size of figures we create
rcParams['figure.figsize'] = 8, 6

#set random seed

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7faeb9b9dbb0>

In [None]:
sw_notes=pd.read_csv('Final_SWmerged.csv')

In [None]:
def clean_text(df):
  text = str(df["TEXT"])
  clean = re.sub(r"\n","",text)
  cleaner = re.sub(r"  ","",clean)

  return cleaner

In [None]:
sw_notes["TEXT_CLEAN"] = sw_notes.apply(clean_text, axis=1)

In [None]:
LABEL_COLUMNS = sw_notes.columns.tolist()[11:15]

## Dataset Class

In [None]:
#Define tokenizer
tokenizer = LongformerTokenizerFast.from_pretrained("yikuan8/Clinical-Longformer",max_length=1000)

Downloading (…)okenizer_config.json:   0%|          | 0.00/347 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

In [None]:
class SWNotesDataset(Dataset):

  def __init__(self, data: pd.DataFrame, tokenizer:BertTokenizer, max_len=512):
    self.data = data
    self.tokenizer = tokenizer
    self.max_len = max_len

  def __len__(self):
    return(len(self.data))

  def __getitem__(self,index: int):
    data_row = self.data.iloc[index]
    note = data_row.TEXT_CLEAN
    labels = data_row[LABEL_COLUMNS]

    encoding=self.tokenizer.encode_plus(
        note,
        add_special_tokens=True,
        truncation=True,
        max_length=self.max_len,
        return_token_type_ids=False,
        padding= 'max_length',
        return_attention_mask=True,
        return_tensors='pt',
    )

    return {
        'SW_note': note,
        'input_ids': encoding['input_ids'].flatten(),
        'attention_mask': encoding['attention_mask'].flatten(),
        'labels': torch.FloatTensor(labels)
    }

In [None]:
#Split data into 5 folds

kf = KFold(n_splits=5, shuffle=True,random_state=RANDOM_SEED)
folds = kf.split(sw_notes)

trn_idx1,tst_idx1 = next(folds)
trn_idx2,tst_idx2 = next(folds)
trn_idx3,tst_idx3 = next(folds)
trn_idx4,tst_idx4 = next(folds)
trn_idx5,tst_idx5 = next(folds)

In [None]:
df_train1 = sw_notes.iloc[trn_idx1]
df_train2 = sw_notes.iloc[trn_idx2]
df_train3 = sw_notes.iloc[trn_idx3]
df_train4 = sw_notes.iloc[trn_idx4]
df_train5 = sw_notes.iloc[trn_idx5]

df_test1 = sw_notes.iloc[tst_idx1]
df_test2 = sw_notes.iloc[tst_idx2]
df_test3 = sw_notes.iloc[tst_idx3]
df_test4 = sw_notes.iloc[tst_idx4]
df_test5 = sw_notes.iloc[tst_idx5]

In [None]:
class SWNotesDataModule(pl.LightningDataModule):
  def __init__(self, df_train, df_test, tokenizer, batch_size=8, max_len=512):
    super().__init__()
    self.batch_size = batch_size
    self.df_train = df_train
    self.df_test = df_test
    self.tokenizer = tokenizer
    self.max_len = max_len

  def setup(self, stage=None):

    self.train_dataset = SWNotesDataset(
      self.df_train,
      self.tokenizer,
      self.max_len
    )

    self.test_dataset = SWNotesDataset(
      self.df_test,
      self.tokenizer,
      self.max_len
    )


  def train_dataloader(self):
    return DataLoader(
      self.train_dataset,
      batch_size=self.batch_size,
      shuffle=True,
      num_workers=2
    )

  def val_dataloader(self):
    return DataLoader(
      self.test_dataset,
      batch_size=self.batch_size,
      num_workers=2
    )

  def test_dataloader(self):
    return DataLoader(
      self.test_dataset,
      batch_size=self.batch_size,
      num_workers=2
    )



In [None]:
N_EPOCHS = 10
BATCH_SIZE = 4
MAX_LEN = 1000

In [None]:
#Fold changes each run
data_module = SWNotesDataModule(
  df_train5,
  df_test5,
  tokenizer,
  batch_size=BATCH_SIZE,
  max_len=MAX_LEN
)

## Building Classification Model

In [None]:
#Pytorch lightning method

class SWRoleClassifier(pl.LightningModule):

  def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
    super().__init__()
    self.bert = LongformerForSequenceClassification.from_pretrained("yikuan8/Clinical-Longformer", return_dict=True,num_labels = n_classes)
    self.n_training_steps = n_training_steps
    self.n_warmup_steps = n_warmup_steps
    self.criterion = nn.BCELoss()

  def forward(self, input_ids, attention_mask, labels=None):
    output = self.bert(input_ids, attention_mask=attention_mask)
    output = torch.sigmoid(output.logits)
    loss = 0
    if labels is not None:
        loss = self.criterion(output, labels)
    return loss, output

  def training_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return {"loss": loss, "predictions": outputs, "labels": labels}

  def validation_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("val_loss", loss, prog_bar=True, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("test_loss", loss, prog_bar=True, logger=True)
    return loss

  def training_epoch_end(self, outputs):
    labels = []
    predictions = []
    for output in outputs:
      for out_labels in output["labels"].detach().cpu():
        labels.append(out_labels)
      for out_predictions in output["predictions"].detach().cpu():
        predictions.append(out_predictions)

    labels = torch.stack(labels).int()
    predictions = torch.stack(predictions)

  def configure_optimizers(self):
    optimizer = AdamW(self.parameters(), lr=2e-5)
    scheduler = get_linear_schedule_with_warmup(
      optimizer,
      num_warmup_steps=self.n_warmup_steps,
      num_training_steps=self.n_training_steps
    )

    return dict(
      optimizer=optimizer,
      lr_scheduler=dict(
        scheduler=scheduler,
        interval='step'
      )
    )

In [None]:
#Calculate warm-up and total steps
steps_per_epoch=len(df_train1) // BATCH_SIZE
total_training_steps = steps_per_epoch * N_EPOCHS

warmup_steps = total_training_steps // 5
warmup_steps, total_training_steps

(502, 2510)

In [None]:
model = SWRoleClassifier(
  n_classes=len(LABEL_COLUMNS),
  n_warmup_steps=warmup_steps,
  n_training_steps=total_training_steps
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/929 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/595M [00:00<?, ?B/s]

Some weights of the model checkpoint at yikuan8/Clinical-Longformer were not used when initializing LongformerForSequenceClassification: ['longformer.embeddings.position_ids', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing LongformerForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LongformerForSequenceClassification were not initialized from the model checkpoint at yikuan8/Clinical-Longformer and are newly initial

## Training

In [None]:
logger = TensorBoardLogger("lightning_logs", name="SWroles2")

early_stopping_callback = EarlyStopping(monitor='val_loss', patience=3)

In [None]:
trainer = pl.Trainer(
  logger=logger,
  log_every_n_steps=40,
  callbacks=[early_stopping_callback],
  max_epochs=N_EPOCHS,
  gpus=1
)

  rank_zero_deprecation(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
#Training step
trainer.fit(model, data_module)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type                                | Params
------------------------------------------------------------------
0 | bert      | LongformerForSequenceClassification | 148 M 
1 | criterion | BCELoss                             | 0     
------------------------------------------------------------------
148 M     Trainable params
0         Non-trainable params
148 M     Total params
594.650   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

## Evaluation

In [None]:
#Set model to eval mode

trained_model = SWRoleClassifier.load_from_checkpoint(
  trainer.checkpoint_callback.best_model_path,
  n_classes=len(LABEL_COLUMNS)
)

trained_model.eval()
trained_model.freeze()

trained_model = trained_model.to(device)

Some weights of the model checkpoint at yikuan8/Clinical-Longformer were not used when initializing LongformerForSequenceClassification: ['longformer.embeddings.position_ids', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing LongformerForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LongformerForSequenceClassification were not initialized from the model checkpoint at yikuan8/Clinical-Longformer and are newly initial

In [None]:
#Change fold each run
test_data = SWNotesDataset(
    data = df_test5,
    tokenizer=tokenizer,
    max_len=MAX_LEN
  )

predictions = []
labels = []

In [None]:
for item in tqdm(test_data):
  _, prediction = trained_model(
    item["input_ids"].unsqueeze(dim=0).to(device),
    item["attention_mask"].unsqueeze(dim=0).to(device)
  )
  predictions.append(prediction.flatten())
  labels.append(item["labels"].int())

  0%|          | 0/251 [00:00<?, ?it/s]

In [None]:
predictions = torch.stack(predictions).detach().cpu()
labels = torch.stack(labels).detach().cpu()

In [None]:
#Fold 1
y_pred = predictions.numpy()
y_true = labels.numpy()

upper, lower = 1, 0

y_pred = np.where(y_pred > 0.5, upper, lower)

print(classification_report(
  y_true,
  y_pred,
  target_names=LABEL_COLUMNS,
  zero_division=0
))

                            precision    recall  f1-score   support

          Involved_Support       0.93      0.94      0.93        94
Communication_Facilitation       0.70      0.92      0.79        59
                Counseling       0.63      0.63      0.63        35
      Practical_Assistance       0.69      0.75      0.72        57

                 micro avg       0.77      0.84      0.81       245
                 macro avg       0.74      0.81      0.77       245
              weighted avg       0.78      0.84      0.81       245
               samples avg       0.51      0.53      0.51       245



In [None]:
#Fold 2
y_pred = predictions.numpy()
y_true = labels.numpy()

upper, lower = 1, 0

y_pred = np.where(y_pred > 0.5, upper, lower)

print(classification_report(
  y_true,
  y_pred,
  target_names=LABEL_COLUMNS,
  zero_division=0
))



                            precision    recall  f1-score   support

          Involved_Support       0.92      0.91      0.91        86
Communication_Facilitation       0.75      0.82      0.79        56
                Counseling       0.67      0.76      0.71        29
      Practical_Assistance       0.82      0.79      0.80        57

                 micro avg       0.82      0.84      0.83       228
                 macro avg       0.79      0.82      0.80       228
              weighted avg       0.82      0.84      0.83       228
               samples avg       0.51      0.52      0.50       228



In [None]:
#Fold 3
y_pred = predictions.numpy()
y_true = labels.numpy()

upper, lower = 1, 0

y_pred = np.where(y_pred > 0.5, upper, lower)

print(classification_report(
  y_true,
  y_pred,
  target_names=LABEL_COLUMNS,
  zero_division=0
))


                            precision    recall  f1-score   support

          Involved_Support       0.86      0.93      0.89        90
Communication_Facilitation       0.67      0.78      0.72        68
                Counseling       0.76      0.71      0.74        35
      Practical_Assistance       0.85      0.69      0.77        59

                 micro avg       0.79      0.81      0.80       252
                 macro avg       0.78      0.78      0.78       252
              weighted avg       0.79      0.81      0.80       252
               samples avg       0.53      0.53      0.52       252



In [None]:
#Fold 4
y_pred = predictions.numpy()
y_true = labels.numpy()

upper, lower = 1, 0

y_pred = np.where(y_pred > 0.5, upper, lower)

print(classification_report(
  y_true,
  y_pred,
  target_names=LABEL_COLUMNS,
  zero_division=0
))


                            precision    recall  f1-score   support

          Involved_Support       0.91      0.96      0.93        90
Communication_Facilitation       0.64      0.73      0.68        51
                Counseling       0.88      0.74      0.80        38
      Practical_Assistance       0.89      0.69      0.78        58

                 micro avg       0.83      0.81      0.82       237
                 macro avg       0.83      0.78      0.80       237
              weighted avg       0.84      0.81      0.82       237
               samples avg       0.51      0.50      0.49       237



In [None]:
#Fold 5
y_pred = predictions.numpy()
y_true = labels.numpy()

upper, lower = 1, 0

y_pred = np.where(y_pred > 0.5, upper, lower)

print(classification_report(
  y_true,
  y_pred,
  target_names=LABEL_COLUMNS,
  zero_division=0
))


                            precision    recall  f1-score   support

          Involved_Support       0.86      0.89      0.88        76
Communication_Facilitation       0.69      0.95      0.80        55
                Counseling       0.59      0.73      0.65        22
      Practical_Assistance       0.54      0.81      0.65        42

                 micro avg       0.70      0.87      0.77       195
                 macro avg       0.67      0.84      0.74       195
              weighted avg       0.71      0.87      0.78       195
               samples avg       0.43      0.45      0.43       195

