In [1]:
import sys 
sys.path.insert(0, '/home/yikuan/project/Code/')

In [2]:
from common.pytorch import save_model, load_model
from common.common import create_folder, load_obj
import os
import torch
import torch.nn.functional as F
from ACM.model.utils.utils import age_vocab
import pandas as pd
import pytorch_pretrained_bert as Bert
from ACM.dataLoader.HF import HF_data
from ACM.model.bertDBKLEmbedding import BertHF
import gpytorch
from ACM.model.optimiser import adam
import torch.nn as nn
from sklearn.metrics import average_precision_score, roc_auc_score

In [3]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings=config.get('max_position_embeddings'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')
        self.num_labels = config.get('num_labels')
        self.prior_rate = config.get('prior_rate')


class TrainConfig(object):
    def __init__(self, config):
        self.batch_size = config.get('batch_size')
        self.use_cuda = config.get('use_cuda')
        self.max_len_seq = config.get('max_len_seq')
        self.train_loader_workers = config.get('train_loader_workers')
        self.test_loader_workers = config.get('test_loader_workers')
        self.device = config.get('device')
        self.output_dir = config.get('output_dir')
        self.output_name = config.get('output_name')
        self.best_name = config.get('best_name')
        self.device1 = config.get('device1')

In [4]:
file_config = {
    'vocab':'/home/yikuan/project/Code/ACM/data/Full_vocab',
    'train': '/home/shared/yikuan/ACM/data/Diabetes/diabetes_clean_train.parquet',
    'test': '/home/shared/yikuan/ACM/data/Diabetes/diabetes_clean_test.parquet'
}

In [5]:
global_params = {
    'max_seq_len': 256,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

In [6]:
BertVocab = load_obj(file_config['vocab'])
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])

In [7]:
trainData_raw = pd.read_parquet(file_config['train']).reset_index(drop=True)
testData_raw = pd.read_parquet(file_config['test']).reset_index(drop=True)

In [8]:
train_params = {
    'batch_size': 128,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'train_loader_workers': 3,
    'test_loader_workers': 3,
    'device': 'cuda:1',
    'device1': 'cuda:0',
    'output_dir': '/home/shared/yikuan/ACM/model/Diabetes',
    'output_name': 'behrtDBKLEmbeddingsPrior.bin',
    'best_name': 'behrtDBKLEmbeddingsPrior_best.bin',
}

In [9]:
trainConfig = TrainConfig(train_params)

data_set = HF_data(trainConfig, trainData_raw, testData_raw, BertVocab['token2idx'], ageVocab, code='code', age='age')

trainData = data_set.get_weighted_sample_train(4)
testData = data_set.get_test_loader()

In [10]:
model_param = {
    'vocab_size': len(BertVocab['token2idx'].keys()),
    'hidden_size': 150,
    'num_hidden_layers': 4,
    'num_attention_heads': 6,
    'hidden_act': 'gelu',
    'intermediate_size': 108,
    'max_position_embeddings': global_params['max_seq_len'],
    'seg_vocab_size': 2,
    'age_vocab_size': len(ageVocab.keys()),
    'prior_prec': 1e0,
    'prec_init': 1e0,
    'initializer_range': 0.02,
    'num_labels': 1,
    'hidden_dropout_prob': 0.29,
    'attention_probs_dropout_prob': 0.38,
    'prior_rate': 1
}

feature_dict = {
    'word': True,
    'age': False,
    'seg': False,
    'norm': True
}

In [11]:
modelConfig = BertConfig(model_param)

model = BertHF(modelConfig, n_dim=24,grid_size=40, ard_num_dims=24, feature_dict=feature_dict, cuda1=trainConfig.device, cuda2=trainConfig.device1, split=True)

likelihood = gpytorch.likelihoods.BernoulliLikelihood().to(trainConfig.device1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model.gp_layer, num_data=len(trainData.dataset))

In [12]:
pretrained_dict = torch.load('/home/shared/yikuan/HF/MLM/PureICD_diag_med.bin')
# pretrained_dict = torch.load('/home/shared/yikuan/ACM/model/Diabetes/behrtKISSGP_best.bin')
model_dict = model.state_dict()
name_dict = {
    'bert.embeddings.word_embeddings.weight_posterior.loc': 'bert.embeddings.word_embeddings.weight',
#     'bert.embeddings.segment_embeddings.weight_posterior.loc': 'bert.embeddings.segment_embeddings.weight',
#     'bert.embeddings.age_embeddings.weight_posterior.loc': 'bert.embeddings.age_embeddings.weight'
}
for k,v in pretrained_dict.items():
    if (k in model_dict) and (k not in ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']):
#     if k in model_dict:
        model_dict[k] = v
for k,v in name_dict.items():
    model_dict[k] = pretrained_dict[v]
    
# model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [13]:
# def load_model(path, model):
#     # load pretrained model and update weights
#     pretrained_dict = torch.load(path)
#     model_dict = model.state_dict()
#     # 1. filter out unnecessary keys
#     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and k not in ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']}
#     # 2. overwrite entries in the existing state dict
#     model_dict.update(pretrained_dict)
#     # 3. load the new state dict
#     model.load_state_dict(model_dict)
#     return model

# model = load_model('/home/shared/yikuan/HF/MLM/PureICD_diag_med.bin', model)
model.allocateGPU()
optim = adam(list(model.named_parameters()),config=optim_param)

t_total value of -1 results in schedule not being applied


In [14]:
def precision(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    label, output=label.cpu(), output.detach().cpu()
    tempprc= average_precision_score(label.numpy(),output.numpy())
    auc = roc_auc_score(label.numpy(),output.numpy())
    return tempprc, output, label

def precision_test(logits, label):
#     sig = nn.Sigmoid()
#     output=sig(logits)
    tempprc= average_precision_score(label.numpy(),logits.numpy())
    auc = roc_auc_score(label.numpy(),logits.numpy())
    return tempprc, auc

In [15]:
def get_beta(batch_idx, m, beta_type):
    if beta_type == "Blundell":
        beta = 2 ** (m - (batch_idx + 1)) / (2 ** m - 1) 
    elif beta_type == "Soenderby":
        beta = min(epoch / (num_epochs // 4), 1)
    elif beta_type == "Standard":
        beta = 1 / m 
    else:
        beta = 0
    return beta

In [16]:
def train(e, trainload, model, likelihood, optim, m, beta_type):
    model.train()
    likelihood.train()

    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    cnt = 0
    for step, batch in enumerate(trainload):
        cnt += 1
        
        beta = get_beta(step, m, beta_type)
        
        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _, _ = batch

        age_ids = age_ids.to(train_params['device'])
        input_ids = input_ids.to(train_params['device'])
        posi_ids = posi_ids.to(train_params['device'])
        segment_ids = segment_ids.to(train_params['device'])
        attMask = attMask.to(train_params['device'])
        targets = targets.view(-1).to(train_params['device1'])

        output, kl = model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=targets)
        loss = -mll(output, targets)
        loss = loss + beta*kl
        
        loss.backward()

        temp_loss += loss.item()
        tr_loss += loss.item()
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1

        if step % 500 == 0:
            #             prec, a, b = precision(logits, targets)
            #             print("epoch: {}\t| Cnt: {}\t| Loss: {}\t| precision: {}".format(e, cnt, temp_loss / 500, prec))
            print("epoch: {}\t|step: {}\t|Loss: {}".format(e, step, temp_loss / 500))
            temp_loss = 0

        optim.step()
        optim.zero_grad()

    # Save a trained model
    output_model_file = os.path.join(trainConfig.output_dir, trainConfig.output_name)
    create_folder(trainConfig.output_dir)
    save_model(output_model_file, model)


def evaluation(testload, model, likelihood, n_sample=10):
    model.eval()
    y = []
    y_label = []
    loss_temp = 0

    model.eval()
    likelihood.eval()
    with torch.no_grad():
        for step, batch in enumerate(testload):
            age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _, _ = batch

            age_ids = age_ids.to(train_params['device'])
            input_ids = input_ids.to(train_params['device'])
            posi_ids = posi_ids.to(train_params['device'])
            segment_ids = segment_ids.to(train_params['device'])
            attMask = attMask.to(train_params['device'])
            targets = targets.view(-1).to(train_params['device'])

#             output = likelihood(
#                 model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=targets)[0])
            output_list = []
            for _ in range(n_sample):
                output, kl = model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=targets)
                output_list.append(output.sample())
            output_list = torch.sigmoid(torch.stack(output_list, dim=0))
            output_list = torch.mean(output_list, dim=0)
            

            logits = output_list.cpu()
            targets = targets.cpu()

            y_label.append(targets)
            y.append(logits)

        y_label = torch.cat(y_label, dim=0)
        y = torch.cat(y, dim=0)

        tempprc, auc = precision_test(y.view(-1), y_label.view(-1))
        return tempprc, auc

In [17]:
best_pre = 0
for e in range(20):
    train(e, trainData, model, likelihood,optim, len(trainData), "Blundell")
#     auc_train, time_cost_train = evaluation(trainData, model)
    auc_test, auc = evaluation(testData, model, likelihood, n_sample=20)
    print('test precision: {}, auc {}'.format(auc_test, auc))
    if auc_test >best_pre:
        # Save a trained model
        output_model_file = os.path.join(trainConfig.output_dir, trainConfig.best_name)
        create_folder(trainConfig.output_dir)
        save_model(output_model_file, model)
        best_pre = auc_test

epoch: 0	|step: 0	|Loss: 828.2166875
epoch: 0	|step: 500	|Loss: 831.7932905516625
epoch: 0	|step: 1000	|Loss: 3.1514769811630248
epoch: 0	|step: 1500	|Loss: 2.856434576511383
epoch: 0	|step: 2000	|Loss: 2.7063269419670104
epoch: 0	|step: 2500	|Loss: 2.5711948308944703
epoch: 0	|step: 3000	|Loss: 2.4021508004665373
epoch: 0	|step: 3500	|Loss: 2.207681831359863
epoch: 0	|step: 4000	|Loss: 2.047792819261551
epoch: 0	|step: 4500	|Loss: 1.9352261986732482
epoch: 0	|step: 5000	|Loss: 1.826802354812622
** ** * Saving fine - tuned model ** ** * 
test precision: 0.4427256305615666, auc 0.7613473609816449
** ** * Saving fine - tuned model ** ** * 
epoch: 1	|step: 0	|Loss: 821.95875
epoch: 1	|step: 500	|Loss: 823.5440304970741
epoch: 1	|step: 1000	|Loss: 1.66576011633873
epoch: 1	|step: 1500	|Loss: 1.5786875529289246
epoch: 1	|step: 2000	|Loss: 1.5081306006908417
epoch: 1	|step: 2500	|Loss: 1.4366275864839553
epoch: 1	|step: 3000	|Loss: 1.3742806398868561
epoch: 1	|step: 3500	|Loss: 1.29890493786

KeyboardInterrupt: 