In [1]:
import sys 
sys.path.insert(0, '/home/yikuan/project/Code/')

In [2]:
from common.pytorch import save_model
from common.common import create_folder, load_obj
import os
import torch
from ACM.model.utils.utils import age_vocab
import pandas as pd
import pytorch_pretrained_bert as Bert
from ACM.dataLoader.HF import HF_data
from ACM.model.bertKISSGP import BertHF
import gpytorch
from ACM.model.optimiser import adam
import torch.nn as nn
from sklearn.metrics import average_precision_score, roc_auc_score

In [3]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings=config.get('max_position_embeddings'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')
        self.num_labels = config.get('num_labels')


class TrainConfig(object):
    def __init__(self, config):
        self.batch_size = config.get('batch_size')
        self.use_cuda = config.get('use_cuda')
        self.max_len_seq = config.get('max_len_seq')
        self.train_loader_workers = config.get('train_loader_workers')
        self.test_loader_workers = config.get('test_loader_workers')
        self.device = config.get('device')
        self.output_dir = config.get('output_dir')
        self.output_name = config.get('output_name')
        self.best_name = config.get('best_name')

In [4]:
file_config = {
    'vocab':'/home/yikuan/project/Code/ACM/data/Full_vocab',
    'train': '/home/shared/yikuan/ACM/data/Depression/Depression_clean_train.parquet',
    'test': '/home/shared/yikuan/ACM/data/Depression/Depression_clean_test.parquet'
}

In [5]:
global_params = {
    'max_seq_len': 256,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

In [6]:
BertVocab = load_obj(file_config['vocab'])
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])

In [7]:
trainData_raw = pd.read_parquet(file_config['train']).reset_index(drop=True)
testData_raw = pd.read_parquet(file_config['test']).reset_index(drop=True)

In [8]:
train_params = {
    'batch_size': 64,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'train_loader_workers': 3,
    'test_loader_workers': 3,
    'device': 'cuda:0',
    'output_dir': '/home/shared/yikuan/ACM/model/Depression',
    'output_name': 'behrtKISSGP.bin',
    'best_name': 'behrtKISSGP_best.bin',
}

In [9]:
trainConfig = TrainConfig(train_params)

data_set = HF_data(trainConfig, trainData_raw, testData_raw, BertVocab['token2idx'], ageVocab, code='code', age='age')

trainData = data_set.get_weighted_sample_train(4)
testData = data_set.get_test_loader()

In [10]:
model_param = {
    'vocab_size': len(BertVocab['token2idx'].keys()),
    'hidden_size': 150,
    'num_hidden_layers': 4,
    'num_attention_heads': 6,
    'hidden_act': 'gelu',
    'intermediate_size': 108,
    'max_position_embeddings': global_params['max_seq_len'],
    'seg_vocab_size': 2,
    'age_vocab_size': len(ageVocab.keys()),
    'prior_prec': 1e0,
    'prec_init': 1e0,
    'initializer_range': 0.02,
    'num_labels': 1,
    'hidden_dropout_prob': 0.29,
    'attention_probs_dropout_prob': 0.38
}

In [11]:
modelConfig = BertConfig(model_param)

model = BertHF(modelConfig, n_dim=24,grid_size=40, ard_num_dims=24)

likelihood = gpytorch.likelihoods.BernoulliLikelihood().to(trainConfig.device)
mll = gpytorch.mlls.VariationalELBO(likelihood, model.gp_layer, num_data=len(trainData.dataset))

In [12]:
def load_model(path, model):
    # load pretrained model and update weights
    pretrained_dict = torch.load(path)
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and k not in ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)
    return model

model = load_model('/home/shared/yikuan/HF/MLM/PureICD_diag_med.bin', model)
model = model.to(trainConfig.device)
optim = adam(list(model.named_parameters()),config=optim_param)

t_total value of -1 results in schedule not being applied


In [13]:
def precision(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    label, output=label.cpu(), output.detach().cpu()
    tempprc= average_precision_score(label.numpy(),output.numpy())
    auc = roc_auc_score(label.numpy(),output.numpy())
    return tempprc, output, label

def precision_test(logits, label):
#     sig = nn.Sigmoid()
#     output=sig(logits)
    tempprc= average_precision_score(label.numpy(),logits.numpy())
    auc = roc_auc_score(label.numpy(),logits.numpy())
    return tempprc, auc

In [14]:
def train(e, trainload, model, likelihood, optim):
    model.train()
    likelihood.train()

    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    cnt = 0
    for step, batch in enumerate(trainload):
        cnt += 1

        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _, _ = batch

        age_ids = age_ids.to(train_params['device'])
        input_ids = input_ids.to(train_params['device'])
        posi_ids = posi_ids.to(train_params['device'])
        segment_ids = segment_ids.to(train_params['device'])
        attMask = attMask.to(train_params['device'])
        targets = targets.view(-1).to(train_params['device'])

        output = model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=targets)
        loss = -mll(output, targets)
        loss.backward()

        temp_loss += loss.item()
        tr_loss += loss.item()
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1

        if step % 500 == 0:
            #             prec, a, b = precision(logits, targets)
            #             print("epoch: {}\t| Cnt: {}\t| Loss: {}\t| precision: {}".format(e, cnt, temp_loss / 500, prec))
            print("epoch: {}\t|step: {}\t|Loss: {}".format(e, step, temp_loss / 500))
            temp_loss = 0

        optim.step()
        optim.zero_grad()

    # Save a trained model
    output_model_file = os.path.join(trainConfig.output_dir, trainConfig.output_name)
    create_folder(trainConfig.output_dir)
    save_model(output_model_file, model)


def evaluation(testload, model, likelihood):
    model.eval()
    y = []
    y_label = []
    loss_temp = 0

    model.eval()
    likelihood.eval()
    with torch.no_grad():
        for step, batch in enumerate(testload):
            age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _, _ = batch

            age_ids = age_ids.to(train_params['device'])
            input_ids = input_ids.to(train_params['device'])
            posi_ids = posi_ids.to(train_params['device'])
            segment_ids = segment_ids.to(train_params['device'])
            attMask = attMask.to(train_params['device'])
            targets = targets.view(-1).to(train_params['device'])

            output = likelihood(
                model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=targets))

            logits = output.mean.float().cpu()
            targets = targets.cpu()

            y_label.append(targets)
            y.append(logits)

        y_label = torch.cat(y_label, dim=0)
        y = torch.cat(y, dim=0)

        tempprc, auc = precision_test(y.view(-1), y_label.view(-1))
        return tempprc, auc

In [15]:
best_pre = 0
for e in range(20):
    train(e, trainData, model, likelihood,optim)
#     auc_train, time_cost_train = evaluation(trainData, model)
    auc_test, auc = evaluation(testData, model, likelihood)
    print('test precision: {}, {}'.format(auc_test, auc))
    if auc_test >best_pre:
        # Save a trained model
        output_model_file = os.path.join(trainConfig.output_dir, trainConfig.best_name)
        create_folder(trainConfig.output_dir)
        save_model(output_model_file, model)
        best_pre = auc_test

epoch: 0	|step: 0	|Loss: 0.008768121719360351
epoch: 0	|step: 500	|Loss: 3.804774790763855
epoch: 0	|step: 1500	|Loss: 2.9021462812423704
epoch: 0	|step: 2000	|Loss: 2.7255283317565917
epoch: 0	|step: 2500	|Loss: 2.6061324028968813
epoch: 0	|step: 3000	|Loss: 2.5140697963237764
epoch: 0	|step: 3500	|Loss: 2.420725441932678
epoch: 0	|step: 4000	|Loss: 2.2917193882465363
epoch: 0	|step: 4500	|Loss: 2.1990071861743927
epoch: 0	|step: 5000	|Loss: 2.080232113599777
epoch: 0	|step: 5500	|Loss: 1.9751528468132018
epoch: 0	|step: 6000	|Loss: 1.888222297668457
epoch: 0	|step: 6500	|Loss: 1.8062055149078369
epoch: 0	|step: 7000	|Loss: 1.7255313392877578
epoch: 0	|step: 7500	|Loss: 1.6390013983249665
epoch: 0	|step: 8000	|Loss: 1.5826390161514283
epoch: 0	|step: 8500	|Loss: 1.5191547157764436
epoch: 0	|step: 9000	|Loss: 1.4799510362148285
epoch: 0	|step: 9500	|Loss: 1.3793616534471511
epoch: 0	|step: 10000	|Loss: 1.3217113312482833
epoch: 0	|step: 10500	|Loss: 1.281630457162857
epoch: 0	|step: 11

epoch: 5	|step: 0	|Loss: 0.0008092639446258545
epoch: 5	|step: 500	|Loss: 0.41199562850594523
epoch: 5	|step: 1000	|Loss: 0.4069055047929287
epoch: 5	|step: 1500	|Loss: 0.4122260991036892
epoch: 5	|step: 2000	|Loss: 0.41373373809456826
epoch: 5	|step: 2500	|Loss: 0.41429139095544815
epoch: 5	|step: 3000	|Loss: 0.41204263213276865
epoch: 5	|step: 3500	|Loss: 0.4052334398329258
epoch: 5	|step: 4000	|Loss: 0.40777027887105943
epoch: 5	|step: 4500	|Loss: 0.40784852385520937
epoch: 5	|step: 5000	|Loss: 0.41132773491740227
epoch: 5	|step: 5500	|Loss: 0.40612172228097915
epoch: 5	|step: 6000	|Loss: 0.4113026424050331
epoch: 5	|step: 6500	|Loss: 0.4128769714832306
epoch: 5	|step: 7000	|Loss: 0.41051990085840223
epoch: 5	|step: 7500	|Loss: 0.4085178347826004
epoch: 5	|step: 8000	|Loss: 0.4067641128897667
epoch: 5	|step: 8500	|Loss: 0.4004742228090763
epoch: 5	|step: 9000	|Loss: 0.41285968297719955
epoch: 5	|step: 9500	|Loss: 0.41167802557349203
epoch: 5	|step: 10000	|Loss: 0.4017868762910366
ep

epoch: 10	|step: 1000	|Loss: 0.39370824828743933
epoch: 10	|step: 1500	|Loss: 0.4066557643711567
epoch: 10	|step: 2000	|Loss: 0.39464907044172287
epoch: 10	|step: 2500	|Loss: 0.4070910278856754
epoch: 10	|step: 3000	|Loss: 0.3966818631887436
epoch: 10	|step: 3500	|Loss: 0.4002802145779133
epoch: 10	|step: 4000	|Loss: 0.3974001688659191
epoch: 10	|step: 4500	|Loss: 0.4024540068805218
epoch: 10	|step: 5000	|Loss: 0.3989296560287476
epoch: 10	|step: 5500	|Loss: 0.3997889396250248
epoch: 10	|step: 6000	|Loss: 0.3959154840409756
epoch: 10	|step: 6500	|Loss: 0.395023844152689
epoch: 10	|step: 7000	|Loss: 0.40415511006116867
epoch: 10	|step: 7500	|Loss: 0.3998586537539959
epoch: 10	|step: 8000	|Loss: 0.39782161632180213
epoch: 10	|step: 8500	|Loss: 0.4029886754751205
epoch: 10	|step: 9000	|Loss: 0.3995700365900993
epoch: 10	|step: 9500	|Loss: 0.39412993550300596
epoch: 10	|step: 10000	|Loss: 0.3980654469430447
epoch: 10	|step: 10500	|Loss: 0.392073753207922
epoch: 10	|step: 11000	|Loss: 0.400

KeyboardInterrupt: 