# AFF Disease Prediction Using BEHRT for MLM

This notebook demonstrates the training process for AFF disease prediction using BEHRT with Masked Language Modeling (MLM).

In [None]:

import sys
sys.path.insert(0, '../')

from common.common import create_folder
from common.pytorch import load_model
from dataLoader.MLM import MLMLoader
import pytorch_pretrained_bert as Bert
from model.utils import age_vocab
from model.MLM import BertForMaskedLM
from model.optimiser import adam

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import os
import time
import pickle
import sklearn.metrics as skm
from sklearn.metrics import roc_auc_score, average_precision_score


## Step 1: Define Model and Training Configurations

In [None]:

# Define BERT Configuration Class
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings=config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')

# Training Configurations
global_params = {
    'max_seq_len': 64,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
    'min_visit': 3,
    'gradient_accumulation_steps': 1
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

train_params = {
    'batch_size': 256,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'device': 'cuda'
}


## Step 2: Load Preprocessed Data and Vocabulary

In [None]:

# Load preprocessed data
with open('./T20_BFC_BEHRT_group_data_sickFinal_mlm_op1.pkl', 'rb') as f:
    data = pickle.load(f)
group_data_sickFinal_mlm = pd.DataFrame(data)

# Load vocabulary
with open('./vocab2_new.pkl', 'rb') as f:
    vocab2 = pickle.load(f)

# Create word dictionary
word_dict = {'PAD': 0, 'CLS': 1, 'SEP': 2, 'MASK': 3, 'UNK': 4}
for i, w in enumerate(vocab2):
    word_dict[w] = i + 4
BertVocab = word_dict

# Create age vocabulary
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])

# Filter patients with sufficient visits
data['length'] = data['d2'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))
data = data[data['length'] >= global_params['min_visit']].reset_index(drop=True)


## Step 3: Prepare DataLoader for Training

In [None]:

# Prepare the DataLoader
Dset = MLMLoader(data, BertVocab, ageVocab, max_len=train_params['max_len_seq'], code='d2', age='AGE2')
trainload = torch.utils.data.DataLoader(dataset=Dset, batch_size=train_params['batch_size'], shuffle=True, num_workers=3)


## Step 4: Initialize BEHRT Model

In [None]:

# Define BEHRT Model Configuration
model_config = {
    'vocab_size': len(BertVocab.keys()),
    'hidden_size': 288,
    'seg_vocab_size': 2,
    'age_vocab_size': len(ageVocab.keys()),
    'max_position_embedding': train_params['max_len_seq'],
    'hidden_dropout_prob': 0.1,
    'num_hidden_layers': 6,
    'num_attention_heads': 12,
    'attention_probs_dropout_prob': 0.1,
    'intermediate_size': 512,
    'hidden_act': 'gelu',
    'initializer_range': 0.02
}

conf = BertConfig(model_config)
model = BertForMaskedLM(conf)
model = model.to(train_params['device'])

# Define Optimizer
optim = adam(params=list(model.named_parameters()), config=optim_param)


## Step 5: Define Training Function

In [None]:

def cal_acc(label, pred):
    logs = nn.LogSoftmax(dim=-1)
    label = label.cpu().numpy()
    ind = np.where(label != -1)[0]
    truepred = pred.detach().cpu().numpy()
    truepred = truepred[ind]
    truelabel = label[ind]
    truepred = logs(torch.tensor(truepred))
    outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]
    precision = skm.precision_score(truelabel, outs, average='micro')
    recall = skm.recall_score(truelabel, outs, average='micro')
    f1 = skm.f1_score(truelabel, outs, average='micro')
    return precision, recall, f1

def train(e, loader):
    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    start = time.time()

    for step, batch in enumerate(loader):
        batch = tuple(t.to(train_params['device']) for t in batch)
        age_ids, input_ids, posi_ids, segment_ids, attMask, masked_label = batch
        loss, pred, label = model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, masked_lm_labels=masked_label)

        if global_params['gradient_accumulation_steps'] > 1:
            loss = loss / global_params['gradient_accumulation_steps']
        loss.backward()

        temp_loss += loss.item()
        tr_loss += loss.item()

        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1

        if step % 200 == 0:
            precision, recall, f1 = cal_acc(label, pred)
            print(f"Epoch: {e}, Step: {step}, Loss: {temp_loss/200:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
            temp_loss = 0

        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:
            optim.step()
            optim.zero_grad()

    return tr_loss / len(loader), time.time() - start


## Step 6: Train the Model

In [None]:

# Train the model
file_config = {
    'model_path': './',
    'model_name': 'T20_BFC_MLM_3_op1_v1',
    'file_name': 'MLM_log_3_op1'
}

create_folder(file_config['model_path'])

f = open(os.path.join(file_config['model_path'], file_config['file_name']), "w")
f.write('{}\t{}\t{}\n'.format('epoch', 'loss', 'time'))

for e in range(30):
    loss, time_cost = train(e, trainload)
    f.write('{}\t{:.4f}\t{:.2f}\n'.format(e, loss, time_cost))

f.close()
