## ICD-10 Prediction using Pre-Trained BERT Model
This script trains a [BertForSequenceClassification model]{https://huggingface.co/microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext} to predict the ICD-10 code given the text of a cancer pathology report.

In [1]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as torch_dataset
from torch import optim
from transformers import AutoTokenizer, BertForSequenceClassification
import numpy as np
import pandas as pd
import re
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.utils import shuffle

In [2]:
def load_data(fn, Xcol="c.path_notes", ycol='c.icd10_after_spilt'):
    data = pd.read_csv(fn, sep='\t')
    X = data[Xcol].values
    labels = data[ycol].values
    labels = labels.reshape(-1, 1) #shape as 2d array to feed to OneHotEncoder
    return X, labels

In [3]:
def process(text):
    text = text.lower()
    text = re.sub("dr\.",'dr', text)
    text = re.sub('m\.d\.', 'md', text)
    text = re.sub('a\.m\.','am', text)
    text = re.sub('p\.m\.','pm', text)
    text = re.sub("\d+\.\d+", 'floattoken', text)
    text = re.sub("\.{2,}", '.', text)
    text = re.sub('[^\w_|\.|\?|!]+', ' ', text)
    text = re.sub('\.', ' .', text)
    text = re.sub('\?', ' ? ', text)
    text = re.sub('!', ' ! ', text)
    text = re.sub('\d{3,}', '', text)
    return text

In [4]:
def encode(data, tokenizer, max_length=100):
    input_ids = []
    attention_mask = []
    for text in data:
        tokenized_text = tokenizer.encode_plus(
            process(text),
            add_special_tokens=True,
            max_length=max_length,
            padding='max_length',
            return_attention_mask=True,
            truncation=True
        )
        
        input_ids.append(tokenized_text['input_ids'])
        attention_mask.append(tokenized_text['attention_mask'])
    
    return torch.tensor(input_ids, dtype=torch.long), torch.tensor(attention_mask, dtype=torch.long)

In [5]:
class MyDataset(torch_dataset):
    def __init__(self, ii, am, y):
        self.ii=ii
        self.am=am
        self.y=y #this is very important
    def __len__(self):
        #return the number of data points
        return self.ii.shape[0]
    def __getitem__(self, idx):        
        # use the notation DatasetName[idx]
        # to get a data point (x,y) by idx
        # we need to convert numpy array to torch tensor
        ii=self.ii[idx] # cast as tensor
        am=self.am[idx]
        y=self.y[idx]
        return ii, am, y

In [6]:
def train(model, optimizer, dataloader, epoch, device):    
    model.train() #set model to train mode
    loss_train=0
    for batch_idx, (ii, am, y) in enumerate(dataloader):
        ii, am, y = ii.to(device), am.to(device), y.to(device)
        optimizer.zero_grad() #clear the grad of each parameter
        output = model(ii, am, labels=y) #forward pass
        loss = output.loss 
        loss.backward() #backward pass
        optimizer.step() #update parameters
        loss_train+=loss.item()
        if batch_idx % 60 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * ii.size(0), len(dataloader.dataset),
                    100. * batch_idx / len(dataloader), loss.item()))
    loss_train/=len(dataloader)
    return loss_train

In [7]:
def score(model, dataloader, device):
    model.eval()
    f1s = []
    with torch.no_grad():
        for batch_idx, (ii, am, y) in enumerate(dataloader):
            ii, am = ii.to(device), am.to(device)
            logits = model(ii, am).logits #forward pass
            y_pred = np.argmax(logits, axis=1)
            y_true = np.argmax(y, axis=1)
            f1s.append(f1_score(y_true, y_pred, average='micro'))
    return np.mean(f1s)

In [8]:
def main(epochs=10, max_datapoints=None, max_seq_length=100, shuffle=True):
    
    X, labels = load_data("icd10sitesonly.txt")
    
    if shuffle:
        X, labels = shuffle(X, labels)
    
    ohe = OneHotEncoder(sparse=False)
    y = torch.tensor(ohe.fit_transform(labels), dtype=torch.float32)
    
    if max_datapoints is not None:
        X = X[:max_datapoints]
        y = y[:max_datapoints]
    
    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")

    input_ids, attn_mask = encode(X, tokenizer, max_length=max_seq_length)

    ii_train, ii_test, am_train, am_test, y_train, y_test = train_test_split(input_ids, attn_mask, y, test_size=0.1)
    ii_train, ii_val, am_train, am_val, y_train, y_val = train_test_split(ii_train, am_train, y_train, test_size=0.2)

    train_set = MyDataset(ii_train, am_train, y_train)
    val_set = MyDataset(ii_val, am_val, y_val)
    test_set = MyDataset(ii_test, am_test, y_test)

    dataloader_train = DataLoader(train_set, batch_size=64, shuffle=True)
    dataloader_val = DataLoader(val_set, batch_size=64, shuffle=False)
    dataloader_test = DataLoader(test_set, batch_size=64, shuffle=False)

    model = BertForSequenceClassification.from_pretrained(
        "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
        num_labels=y.shape[1]
    )

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.01)
    scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=70, gamma=0.5)

    loss_train_list=[]
    acc_val_list=[]
    for epoch in range(epochs):
        #print("Current learning rate:", scheduler.get_last_lr())
        loss_train = train(model, optimizer, dataloader_train, epoch, device)
        loss_train_list.append(loss_train)
        print('epoch', epoch, 'training loss:', loss_train)
        acc = score(model, dataloader_val, device)
        acc_val_list.append(acc)
        print('epoch', epoch, 'validation accuracy:', acc)
        scheduler.step()