# Run clinicalBERT with pytorch
- For understanding, this note simply converts [clinicalBERT](https://github.com/EmilyAlsentzer/clinicalBERT) to pytorch 
- clinicalBERT huggingface model can be downloaded [here](https://huggingface.co/emilyalsentzer/Bio_ClinicalBERT).
- Run for 2014 i2b2 de-identification challenge.

## About preprocessed clinical notes
- In this code, preprocessed clinical notes are used.
- The **preprocessed clinical notes** are as follows:
        Identification  O
        of  O
        APC2    O
        ,   O
        a   O
        homologue   O
        of  O
        the O
        adenomatous B-Disease
        polyposis   I-Disease
        coli    I-Disease
        tumour  I-Disease
        suppressor  O
        .   O

        The O
        adenomatous B-Disease
        polyposis   I-Disease
- Already tokenized by space with ground-truth labels(0, B-Disease etc) and each sentence is seperated by blank.
- This process can be done with original clinicalBERT repo `clinicalBERT/downstream_tasks/i2b2_preprocessing/i2b2_2014_deid_hf_risk/`

In [3]:
from keras.preprocessing.sequence import pad_sequences
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AdamW, get_linear_schedule_with_warmup
from seqeval.metrics import f1_score, accuracy_score
import numpy as np
import itertools

In [1]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

### Get label list
- This label is for the 2014 challenge.

In [4]:
label_list = ["B-IDNUM", "I-IDNUM", "B-HOSPITAL", "I-HOSPITAL", 'B-PATIENT', 'I-PATIENT', 'B-PHONE', 'I-PHONE',
        'B-DATE', 'I-DATE', 'B-DOCTOR', 'I-DOCTOR', 'B-LOCATION-OTHER', 'I-LOCATION-OTHER', 'B-AGE', 'I-AGE', 'B-BIOID', 'I-BIOID',
        'B-STATE', 'I-STATE','B-ZIP', 'I-ZIP', 'B-HEALTHPLAN', 'I-HEALTHPLAN', 'B-ORGANIZATION', 'I-ORGANIZATION',
        'B-MEDICALRECORD', 'I-MEDICALRECORD', 'B-CITY', 'I-CITY', 'B-STREET', 'I-STREET', 'B-COUNTRY', 'I-COUNTRY',
        'B-URL', 'I-URL',
        'B-USERNAME', 'I-USERNAME', 'B-PROFESSION', 'I-PROFESSION', 'B-FAX', 'I-FAX', 'B-EMAIL', 'I-EMAIL', 'B-DEVICE', 'I-DEVICE',
        'O', "X", "[CLS]", "[SEP]"]
label_list.append("PAD")
tag2idx = {t: i for i, t in enumerate(list(set(label_list)))}

### Create data iterator 
- `read_data` converts preprocessed clinical notes to `[[token1, token2, token3], [label1, label2, label3]]` format.
- This function is from the original clinicalBERT.

In [1]:
def read_data(input_file):

        with open(input_file) as f:
            lines = []
            words = []
            labels = []
            for line in f:
                line = line.strip()
                if len(line) == 0: #i.e. we're in between sentences
                    assert len(words) == len(labels)
                    if len(words) == 0:
                        continue
                    lines.append([words, labels])
                    words = []
                    labels = []
                    continue

                word = line.split()[0]
                label = line.split()[-1]
                words.append(word)
                labels.append(label)

            return lines

In [6]:
train = 'preprocessed/train.tsv'
valid = 'preprocessed/dev.tsv'

In [7]:
tr_tokenized_text_and_labels = read_data(train)
val_tokenized_text_and_labels = read_data(valid)

In [8]:
print(tr_tokenized_text_and_labels[:1])

[[['Record', 'date:', '2077-05-25'], ['O', 'O', 'B-DATE']]]


In [9]:
tr_sentences = [token_label_pair[0] for token_label_pair in tr_tokenized_text_and_labels]
tr_labels = [token_label_pair[1] for token_label_pair in tr_tokenized_text_and_labels]
val_sentences = [token_label_pair[0] for token_label_pair in val_tokenized_text_and_labels]
val_labels = [token_label_pair[1] for token_label_pair in val_tokenized_text_and_labels]

In [10]:
print(tr_sentences[:1])
print(tr_labels[:1])

[['Record', 'date:', '2077-05-25']]
[['O', 'O', 'B-DATE']]


In [11]:
tr_sentences = list(itertools.chain(*tr_sentences))
tr_labels = list(itertools.chain(*tr_labels))
val_sentences = list(itertools.chain(*val_sentences))
val_labels = list(itertools.chain(*val_labels)) # flatten list

In [12]:
print(tr_sentences[:1])
print(tr_labels[:1])

['Record']
['O']


In [13]:
def tokenize_and_preserve_labels(sentence, text_labels):

    tokenized_sentence = []
    labels = []

    tokenized_word = tokenizer.tokenize(sentence)
    n_subwords = len(tokenized_word)

    tokenized_sentence.extend(tokenized_word)

    labels.extend([text_labels] * n_subwords)

    return tokenized_sentence, labels

In [14]:
tr_tokenized_texts_and_labels = [
    tokenize_and_preserve_labels(sent, labs)
    for sent, labs in zip(tr_sentences, tr_labels)
]
val_tokenized_texts_and_labels = [
    tokenize_and_preserve_labels(sent, labs)
    for sent, labs in zip(val_sentences, val_labels)
]

In [15]:
tr_tokenized_texts_and_labels[:5]

[(['record'], ['O']),
 (['date', ':'], ['O', 'O']),
 (['207', '##7', '-', '05', '-', '25'],
  ['B-DATE', 'B-DATE', 'B-DATE', 'B-DATE', 'B-DATE', 'B-DATE']),
 (['internal'], ['O']),
 (['medicine'], ['O'])]

In [16]:
tr_tokenized_texts = [token_label_pair[0] for token_label_pair in tr_tokenized_texts_and_labels]
tr_labels = [token_label_pair[1] for token_label_pair in tr_tokenized_texts_and_labels]
val_tokenized_texts = [token_label_pair[0] for token_label_pair in val_tokenized_texts_and_labels]
val_labels = [token_label_pair[1] for token_label_pair in val_tokenized_texts_and_labels]

In [17]:
print(tr_tokenized_texts[:3])
print(tr_labels[:3])

[['record'], ['date', ':'], ['207', '##7', '-', '05', '-', '25']]
[['O'], ['O', 'O'], ['B-DATE', 'B-DATE', 'B-DATE', 'B-DATE', 'B-DATE', 'B-DATE']]


In [18]:
from keras.preprocessing.sequence import pad_sequences

max_len = 128
bs = 16

tr_inputs = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tr_tokenized_texts],
                          maxlen=max_len, dtype="long", value=0.0,
                          truncating="post", padding="post")
tr_tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in tr_labels],
                     maxlen=max_len, value=tag2idx["PAD"], padding="post",
                     dtype="long", truncating="post")
tr_masks = [[float(i != 0.0) for i in ii] for ii in tr_inputs]

val_inputs = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in val_tokenized_texts],
                          maxlen=max_len, dtype="long", value=0.0,
                          truncating="post", padding="post")
val_tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in val_labels],
                     maxlen=max_len, value=tag2idx["PAD"], padding="post",
                     dtype="long", truncating="post")
val_masks = [[float(i != 0.0) for i in ii] for ii in val_inputs]


In [19]:
tr_inputs = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tr_tokenized_texts],
                          maxlen=max_len, dtype="long", value=0.0,
                          truncating="post", padding="post")
tr_tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in tr_labels],
                     maxlen=max_len, value=tag2idx["PAD"], padding="post",
                     dtype="long", truncating="post")

In [20]:
tr_inputs = torch.tensor(tr_inputs)
val_inputs = torch.tensor(val_inputs)
tr_tags = torch.tensor(tr_tags)
val_tags = torch.tensor(val_tags)
tr_masks = torch.tensor(tr_masks)
val_masks = torch.tensor(val_masks)

In [21]:
train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)

valid_data = TensorDataset(val_inputs, val_masks, val_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=bs)

In [22]:
len(train_dataloader)

27433

### Load model
- There should be an label information(`num_labels`) when loading the model to run appropriately.
- Because the clinicalBERT contains various types of pre-trained parameters for i2b2 challenge.

In [23]:
model = AutoModelForTokenClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", 
                                                        num_labels=len(tag2idx),
                                                        output_attentions = False,
                                                        output_hidden_states = False)

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint 

## For fine-tuning the model

In [None]:
FULL_FINETUNING = True
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]

optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=5e-5,
    betas=(0.9, 0.999),
    eps=1e-6,
    weight_decay=0 # weight_decay & wight_decay_rate
)


In [25]:
epochs = 4
max_grad_norm = 1.0

# Total number of training steps is number of batches * number of epochs.
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tag_values = list(set(label_list))
use_gpu = torch.cuda.is_available()
print(use_gpu)

True


In [27]:
 CUDA_LAUNCH_BLOCKING=1

In [28]:
model.cuda()

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

### Training

In [30]:
from tqdm import tqdm, trange

loss_values, validation_loss_values = [], [] ## Store the average loss if you want to plot them.
for _ in trange(epochs, desc="Epoch"):
    # ========================================
    #               Training
    # ========================================
    model.train()
    total_loss = 0 # Reset the total loss for this epoch.

    for step, batch in enumerate(train_dataloader):

        batch = tuple(t.to(device) for t in batch) # add batch to gpu
        b_input_ids, b_input_mask, b_labels = batch
        model.zero_grad() # Always clear any previously calculated gradients before performing a backward pass.
        
        outputs = model(input_ids = b_input_ids, token_type_ids=None,
                        attention_mask=b_input_mask, labels=b_labels)
        
        loss = outputs[0]
        loss.backward()
        
        total_loss += loss.item()
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
        
        optimizer.step() # update parameters
        scheduler.step() # Update the learning rate.
    
    print("started")

    avg_train_loss = total_loss / len(train_dataloader) # Calculate the average loss over the training data.
    print("Average train loss: {}".format(avg_train_loss))

    loss_values.append(avg_train_loss) # Store the loss value for plotting the learning curve.
    
    print("pass training")


    # ========================================
    #               Validation
    # ========================================
    model.eval()
    
    eval_loss, eval_accuracy = 0, 0 # Reset the validation loss for this epoch.
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions , true_labels = [], []
    
    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        with torch.no_grad():
            # Calculate logit predictions.
            # This will return the logits rather than the loss because we have not provided labels.
            outputs = model(b_input_ids, token_type_ids=None,
                            attention_mask=b_input_mask, labels=b_labels)
        
        logits = outputs[1].detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy() # Move logits and labels to CPU

        eval_loss += outputs[0].mean().item()
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.extend(label_ids)
        
    print("pass validation")
    
    eval_loss = eval_loss / len(valid_dataloader)
    validation_loss_values.append(eval_loss)
    print("Validation loss: {}".format(eval_loss))
    pred_tags = [tag_values[p_i] for p, l in zip(predictions, true_labels)
                                  for p_i, l_i in zip(p, l) if tag_values[l_i] != "PAD"]
    valid_tags = [tag_values[l_i] for l in true_labels
                                   for l_i in l if tag_values[l_i] != "PAD"] 
    print("Validation Accuracy: {}".format(accuracy_score(pred_tags, valid_tags)))
    print("Validation F1-Score: {}".format(f1_score(pred_tags, valid_tags)))
    print()

'''
Average train loss: 0.13390333881295455
Validation loss: 0.10936674067364055
Validation Accuracy: 0.977215821504629
'''

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

started
Average train loss: 0.09026757150200894
pass training
pass validation
Validation loss: 0.11875062298172881
Validation Accuracy: 0.9747526091457671


Epoch:  25%|██▌       | 1/4 [1:06:15<3:18:45, 3975.10s/it]

Validation F1-Score: 0.73874440166575

started
Average train loss: 0.07032495812503851
pass training
pass validation
Validation loss: 0.12839929612759995
Validation Accuracy: 0.9751723108777363


Epoch:  50%|█████     | 2/4 [2:12:19<2:12:23, 3971.92s/it]

Validation F1-Score: 0.736087025067003

started
Average train loss: 0.05619299204360645
pass training
pass validation
Validation loss: 0.1280268517501515
Validation Accuracy: 0.9756293194303248


Epoch:  75%|███████▌  | 3/4 [3:18:25<1:06:10, 3970.19s/it]

Validation F1-Score: 0.7420736932305055

started
Average train loss: 0.04911095145447775
pass training
pass validation
Validation loss: 0.1280268517501515
Validation Accuracy: 0.9756293194303248


Epoch: 100%|██████████| 4/4 [4:24:27<00:00, 3966.81s/it]  

Validation F1-Score: 0.7420736932305055






'\nAverage train loss: 0.13390333881295455\nValidation loss: 0.10936674067364055\nValidation Accuracy: 0.977215821504629\n'

In [31]:
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score

print("Classification report: \n", (classification_report(valid_tags, pred_tags)))

  _warn_prf(average, modifier, msg_start, len(result))


Classification report: 
                  precision    recall  f1-score   support

          B-AGE       0.58      0.66      0.62       148
         B-CITY       0.30      0.29      0.29       135
      B-COUNTRY       1.00      0.05      0.10        37
         B-DATE       0.98      0.98      0.98      3559
       B-DEVICE       0.00      0.00      0.00        25
       B-DOCTOR       0.39      0.48      0.43       725
          B-FAX       0.00      0.00      0.00        16
     B-HOSPITAL       0.62      0.46      0.52       349
        B-IDNUM       0.83      0.53      0.65       241
B-MEDICALRECORD       0.72      0.94      0.81       257
 B-ORGANIZATION       0.00      0.00      0.00        65
      B-PATIENT       0.44      0.28      0.34       372
        B-PHONE       0.75      0.84      0.79       204
   B-PROFESSION       0.14      0.07      0.10        41
        B-STATE       0.50      0.53      0.51        36
       B-STREET       0.39      0.35      0.37        20
     