# AFF Disease Prediction Using BEHRT

This notebook implements AFF disease prediction using BEHRT.

In [None]:

import sys 
sys.path.insert(0, '../') 
from torch.utils.data import DataLoader
import pandas as pd
from common.common import create_folder
import numpy as np
import os
import torch
import torch.nn as nn
import pytorch_pretrained_bert as Bert
from model import optimiser
import sklearn.metrics as skm
from sklearn.metrics import roc_auc_score, auc, f1_score, average_precision_score, confusion_matrix, roc_curve
import matplotlib.pyplot as plt
import seaborn as sns
import random
from dataLoader.NextXVisit2 import NextVisit
from model.NextXVisit2 import BertForMultiLabelPrediction
from common.common import load_obj
from model.utils import age_vocab
from sklearn.preprocessing import MultiLabelBinarizer
import warnings
warnings.filterwarnings(action='ignore')


In [None]:

# Load vocabulary
import pickle
with open('./vocab2_new.pkl', 'rb') as f:
    vocab2_new = pickle.load(f)

# Define configurations
file_config = {
    'vocab': 'vocab2_new',
    'train': 'train',
    'test': 'test',
}

optim_config = {
    'lr': 1.857576788683168e-05,
    'warmup_proportion': 0.14677794038068517,
    'weight_decay': 0.005229456209250363
}

global_params = {
    'batch_size': 256,
    'gradient_accumulation_steps': 1,
    'device': 'cuda',
    'output_dir':'../ndp',
    'best_name': 'ndp_best_j_0',
    'max_len_seq': 64,
    'max_age': 110,
    'age_year': False,
    'age_symbol': None,
    'min_visit': 3
}

pretrain_model_path = './T20_BFC_MLM_3_op1_v1'

# Create word dictionary
word_dict = {'PAD': 0, 'CLS': 1, 'SEP': 2, 'MASK': 3, 'UNK' :4}
for i, w in enumerate(vocab2_new):
   word_dict[w] = i + 4
vocab_size = len(word_dict)

BertVocab = word_dict
ageVocab, _ = age_vocab(max_age=global_params['max_age'], symbol=global_params['age_symbol'])


In [None]:

# Define model configuration
model_config = {
    'vocab_size': len(BertVocab.keys()),
    'hidden_size': 288,
    'seg_vocab_size': 2,
    'age_vocab_size': len(ageVocab.keys()),
    'max_position_embedding': global_params['max_len_seq'],
    'hidden_dropout_prob': 0.1,
    'num_hidden_layers': 10,
    'num_attention_heads': 12,
    'attention_probs_dropout_prob': 0.1,
    'intermediate_size': 512,
    'hidden_act': 'gelu',
    'initializer_range': 0.04,
}

# Feature configuration
feature_dict = {
    'word': True,
    'seg': True,
    'age': True,
    'position': True
}

# Define BertConfig
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings=config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')


In [None]:

# Load datasets
with open('../task/d138_5y_onset_Train2_jan2024_aug06_sep26_F.pkl', 'rb') as f:
    train = pickle.load(f)

with open('../task/d138_5y_onset_Valid2_jan2024_aug06_sep26_F.pkl', 'rb') as f:
    valid = pickle.load(f)

with open('../task/d138_5y_onset_Test2_jan2024_aug06_sep26_F.pkl', 'rb') as f:
    test = pickle.load(f)

# Reset indices
train.index = range(len(train))
valid.index = range(len(valid))
test.index = range(len(test))


In [None]:

# DataLoader
Dset = NextVisit(token2idx=BertVocab, label2idx=BertVocab, age2idx=ageVocab, dataframe=train, max_len=global_params['max_len_seq'], code='disease_sequenceF', age='age2_sequenceF', label='d138')
trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3)

Dset2 = NextVisit(token2idx=BertVocab, label2idx=BertVocab, age2idx=ageVocab, dataframe=test, max_len=global_params['max_len_seq'], code='disease_sequenceF', age='age2_sequenceF', label='d138')
testload = DataLoader(dataset=Dset2, batch_size=global_params['batch_size'], shuffle=False, num_workers=3)

#Dset = NextVisit(token2idx=BertVocab, label2idx=labelVocab, age2idx=ageVocab, dataframe=train, max_len=global_params['max_len_seq'],code='disease_sequence2', age='age_sequence', label='Diag_d')
Dset3 = NextVisit(token2idx=BertVocab, label2idx=labelVocab, age2idx=ageVocab, dataframe=valid, max_len=global_params['max_len_seq'],code='disease_sequenceF', age='age2_sequenceF',label= 'd138') #label= 'd_138') #'Diag_d')
# Dset = NextVisit(token2idx=BertVocab, label2idx=labelVocab, age2idx=ageVocab, dataframe=train, max_len=global_params['max_len_seq'],code='disease_sequence_new', age='age2_sequence_new',label='label') #label= 'd_138') #'Diag_d')
validload = DataLoader(dataset=Dset3, batch_size=global_params['batch_size'], shuffle=True, num_workers=3)


In [None]:

# # Model initialization
# conf = BertConfig(model_config)
# model = BertForMultiLabelPrediction(conf, num_labels=1, feature_dict=feature_dict)
# model = model.to(global_params['device'])

# # Optimizer
# optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)


In [None]:
def load_model(path, model):
    # load pretrained model and update weights
    pretrained_dict = torch.load(path)
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)
    return model

mode = load_model(pretrain_model_path, model)

In [None]:
model = model.to(global_params['device'])
optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)

In [None]:
import sklearn
def precision(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    label, output=label.cpu(), output.detach().cpu()
    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')
    return tempprc, output, label

def precision_test(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')
    roc = sklearn.metrics.roc_auc_score(label.numpy(),output.numpy(), average='samples')
    return tempprc, roc, output, label

In [None]:
#binary class
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

def train(e):
    model.train()
    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    cnt = 0
    for step, batch in enumerate(trainload):
        cnt +=1
        age_ids, input_ids, posi_ids, segment_ids, attMask, targets = batch
        
        #targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)

        age_ids = age_ids.to(global_params['device'])
        input_ids = input_ids.to(global_params['device'])
        posi_ids = posi_ids.to(global_params['device'])
        segment_ids = segment_ids.to(global_params['device'])
        attMask = attMask.to(global_params['device'])
        targets = targets.to(global_params['device'])
        
        loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)
        
        if global_params['gradient_accumulation_steps'] >1:
            loss = loss/global_params['gradient_accumulation_steps']
        loss.backward()
        
        temp_loss += loss.item()
        tr_loss += loss.item()
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1
        
        if step % 500==0:
            prec, a, b = precision(logits, targets)
            print("epoch: {}\t| Cnt: {}\t| Loss: {}\t| precision: {}".format(e, cnt,temp_loss/500, prec))
            temp_loss = 0
        
        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:
            optim.step()
            optim.zero_grad()

def evaluation():
    model.eval()
    y = []
    y_label = []
    tr_loss = 0
    for step, batch in enumerate(validload): #validload instead of testload
        model.eval()
        age_ids, input_ids, posi_ids, segment_ids, attMask, targets= batch
        #targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)
        age_ids = age_ids.to(global_params['device'])
        input_ids = input_ids.to(global_params['device'])
        posi_ids = posi_ids.to(global_params['device'])
        segment_ids = segment_ids.to(global_params['device'])
        attMask = attMask.to(global_params['device'])
        targets = targets.to(global_params['device'])
        
        with torch.no_grad():
            loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)
        logits = logits.cpu()
        targets = targets.cpu()
        
        tr_loss += loss.item()

        y_label.append(targets)
        y.append(logits)

    y_label = torch.cat(y_label, dim=0)
    y = torch.cat(y, dim=0)

    aps, roc, output, label = precision_test(y, y_label)

     # Convert the tensors to numpy arrays for use with sklearn
    y_label_np = y_label.detach().numpy()
    y_np = y.detach().numpy()

    # Compute the ROC curve and ROC area for each class
    fpr, tpr, _ = roc_curve(y_label_np, y_np)
    roc_auc = auc(fpr, tpr)

    # Plot the ROC curve
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc="lower right")
    plt.show()
    return aps, roc, tr_loss

In [None]:
## without temp_loss in return : sep25
##wwwww
best_pre = 0.0
best_roc =0.0
best_loss =0.0
eval_auprc_log = []
eval_auroc_log = []
#train_loss_log = []
eval_loss_log = []
for e in range(15):
    train(e)
    aps, roc, test_loss = evaluation()
    eval_auprc_log.append(aps)
    eval_auroc_log.append(roc)
    eval_loss_log.append(test_loss)
    
    if aps >best_pre:
        # Save a trained model
        print("** ** * Saving fine - tuned model ** ** * ")
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(global_params['output_dir'],global_params['best_name'])
        create_folder(global_params['output_dir'])

        torch.save(model_to_save.state_dict(), output_model_file)
        best_pre = aps
        best_roc = roc
        best_loss= test_loss
    print('roc : {}'.format(roc), 'aps : {}'.format(aps))

In [None]:
# ephonc 10
import os
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, roc_curve, confusion_matrix

# Load the best model
model_path = os.path.join(global_params['output_dir'], "ndp_best_j_0")
conf = BertConfig(model_config)
model = BertForMultiLabelPrediction(conf, num_labels=1, feature_dict=feature_dict)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

# Move the model to the same device as the tensors
print("Moving model to device...")
model = model.to(global_params['device'])
print("Model moved to device.")

# Put the model in evaluation mode
model.eval()

# Apply the model to the test data
test_predictions = []
test_targets = []
for batch in testload:  # Use the testload DataLoader you created
    # Unpack the batch
    age_ids, input_ids, posi_ids, segment_ids, attMask, targets = batch

    # Move the batch tensors to the same device as the model
    age_ids = age_ids.to(global_params['device'])
    input_ids = input_ids.to(global_params['device'])
    posi_ids = posi_ids.to(global_params['device'])
    segment_ids = segment_ids.to(global_params['device'])
    attMask = attMask.to(global_params['device'])
    targets = targets.to(global_params['device'])

    # Compute the model output
    with torch.no_grad():
        outputs = model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask)

    # Move the model output tensors back to CPU
    outputs = outputs.cpu()
    targets = targets.cpu()

    # Store the predictions and targets
    test_predictions.append(outputs)
    test_targets.append(targets)

# Concatenate all the predictions and targets
test_predictions = torch.cat(test_predictions, dim=0)
test_targets = torch.cat(test_targets, dim=0)

# Compute the metrics
auroc = roc_auc_score(test_targets, test_predictions)
auprc = average_precision_score(test_targets, test_predictions)
f1 = f1_score(test_targets, test_predictions.round())  # Use round to convert predictions to binary

# Print the metrics
print(f"Metrics for the best model:")
print(f"  AUROC: {auroc}")
print(f"  AUPRC: {auprc}")
print(f"  F1 Score: {f1}")

# Draw the ROC curve
fpr, tpr, _ = roc_curve(test_targets, test_predictions)
plt.figure()
plt.plot(fpr, tpr, label=f'AUROC = {auroc:.2f}')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='best')
plt.show()

# Compute and plot the confusion matrix
cm = confusion_matrix(test_targets, test_predictions.round())
plt.figure()
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()lt.title('ROC Curve')
plt.legend(loc='best')
plt.show()

# Compute and plot the confusion matrix
cm = confusion_matrix(test_targets, test_predictions.round())
plt.figure()
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()