In [1]:
import sys 
sys.path.insert(0, '/home/yikuan/project/Code/')

In [2]:
from common.pytorch import save_model
from common.common import create_folder,load_obj
from common.pytorch import load_model as torch_load_model
import os
import torch
from ACM.model.utils.utils import age_vocab
import pandas as pd
import pytorch_pretrained_bert as Bert
from ACM.dataLoader.HF import HF_data
from ACM.model.bertBayesianEmbeddingOutput import BertHF
from ACM.model.optimiser import adam
import torch.nn as nn
from sklearn.metrics import average_precision_score, roc_auc_score

In [3]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings=config.get('max_position_embeddings'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')
        self.num_labels = config.get('num_labels')


class TrainConfig(object):
    def __init__(self, config):
        self.batch_size = config.get('batch_size')
        self.use_cuda = config.get('use_cuda')
        self.max_len_seq = config.get('max_len_seq')
        self.train_loader_workers = config.get('train_loader_workers')
        self.test_loader_workers = config.get('test_loader_workers')
        self.device = config.get('device')
        self.output_dir = config.get('output_dir')
        self.output_name = config.get('output_name')
        self.best_name = config.get('best_name')
        self.num_samples = config.get('num_samples')

In [4]:
file_config = {
    'vocab':'/home/yikuan/project/Code/ACM/data/Full_vocab',
    'train': '/home/shared/yikuan/ACM/data/Depression/Depression_clean_train.parquet',
    'test': '/home/shared/yikuan/ACM/data/Depression/Depression_clean_test.parquet'
}

In [5]:
global_params = {
    'max_seq_len': 256,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

In [6]:
BertVocab = load_obj(file_config['vocab'])
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])

In [7]:
trainData = pd.read_parquet(file_config['train']).reset_index(drop=True)
testData = pd.read_parquet(file_config['test']).reset_index(drop=True)

In [8]:
train_params = {
    'batch_size': 64,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'train_loader_workers': 3,
    'test_loader_workers': 3,
    'num_samples': 10,
    'device': 'cuda:1',
    'output_dir': '/home/shared/yikuan/ACM/model/Depression',
    'output_name': 'behrtBayesianEmbeddingsOutput.bin',
    'best_name': 'behrtBayesianEmbeddingsOutput_best.bin',
}

In [9]:
trainConfig = TrainConfig(train_params)

data_set = HF_data(trainConfig, trainData, testData, BertVocab['token2idx'], ageVocab, code='code', age='age')

trainData = data_set.get_weighted_sample_train(4)
testData = data_set.get_test_loader()

In [10]:
model_param = {
    'vocab_size': len(BertVocab['token2idx'].keys()),
    'hidden_size': 150,
    'num_hidden_layers': 4,
    'num_attention_heads': 6,
    'hidden_act': 'gelu',
    'intermediate_size': 108,
    'max_position_embeddings': global_params['max_seq_len'],
    'seg_vocab_size': 2,
    'age_vocab_size': len(ageVocab.keys()),
    'prior_prec': 1e0,
    'prec_init': 1e0,
    'initializer_range': 0.02,
    'num_labels': 1,
    'hidden_dropout_prob': 0.29,
    'attention_probs_dropout_prob': 0.38
}

feature_dict = {
    'word': True,
    'age': False,
    'seg': False,
    'norm': True
}

In [11]:
modelConfig = BertConfig(model_param)
model = BertHF(modelConfig, num_labels=modelConfig.num_labels, n_dim=150, feature_dict=feature_dict)

In [12]:
def load_model(path, model):
    mapdict = {
        'bert.embeddings.word_embeddings.weight_posterior.loc': 'bert.embeddings.word_embeddings.weight',
#         'bert.embeddings.age_embeddings.weight': 'bert.embeddings.age_embeddings.weight_posterior.loc'
    }
    
#     # load pretrained model and update weights
    pretrained_dict = torch.load(path)
#     model_dict = model.state_dict()
#     # 1. filter out unnecessary keys
# #     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and k not in ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']}
#     pretrained_dict = {k: v for k, v in pretrained_dict.items()}
#     # 2. overwrite entries in the existing state dict
#     model_dict.update(pretrained_dict)
    
#     pretrained_dict = {mapdict.get(k): v for k, v in pretrained_dict.items() if k in list(mapdict.keys())}
    
#     model_dict.update(pretrained_dict)
    
#     # 3. load the new state dict
#     model.load_state_dict(model_dict)
    model_dict = model.state_dict()

    for k,v in pretrained_dict.items():
        if (k in model_dict) and (k not in ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']):
            model_dict[k] = v
    for k,v in mapdict.items():
        model_dict[k] = pretrained_dict[v]
    model.load_state_dict(model_dict)
    return model

# model = torch_load_model('/home/shared/yikuan/ACM/model/Diabetes/behrtBayesianEmbeddings_best.bin', model)
model = load_model('/home/shared/yikuan/HF/MLM/PureICD_diag_med.bin', model)
model = model.to(trainConfig.device)
optim = adam(list(model.named_parameters()),config=optim_param)

t_total value of -1 results in schedule not being applied


In [13]:
def precision(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    label, output=label.cpu(), output.detach().cpu()
    tempprc= average_precision_score(label.numpy(),output.numpy())
    auc = roc_auc_score(label.numpy(),output.numpy())
    return tempprc, auc

def precision_test(prob, label):
#     sig = nn.Sigmoid()
#     output=sig(logits)
    tempprc= average_precision_score(label.numpy(),prob.numpy())
    auc = roc_auc_score(label.numpy(),prob.numpy())
    return tempprc, auc

In [14]:
def get_beta(batch_idx, m, beta_type):
    if beta_type == "Blundell":
        beta = 2 ** (m - (batch_idx + 1)) / (2 ** m - 1) 
    elif beta_type == "Soenderby":
        beta = min(epoch / (num_epochs // 4), 1)
    elif beta_type == "Standard":
        beta = 1 / m 
    else:
        beta = 0
    return beta

In [15]:
def train(e, trainload, model, optim, m, beta_type):
    model.train()

    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    cnt = 0
    for step, batch in enumerate(trainload):
        cnt += 1
        beta = get_beta(step, m, beta_type)

        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _, _ = batch

        age_ids = age_ids.to(train_params['device'])
        input_ids = input_ids.to(train_params['device'])
        posi_ids = posi_ids.to(train_params['device'])
        segment_ids = segment_ids.to(train_params['device'])
        attMask = attMask.to(train_params['device'])
        targets = targets.view(-1).to(train_params['device'])

        loss, _ = model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=targets, beta=beta)
        loss.backward()

        temp_loss += loss.item()
        tr_loss += loss.item()
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1

        if step % 500 == 0:
            #             prec, a, b = precision(logits, targets)
            #             print("epoch: {}\t| Cnt: {}\t| Loss: {}\t| precision: {}".format(e, cnt, temp_loss / 500, prec))
            print("epoch: {}\t|step: {}\t|Loss: {}".format(e, step, temp_loss / 500))
            temp_loss = 0

        optim.step()
        optim.zero_grad()

    # Save a trained model
    output_model_file = os.path.join(trainConfig.output_dir, trainConfig.output_name)
    create_folder(trainConfig.output_dir)
    save_model(output_model_file, model)


def evaluation(testload, model):
    model.eval()
    y = []
    y_label = []
    loss_temp = 0
    sig = nn.Sigmoid()
    
    with torch.no_grad():
        for step, batch in enumerate(testload):
            age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _, _ = batch

            age_ids = age_ids.to(train_params['device'])
            input_ids = input_ids.to(train_params['device'])
            posi_ids = posi_ids.to(train_params['device'])
            segment_ids = segment_ids.to(train_params['device'])
            attMask = attMask.to(train_params['device'])
            targets = targets.view(-1).to(train_params['device'])
            
            logits_prob = []
            for i in range(trainConfig.num_samples):
                logits =model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=None)
                logits_prob.append(logits)
            
            logits_prob = torch.mean(sig(torch.stack(logits_prob, dim=1)), dim=1)
            
            # get mean of logits
            
            
            targets = targets.cpu()

            y_label.append(targets)
            y.append(logits_prob.cpu())

        y_label = torch.cat(y_label, dim=0)
        y = torch.cat(y, dim=0)

        tempprc, auc = precision_test(y.view(-1), y_label.view(-1))
        return tempprc, auc

In [16]:
best_pre = 0.0
for e in range(20):
    train(e, trainData, model,optim, len(trainData), "Blundell")
#     auc_train, time_cost_train = evaluation(trainData, model)
    auc_test, auc = evaluation(testData, model)
    print('test precision: {}, auc:{}'.format(auc_test, auc))
    if auc_test >best_pre:
        # Save a trained model
        output_model_file = os.path.join(trainConfig.output_dir, trainConfig.best_name)
        create_folder(trainConfig.output_dir)
        save_model(output_model_file, model)
        best_pre = auc_test

epoch: 0	|step: 0	|Loss: 828.8224375
epoch: 0	|step: 500	|Loss: 829.1026728112101
epoch: 0	|step: 1000	|Loss: 0.4390993211865425
epoch: 0	|step: 1500	|Loss: 0.4371802775263786
epoch: 0	|step: 2000	|Loss: 0.43502347719669343
epoch: 0	|step: 2500	|Loss: 0.4353513500392437
epoch: 0	|step: 3000	|Loss: 0.4392665742635727
epoch: 0	|step: 3500	|Loss: 0.4382819624841213
epoch: 0	|step: 4000	|Loss: 0.43890997382998465
epoch: 0	|step: 4500	|Loss: 0.4402759317755699
epoch: 0	|step: 5000	|Loss: 0.4299034067094326
epoch: 0	|step: 5500	|Loss: 0.42331701096892355
epoch: 0	|step: 6000	|Loss: 0.43830069231987
epoch: 0	|step: 6500	|Loss: 0.4324846982955933
epoch: 0	|step: 7000	|Loss: 0.42691379112005234
epoch: 0	|step: 7500	|Loss: 0.4305554014444351
epoch: 0	|step: 8000	|Loss: 0.42616144585609433
epoch: 0	|step: 8500	|Loss: 0.42814017111063
epoch: 0	|step: 9000	|Loss: 0.4301676575839519
epoch: 0	|step: 9500	|Loss: 0.4257418938875198
epoch: 0	|step: 10000	|Loss: 0.433092533826828
epoch: 0	|step: 10500	|L

epoch: 5	|step: 500	|Loss: 763.5314534712732
epoch: 5	|step: 1000	|Loss: 0.41862406802177426
epoch: 5	|step: 1500	|Loss: 0.41541320756077765
epoch: 5	|step: 2000	|Loss: 0.41664264419674873
epoch: 5	|step: 2500	|Loss: 0.4183664395213127
epoch: 5	|step: 3000	|Loss: 0.41411837765574455
epoch: 5	|step: 3500	|Loss: 0.4216207355856895
epoch: 5	|step: 4000	|Loss: 0.41628529185056684
epoch: 5	|step: 4500	|Loss: 0.4166868434250355
epoch: 5	|step: 5000	|Loss: 0.4104483119547367
epoch: 5	|step: 5500	|Loss: 0.41959629964828493
epoch: 5	|step: 6000	|Loss: 0.417732636898756
epoch: 5	|step: 6500	|Loss: 0.41669280475378034
epoch: 5	|step: 7000	|Loss: 0.4119749778807163
epoch: 5	|step: 7500	|Loss: 0.4141731736958027
epoch: 5	|step: 8000	|Loss: 0.4132458476424217
epoch: 5	|step: 8500	|Loss: 0.4202275656461716
epoch: 5	|step: 9000	|Loss: 0.4155992553830147
epoch: 5	|step: 9500	|Loss: 0.4165930699110031
epoch: 5	|step: 10000	|Loss: 0.42042563939094546
epoch: 5	|step: 10500	|Loss: 0.4215616232454777
epoch:

KeyboardInterrupt: 