In [1]:
import sys 
sys.path.insert(0, '/home/yikuan/project/Code/')

In [None]:
from common.pytorch import save_model
from common.common import create_folder, load_obj
import os
import torch
from ACM.model.utils.utils import age_vocab
import pandas as pd
import pytorch_pretrained_bert as Bert
from ACM.dataLoader.HF import HF_data
from ACM.model.bertWhitenedGP import BertHF
import gpytorch
from Problems.BEHRT.tools.optimiser import adam
import torch.nn as nn
from sklearn.metrics import average_precision_score,roc_auc_score
# from Utils.evaluation import uncertain_cal

In [None]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings=config.get('max_position_embeddings'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')
        self.num_labels = config.get('num_labels')


class TrainConfig(object):
    def __init__(self, config):
        self.batch_size = config.get('batch_size')
        self.use_cuda = config.get('use_cuda')
        self.max_len_seq = config.get('max_len_seq')
        self.train_loader_workers = config.get('train_loader_workers')
        self.test_loader_workers = config.get('test_loader_workers')
        self.device = config.get('device')
        self.output_dir = config.get('output_dir')
        self.output_name = config.get('output_name')
        self.best_name = config.get('best_name')

In [None]:
file_config = {
    'vocab':'/home/shared/yikuan/HF/Data/PureICD/Full_vocab',
    'train': '/home/shared/01_data/03_cuts/pureicd_ageInMonth/data_fract/fract_10.parquet',
    'test': '/home/shared/01_data/03_cuts/01_model_improvement/separate_med_diag/testV1BertDM.parquet'
}

In [None]:
global_params = {
    'max_seq_len': 256,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

In [None]:
BertVocab = load_obj(file_config['vocab'])
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])

In [None]:
trainData_raw = pd.read_parquet(file_config['train'])
testData_raw = pd.read_parquet(file_config['test'])

In [None]:
train_params = {
    'batch_size': 64,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'train_loader_workers': 3,
    'test_loader_workers': 0,
    'device': 'cuda:1',
    'output_dir': '/home/shared/yikuan/HF/model/PureICD',
    'output_name': 'GPabs_diag_med_MLM_test.bin',
    'best_name': 'GPabs_diag_med_MLM_test_best.bin',
}

In [None]:
trainConfig = TrainConfig(train_params)

data_set = HF_data(trainConfig, trainData_raw, testData_raw, BertVocab['token2idx'], ageVocab, code='code', age='age')

trainData = data_set.get_weighted_sample_train(4)
testData = data_set.get_test_loader()

In [None]:
model_param = {
    'vocab_size': len(BertVocab['token2idx'].keys()),
    'hidden_size': 150,
    'num_hidden_layers': 4,
    'num_attention_heads': 6,
    'hidden_act': 'gelu',
    'intermediate_size': 108,
    'max_position_embeddings': global_params['max_seq_len'],
    'seg_vocab_size': 2,
    'age_vocab_size': len(ageVocab.keys()),
    'prior_prec': 1e0,
    'prec_init': 1e0,
    'initializer_range': 0.02,
    'num_labels': 1,
    'hidden_dropout_prob': 0.29,
    'attention_probs_dropout_prob': 0.38
}

In [None]:
modelConfig = BertConfig(model_param)

model = BertHF(modelConfig, n_dim=24,grid_size=40, ard_num_dims=1)

likelihood = gpytorch.likelihoods.BernoulliLikelihood().to(trainConfig.device)
mll = gpytorch.mlls.VariationalELBO(likelihood, model.gp_layer, num_data=len(trainData.dataset))

In [None]:
def load_model(path, model):
    # load pretrained model and update weights
    pretrained_dict = torch.load(path)
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and k not in ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)
    return model

model = load_model('/home/shared/yikuan/HF/MLM/PureICD_diag_med.bin', model)
model = model.to(trainConfig.device)
optim = adam(list(model.named_parameters()),config=optim_param)

In [None]:
def precision(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    label, output=label.cpu(), output.detach().cpu()
    tempprc= average_precision_score(label.numpy(),output.numpy())
    return tempprc, output, label

def precision_test(logits, label):
#     sig = nn.Sigmoid()
#     output=sig(logits)
    tempprc= average_precision_score(label.numpy(),logits.numpy())
    auc = roc_auc_score(label.numpy(),logits.numpy())
    return tempprc, auc, label

In [None]:
def train(e, trainload, model, likelihood, optim):
    model.train()
    likelihood.train()

    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    cnt = 0
    for step, batch in enumerate(trainload):
        cnt += 1

        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _, _ = batch

        age_ids = age_ids.to(train_params['device'])
        input_ids = input_ids.to(train_params['device'])
        posi_ids = posi_ids.to(train_params['device'])
        segment_ids = segment_ids.to(train_params['device'])
        attMask = attMask.to(train_params['device'])
        targets = targets.view(-1).to(train_params['device'])

        output = model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=targets)
        loss = -mll(output, targets)
        loss.backward()

        temp_loss += loss.item()
        tr_loss += loss.item()
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1

        if step % 500 == 0:
            #             prec, a, b = precision(logits, targets)
            #             print("epoch: {}\t| Cnt: {}\t| Loss: {}\t| precision: {}".format(e, cnt, temp_loss / 500, prec))
            print("epoch: {}\t|step: {}\t|Loss: {}".format(e, step, temp_loss / 500))
            temp_loss = 0

        optim.step()
        optim.zero_grad()

    # Save a trained model
    output_model_file = os.path.join(trainConfig.output_dir, trainConfig.output_name)
    create_folder(trainConfig.output_dir)
    save_model(output_model_file, model)


def evaluation(testload, model, likelihood):
    model.eval()
    y = []
    y_label = []
    loss_temp = 0

    model.eval()
    likelihood.eval()
    with torch.no_grad():
        for step, batch in enumerate(testload):
            age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _, _ = batch

            age_ids = age_ids.to(train_params['device'])
            input_ids = input_ids.to(train_params['device'])
            posi_ids = posi_ids.to(train_params['device'])
            segment_ids = segment_ids.to(train_params['device'])
            attMask = attMask.to(train_params['device'])
            targets = targets.view(-1).to(train_params['device'])

            output = likelihood(
                model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=targets))

            logits = output.mean.float().cpu()
            targets = targets.cpu()

            y_label.append(targets)
            y.append(logits)

        y_label = torch.cat(y_label, dim=0)
        y = torch.cat(y, dim=0)

        tempprc, auc, label = precision_test(y.view(-1), y_label.view(-1))
        return tempprc, auc

In [None]:
# def evaluation_GP(testload, model, likelihood, cuda, n_sample=10):
#     model.eval()
#     y = []
#     y_label = []
#     variance_list = []
#     predictive_list = []
#     epis_list = []
#     alea_list = []
#     age_list = []
#     patid_list = []
    
#     loss_temp = 0

# #     model.eval()
#     likelihood.eval()
#     with torch.no_grad():
#         for step, batch in enumerate(testload):
#             age_ids, input_ids, posi_ids, segment_ids, attMask, targets, patid, last_age = batch

#             age_ids = age_ids.to(cuda)
#             input_ids = input_ids.to(cuda)
#             posi_ids = posi_ids.to(cuda)
#             segment_ids = segment_ids.to(cuda)
#             attMask = attMask.to(cuda)
#             targets = targets.view(-1).to(cuda)
#             last_age = last_age.view(-1)
            
#             output_list = []
#             for _ in range(n_sample):
#                 output = model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=targets)
#                 output_list.append(output.sample())
                
#             logits = torch.mean(torch.sigmoid(torch.stack(output_list, dim=1)),dim=1).cpu()
#             alea, epis, predictive, variance = uncertain_cal(torch.sigmoid(torch.stack(output_list, dim=1)), dim=1)
#             targets = targets.cpu()
            
#             alea_list.append(alea.cpu())
#             epis_list.append(epis.cpu())
#             variance_list.append(variance.cpu())
#             predictive_list.append(predictive.cpu())
            
#             y_label.append(targets)
#             y.append(logits)
# #             variance_list.append(variance)
#             patid_list.append(patid.view(-1))
#             age_list.append(last_age)
            
# #             if step % 500 == 0:
# #                 print('GP: {}'.format(step))

#         y_label = torch.cat(y_label, dim=0).view(-1)
#         y = torch.cat(y, dim=0).view(-1)
#         variance_list = torch.cat(variance_list, dim=0).view(-1)
#         patid_list = torch.cat(patid_list, dim=0).view(-1)
#         age_list = torch.cat(age_list, dim=0).view(-1)
#         alea_list = torch.cat(alea_list, dim=0).view(-1)
#         epis_list = torch.cat(epis_list,dim=0).view(-1)
#         predictive_list = torch.cat(predictive_list, dim=0).view(-1)

#         tempprc, output, label = precision_GP(y.view(-1), y_label.view(-1))
#         score, _, _ = roc_score_GP(y.view(-1), y_label.view(-1))
#         return tempprc, score, y, y_label, variance_list, patid_list, age_list, epis_list, alea_list, predictive_list

In [None]:
best_pre = 0
for e in range(20):
    train(e, trainData, model, likelihood,optim)
#     auc_train, time_cost_train = evaluation(trainData, model)
    auc_test, auc_score = evaluation(testData, model, likelihood)
#     tempprc, score, y, y_label, variance_list, patid_list, age_list, epis_list, alea_list, predictive_list = evaluation_GP(testData, model, likelihood, trainConfig.device, 40)

    print('test precision: {}, {}'.format(auc_test, auc_score))
    if auc_test >best_pre:
        # Save a trained model
        output_model_file = os.path.join(trainConfig.output_dir, trainConfig.best_name)
        create_folder(trainConfig.output_dir)
        save_model(output_model_file, model)
        best_pre = auc_test