In [1]:
import sys 
sys.path.insert(0, '/home/yikuan/project/Code/')

In [2]:
from common.pytorch import save_model
from common.common import create_folder, load_obj
import os
import torch
from ACM.model.utils.utils import age_vocab
import pandas as pd
import pytorch_pretrained_bert as Bert
from ACM.dataLoader.HF import HF_data
from ACM.model.behrt import BertHF
from ACM.model.optimiser import adam
import torch.nn as nn
from sklearn.metrics import average_precision_score, roc_auc_score

In [3]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings=config.get('max_position_embeddings'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')
        self.num_labels = config.get('num_labels')


class TrainConfig(object):
    def __init__(self, config):
        self.batch_size = config.get('batch_size')
        self.use_cuda = config.get('use_cuda')
        self.max_len_seq = config.get('max_len_seq')
        self.train_loader_workers = config.get('train_loader_workers')
        self.test_loader_workers = config.get('test_loader_workers')
        self.device = config.get('device')
        self.output_dir = config.get('output_dir')
        self.output_name = config.get('output_name')
        self.best_name = config.get('best_name')

In [4]:
file_config = {
    'vocab':'/home/yikuan/project/Code/ACM/data/Full_vocab',
    'train': '/home/shared/yikuan/ACM/data/Depression/Depression_clean_train.parquet',
    'test': '/home/shared/yikuan/ACM/data/Depression/Depression_clean_test.parquet'
}

In [5]:
global_params = {
    'max_seq_len': 256,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

In [6]:
BertVocab = load_obj(file_config['vocab'])
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])

In [7]:
trainData = pd.read_parquet(file_config['train']).reset_index(drop=True)
testData = pd.read_parquet(file_config['test']).reset_index(drop=True)

In [8]:
train_params = {
    'batch_size': 64,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'train_loader_workers': 3,
    'test_loader_workers': 3,
    'device': 'cuda:1',
    'output_dir': '/home/shared/yikuan/ACM/model/Depression',
    'output_name': 'behrt.bin',
    'best_name': 'behrt_best.bin',
}

In [9]:
trainConfig = TrainConfig(train_params)

data_set = HF_data(trainConfig, trainData, testData, BertVocab['token2idx'], ageVocab, code='code', age='age')

trainData = data_set.get_weighted_sample_train(4)
testData = data_set.get_test_loader()

In [10]:
model_param = {
    'vocab_size': len(BertVocab['token2idx'].keys()),
    'hidden_size': 150,
    'num_hidden_layers': 4,
    'num_attention_heads': 6,
    'hidden_act': 'gelu',
    'intermediate_size': 108,
    'max_position_embeddings': global_params['max_seq_len'],
    'seg_vocab_size': 2,
    'age_vocab_size': len(ageVocab.keys()),
    'prior_prec': 1e0,
    'prec_init': 1e0,
    'initializer_range': 0.02,
    'num_labels': 1,
    'hidden_dropout_prob': 0.29,
    'attention_probs_dropout_prob': 0.38
}

In [11]:
modelConfig = BertConfig(model_param)
model = BertHF(modelConfig, num_labels=modelConfig.num_labels, n_dim=150)

In [12]:
def load_model(path, model):
    # load pretrained model and update weights
    pretrained_dict = torch.load(path)
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and k not in ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)
    return model

model = load_model('/home/shared/yikuan/HF/MLM/PureICD_diag_med.bin', model)
model = model.to(trainConfig.device)
optim = adam(list(model.named_parameters()),config=optim_param)

t_total value of -1 results in schedule not being applied


In [13]:
def precision(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    label, output=label.cpu(), output.detach().cpu()
    prc= average_precision_score(label.numpy(),output.numpy())
    auc = roc_auc_score(label.numpy(),output.numpy())
    return prc, auc

def precision_test(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    prc= average_precision_score(label.numpy(),output.numpy())
    auc = roc_auc_score(label.numpy(),output.numpy())
    return prc, auc

In [14]:
def train(e, trainload, model, optim):
    model.train()

    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    cnt = 0
    for step, batch in enumerate(trainload):
        cnt += 1

        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _,_ = batch

        age_ids = age_ids.to(train_params['device'])
        input_ids = input_ids.to(train_params['device'])
        posi_ids = posi_ids.to(train_params['device'])
        segment_ids = segment_ids.to(train_params['device'])
        attMask = attMask.to(train_params['device'])
        targets = targets.view(-1).to(train_params['device'])

        loss, _ = model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=targets)
        loss.backward()

        temp_loss += loss.item()
        tr_loss += loss.item()
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1

        if step % 500 == 0:
            #             prec, a, b = precision(logits, targets)
            #             print("epoch: {}\t| Cnt: {}\t| Loss: {}\t| precision: {}".format(e, cnt, temp_loss / 500, prec))
            print("epoch: {}\t|step: {}\t|Loss: {}".format(e, step, temp_loss / 500))
            temp_loss = 0

        optim.step()
        optim.zero_grad()

    # Save a trained model
    output_model_file = os.path.join(trainConfig.output_dir, trainConfig.output_name)
    create_folder(trainConfig.output_dir)
    save_model(output_model_file, model)


def evaluation(testload, model):
    model.eval()
    y = []
    y_label = []
    loss_temp = 0
    with torch.no_grad():
        for step, batch in enumerate(testload):
            age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _,_ = batch

            age_ids = age_ids.to(train_params['device'])
            input_ids = input_ids.to(train_params['device'])
            posi_ids = posi_ids.to(train_params['device'])
            segment_ids = segment_ids.to(train_params['device'])
            attMask = attMask.to(train_params['device'])
            targets = targets.view(-1).to(train_params['device'])

            logits =model(input_ids, age_ids, segment_ids, posi_ids, attention_mask=attMask, labels=None)
            targets = targets.cpu()

            y_label.append(targets)
            y.append(logits.cpu())

        y_label = torch.cat(y_label, dim=0)
        y = torch.cat(y, dim=0)

        tempprc, prc = precision_test(y.view(-1), y_label.view(-1))
        return tempprc, prc

In [15]:
best_pre = 0
for e in range(20):
    train(e, trainData, model,optim)
#     auc_train, time_cost_train = evaluation(trainData, model)
    auc_test, auc = evaluation(testData, model)
    print('test prc: {}, auc{}'.format(auc_test, auc))
    if auc_test >best_pre:
        # Save a trained model
        output_model_file = os.path.join(trainConfig.output_dir, trainConfig.best_name)
        create_folder(trainConfig.output_dir)
        save_model(output_model_file, model)
        best_pre = auc_test

epoch: 0	|step: 0	|Loss: 0.0013849666118621825
epoch: 0	|step: 500	|Loss: 0.4600920511484146
epoch: 0	|step: 1000	|Loss: 0.43003752520680427
epoch: 0	|step: 1500	|Loss: 0.4301555517315865
epoch: 0	|step: 2000	|Loss: 0.4221146545410156
epoch: 0	|step: 2500	|Loss: 0.426911899805069
epoch: 0	|step: 3000	|Loss: 0.422379921913147
epoch: 0	|step: 3500	|Loss: 0.4210122010707855
epoch: 0	|step: 4000	|Loss: 0.4199312850832939
epoch: 0	|step: 4500	|Loss: 0.424438825905323
epoch: 0	|step: 5000	|Loss: 0.4217379032969475
epoch: 0	|step: 5500	|Loss: 0.4201628498733044
epoch: 0	|step: 6000	|Loss: 0.4155524201989174
epoch: 0	|step: 6500	|Loss: 0.41815310961008073
epoch: 0	|step: 7000	|Loss: 0.4222968499660492
epoch: 0	|step: 7500	|Loss: 0.41978851452469823
epoch: 0	|step: 8000	|Loss: 0.4156451911628246
epoch: 0	|step: 8500	|Loss: 0.4214037240743637
epoch: 0	|step: 9000	|Loss: 0.4193955805301666
epoch: 0	|step: 9500	|Loss: 0.4215896881520748
epoch: 0	|step: 10000	|Loss: 0.4155043041408062
epoch: 0	|ste

epoch: 5	|step: 0	|Loss: 0.0006321948766708374
epoch: 5	|step: 500	|Loss: 0.40258555123209955
epoch: 5	|step: 1000	|Loss: 0.4039654767513275
epoch: 5	|step: 1500	|Loss: 0.3990810188651085
epoch: 5	|step: 2000	|Loss: 0.4041485235691071
epoch: 5	|step: 2500	|Loss: 0.3985883340239525
epoch: 5	|step: 3000	|Loss: 0.40369786059856416
epoch: 5	|step: 3500	|Loss: 0.40100994789600375
epoch: 5	|step: 4000	|Loss: 0.40069636118412016
epoch: 5	|step: 4500	|Loss: 0.4013688159286976
epoch: 5	|step: 5000	|Loss: 0.40321208247542384
epoch: 5	|step: 5500	|Loss: 0.4081941133439541
epoch: 5	|step: 6000	|Loss: 0.3997779286503792
epoch: 5	|step: 6500	|Loss: 0.4030659210085869
epoch: 5	|step: 7000	|Loss: 0.4006944591403008
epoch: 5	|step: 7500	|Loss: 0.4001564683318138
epoch: 5	|step: 8000	|Loss: 0.3993958587050438
epoch: 5	|step: 8500	|Loss: 0.4020997706055641
epoch: 5	|step: 9000	|Loss: 0.4058608768582344
epoch: 5	|step: 9500	|Loss: 0.40648863703012467
epoch: 5	|step: 10000	|Loss: 0.3958017321527004
epoch: 

Exception ignored in: <function _DataLoaderIter.__del__ at 0x7fea1bfb8598>
Traceback (most recent call last):
  File "/home/yikuan/anaconda/envs/py3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 677, in __del__
    self._shutdown_workers()
  File "/home/yikuan/anaconda/envs/py3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 659, in _shutdown_workers
    w.join()
  File "/home/yikuan/anaconda/envs/py3/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/home/yikuan/anaconda/envs/py3/lib/python3.7/multiprocessing/popen_fork.py", line 48, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
  File "/home/yikuan/anaconda/envs/py3/lib/python3.7/multiprocessing/popen_fork.py", line 28, in poll
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt: 


KeyboardInterrupt: 