In [None]:
import argparse
import json
import os
import random
from datetime import datetime
from pprint import pprint
from shutil import copyfile, move

import numpy as np
import torch

from data_utils.mediqa_utils import submit, eval_model,mediqa_name_list
from data_utils.label_map import DATA_META, GLOBAL_MAP, DATA_TYPE, DATA_SWAP, TASK_TYPE, generate_decoder_opt
from data_utils.log_wrapper import create_logger
from data_utils.utils import set_environment
from data_utils.mediqa2019_evaluator_allTasks_final import eval_mediqa_official
from mt_dnn.batcher import BatchGen
from mt_dnn.model import MTDNNModel
from bert.modeling import BertModel
from bert.modeling import BertConfig
#import pdb

### Import config

In [None]:
data_config = json.load(open('data_config.json', 'r'))
train_config = json.load(open('train_config.json', 'r'))
model_config = json.load(open('model_config.json', 'r'))
config = {}
config.update(data_config)
config.update(train_config)
config.update(model_config)
print(len(config))

### None configs by default

In [None]:
config['init_config'] = None
config['resume_scoring'] = None
config['test_datasets'] = None
config['predict_split'] = None
config['mediqa_pairloss'] = None
config['batch_size_eval'] = None

### Cuda Flag

In [None]:
cuda_flag = torch.cuda.is_available()

### Pre-process config

In [None]:
config['mtl_observe_datasets'] = config['mtl_observe_datasets'].split(',')
if config['float_medquad']:
    TASK_TYPE['medquad'] = 1
    DATA_META['medquad'] = 1

config['float_target'] = False

if config['batch_size_eval'] is None:
    config['batch_size_eval'] = config['batch_size']

output_dir = config['output_dir']
data_dir = config['data_dir']
if config['test_datasets'] is None:
    config['test_datasets'] = config['train_datasets']
# args.train_datasets = args.train_datasets.split(',')
# args.test_datasets = args.test_datasets.split(',')
if len(config['train_datasets'])==1:
    config['mtl_observe_datasets'] = config['train_datasets']
config['external_datasets'] = config['external_datasets'].split(',') if config['external_datasets'] != '' else []
config['train_datasets'] = config['train_datasets'] + config['external_datasets']

# pprint(config)

os.makedirs(output_dir, exist_ok=True)
output_dir = os.path.abspath(output_dir)

set_environment(config['seed'], cuda_flag)
log_path = os.path.join(output_dir, config['log_file'])
logger =  create_logger(__name__, to_disk=True, log_file=log_path)
logger.info(config['answer_opt'])

#TODO: Check whether this should be loaded from Minio
tasks_config = {}
if os.path.exists(config['task_config_path']):
    with open(config['task_config_path'], 'r') as reader:
        tasks_config = json.loads(reader.read())

### Json dump utility

In [None]:
def dump(path, data):
    with open(path ,'w') as f:
        json.dump(data, f)

### Load dev and test dataset lists 

In [None]:
def get_dev_test_dataset(opt, nclass_list):
    
    opt['label_size'] = nclass_list #','.join([str(l) for l in nclass_list])
    logger.info(opt['label_size'])
    dev_data_list = []
    test_data_list = []
    
    for dataset in opt['test_datasets']:
        prefix = dataset.split('_')[0]
        task_id = tasks_class[DATA_META[prefix]] if opt['mtl_opt'] > 0 else tasks[prefix]
        task_type = TASK_TYPE[prefix]

        pw_task = False

        assert prefix in DATA_TYPE
        data_type = DATA_TYPE[prefix]

        if opt['predict_split'] is not None:
            dev_path = os.path.join(data_dir, '{}_{}.json'.format(dataset, 
                args.predict_split))
        else:
            dev_path = os.path.join(data_dir, '{}_dev.json'.format(dataset))
        
        dev_data = None
        if os.path.exists(dev_path):
            
            #TODO: Check whether this should be loaded from Minio
            dev_data = BatchGen(BatchGen.load(dev_path, False, pairwise=pw_task, maxlen=opt['max_seq_len'],
                                            opt=opt, dataset=dataset),
                                  batch_size=opt['batch_size_eval'],
                                  gpu=cuda_flag, is_train=False,
                                  task_id=task_id,
                                  maxlen=opt['max_seq_len'],
                                  pairwise=pw_task,
                                  data_type=data_type,
                                  task_type=task_type,
                                  dataset_name=dataset)
        dev_data_list.append(dev_data)

        test_path = os.path.join(data_dir, '{}_test.json'.format(dataset))
        test_data = None
        if os.path.exists(test_path):
            
            #TODO: Check whether this should be loaded from Minio
            test_data = BatchGen(BatchGen.load(test_path, False, pairwise=pw_task, 
                                            maxlen=opt['max_seq_len'],opt=opt, dataset=dataset),
                                  batch_size=opt['batch_size_eval'],
                                  gpu=cuda_flag, is_train=False,
                                  task_id=task_id,
                                  maxlen=opt['max_seq_len'],
                                  pairwise=pw_task,
                                  data_type=data_type,
                                  task_type=task_type,
                                  dataset_name=dataset)
        test_data_list.append(test_data)
        
        return {'dev_list': dev_data_list, 'test_list': test_data_list}

### Load train dataset list

In [None]:
def get_train_dataset(opt, data_dir):
    
#     opt = vars(args)
    # update data dir
    opt['data_dir'] = data_dir
#     batch_size = args.batch_size
    batch_size = opt['batch_size']
    train_data_list = []
    tasks = {}
    tasks_class = {}
    nclass_list = []
    decoder_opts = []
    dropout_list = []

    for dataset in opt['train_datasets']:
        prefix = dataset.split('_')[0]
        if prefix in tasks: continue
        assert prefix in DATA_META
        assert prefix in DATA_TYPE
        data_type = DATA_TYPE[prefix]
        nclass = DATA_META[prefix]
        task_id = len(tasks)
        
        if opt['mtl_opt'] > 0:
            task_id = tasks_class[nclass] if nclass in tasks_class else len(tasks_class)

        task_type = TASK_TYPE[prefix]
        pw_task = False

        dopt = generate_decoder_opt(prefix, opt['answer_opt'])
        if task_id < len(decoder_opts):
            decoder_opts[task_id] = min(decoder_opts[task_id], dopt)
        else:
            decoder_opts.append(dopt)

        if prefix not in tasks:
            tasks[prefix] = len(tasks)
            if opt['mtl_opt'] < 1: nclass_list.append(nclass)

        if (nclass not in tasks_class):
            tasks_class[nclass] = len(tasks_class)
            if opt['mtl_opt'] > 0: nclass_list.append(nclass)

        dropout_p = opt['dropout_p']
        if tasks_config and prefix in tasks_config:
            dropout_p = tasks_config[prefix]
        dropout_list.append(dropout_p)

        train_path = os.path.join(data_dir, '{}_train.json'.format(dataset))
        logger.info('Loading {} as task {}'.format(train_path, task_id))
        
        #TODO: Check whether this should be loaded from Minio
        train_data = BatchGen(BatchGen.load(train_path, True, pairwise=pw_task, maxlen=opt['max_seq_len'], 
                                        opt=opt, dataset=dataset),
                                batch_size=batch_size,
                                dropout_w=opt['dropout_w'],
                                gpu=cuda_flag,
                                task_id=task_id,
                                maxlen=opt['max_seq_len'],
                                pairwise=pw_task,
                                data_type=data_type,
                                task_type=task_type,
                                dataset_name=dataset)
        train_data.reset()
        train_data_list.append(train_data)

    opt['answer_opt'] = decoder_opts
    opt['tasks_dropout_p'] = dropout_list
    
    return {'train_list': train_data_list, 'nclass_list': n_class_list}

### Update state dict utility

In [None]:
def update_state_dict(opt, model_path, state_dict):
    
    if os.path.exists(model_path):
        #TODO: Check whether this should be loaded from Minio
        state_dict = torch.load(model_path)
        
        if opt['init_config'] is not None: # load huggingface model
            config = json.load(open(args.init_config))
            state_dict={'config':config, 'state':state_dict}
            
        if arg_config['finetune']:
            # only resume config and state
            del_keys=set(state_dict.keys())-set(['config','state'])
            for key in del_keys:
                del state_dict[key]
                
            # Parameterize this
            resume_configs=json.load(open('config/resume_configs.json'))
            del_keys=set(state_dict['config'].keys())-set(resume_configs)
            
            for key in del_keys:
                del state_dict['config'][key]
                
            if opt['resume_scoring'] is not None: 
                
                for key in state_dict['state'].keys():
                    if 'scoring_list.0' in key:
                        state_dict['state'][key]=state_dict['state'][key.replace('0',str(opt['resume_scoring']))]
                        # other scorings will be deleted during loading process, since finetune only has one task
                        
            elif not opt['retain_scoring']:
                del_keys = [k for k in state_dict['state'] if 'scoring_list' in k]
                
                for key in del_keys:                    
                    print('deleted previous weight:',key)
                    del state_dict['state'][key]
        
        config = state_dict['config']
        config['attention_probs_dropout_prob'] = opt['bert_dropout_p']
        config['hidden_dropout_prob'] = opt['bert_dropout_p']
        opt.update(config)
    else:
        logger.error('#' * 20)
        logger.error('Could not find the init model!\n The parameters will be initialized randomly!')
        logger.error('#' * 20)
        config = BertConfig(vocab_size_or_config_json_file=30522).to_dict()
        opt.update(config)    

### Evaluate on test dataset

In [None]:
def dev_test_eval(config, model, dev_data_list, test_data_list, dev_split):
        
        this_performance = {}
        for idx, dataset in enumerate(config['test_datasets']):
            prefix = dataset.split('_')[0]
            dev_data = dev_data_list[idx]
            if dev_data is not None:
                dev_metrics, dev_predictions, scores, golds, dev_ids= eval_model(model, dev_data, dataset=prefix,
                                                                                 use_cuda=cuda_flag)
                #TODO: Check whether this should be saved to Minio
                score_file = os.path.join(output_dir, '{}_{}_scores_{}.json'.format(dataset, dev_split, epoch))
                results = {'metrics': dev_metrics, 'predictions': dev_predictions, 'uids': dev_ids, 'scores': scores}
                dump(score_file, results)
                
                #TODO: Check whether this should be saved to Minio
                official_score_file = os.path.join(output_dir, '{}_{}_scores_{}.csv'.format(dataset, dev_split, epoch))
                submit(official_score_file, results,dataset_name=prefix, threshold=2.0+config['mediqa_score_offset'])
                
                if prefix in mediqa_name_list:
                    logger.warning('self test numbers:{}'.format(dev_metrics))
                    
                    #TODO: Check whether this should be loaded from Minio
                    if '_' in dataset:
                        affix = dataset.split('_')[1]
                        ground_truth_path=os.path.join(config['data_root'],'mediqa/task3_qa/gt_{}_{}.csv'.format(dev_split,affix))
                    else:
                        ground_truth_path=os.path.join(config['data_root'],'mediqa/task3_qa/gt_{}.csv'.format(dev_split))
                    
                    official_result=eval_mediqa_official(pred_path=official_score_file, ground_truth_path=ground_truth_path, 
                        eval_qa_more=config['mediqa_eval_more'])
                    
                    logger.warning("MediQA dev eval result:{}".format(official_result))
                    
                    if config['mediqa_eval_more']:
                        dev_metrics={'ACC':official_result['score']*100,'Spearman':official_result['score_secondary']*100,
                                    'F1':dev_metrics['F1'], 'MRR':official_result['meta']['MRR'], 'MAP':official_result['MAP'],
                                    'P@1':official_result['meta']['P@1']}
                    else:
                        dev_metrics={'ACC':official_result['score']*100,'Spearman':official_result['score_secondary']*100}

                for key, val in dev_metrics.items():
                    logger.warning("Task {0} -- epoch {1} -- Dev {2}: {3:.3f}".format(dataset, epoch, key, val))
            if config['predict_split'] is not None:
                continue
            print('args.mtl_observe_datasets:',config['mtl_observe_datasets'], dataset)
            if dataset in config['mtl_observe_datasets']:
                this_performance[dataset]=np.mean([val for val in dev_metrics.values()])
            test_data = test_data_list[idx]
            if test_data is not None:
                test_metrics, test_predictions, scores, golds, test_ids= eval_model(model, test_data, dataset=prefix,
                                                                                 use_cuda=cuda_flag, with_label=False)
                for key, val in test_metrics.items():
                    logger.warning("Task {0} -- epoch {1} -- Test {2}: {3:.3f}".format(dataset, epoch, key, val))
                score_file = os.path.join(output_dir, '{}_test_scores_{}.json'.format(dataset, epoch))
                results = {'metrics': test_metrics, 'predictions': test_predictions, 'uids': test_ids, 'scores': scores}
                dump(score_file, results)
                # if dataset in mediqa_name_list:
                official_score_file = os.path.join(output_dir, '{}_test_scores_{}.csv'.format(dataset, epoch))
                submit(official_score_file, results,dataset_name=prefix, threshold=2.0+config['mediqa_score_offset'])
                logger.info('[new test scores saved.]')
                
        return this_performance    

### Train util per epoch

In [None]:
def train_per_epoch(model, train_data_list, epoch, config):
    
        logger.warning('At epoch {}'.format(epoch))
        if epoch==0 and config['freeze_bert_first']:
            model.network.freeze_bert()
            logger.warning('Bert freezed.')
        if epoch==1 and config['freeze_bert_first']:
            model.network.unfreeze_bert()
            logger.warning('Bert unfreezed.')
        start = datetime.now()
        all_indices=[]
        if len(config['external_datasets'])> 0 and config['external_include_ratio'] > 0:
            main_indices = []
            extra_indices = []
            for data_idx,batcher in enumerate(train_data_list):
                if batcher. dataset_name not in config['external_datasets']:
                    main_indices += [data_idx] * len(batcher)
                else:
                    extra_indices += [data_idx] * len(batcher)

            random_picks=int(min(len(main_indices) * config['external_include_ratio'], len(extra_indices)))
            extra_indices = np.random.choice(extra_indices, random_picks, replace=False)
            if config['mix_opt'] > 0:
                extra_indices = extra_indices.tolist()
                random.shuffle(extra_indices)
                all_indices = extra_indices + main_indices
            else:
                all_indices = main_indices + extra_indices.tolist()
        else:
            for i in range(1, len(train_data_list)):
                all_indices += [i] * len(train_data_list[i])
            if config['mix_opt'] > 0:
                random.shuffle(all_indices)
            all_indices += [0] * len(train_data_list[0])
        if config['mix_opt'] < 1:
            random.shuffle(all_indices)
        if config['test_mode']:
            all_indices=all_indices[:2]
        if config['predict_split'] is not None:
            all_indices=[]
            dev_split=config['predict_split']
        else:
            dev_split='dev'

        for i in range(len(all_indices)):
            task_id = all_indices[i]
            batch_meta, batch_data= next(all_iters[task_id])
            model.update(batch_meta, batch_data)
            if (model.updates) % config['log_per_updates'] == 0 or model.updates == 1:
                logger.info('Task [{0:2}] updates[{1:6}] train loss[{2:.5f}] remaining[{3}]'.format(task_id,
                    model.updates, model.train_loss.avg,
                    str((datetime.now() - start) / (i + 1) * (len(all_indices) - i - 1)).split('.')[0]))
        
        return {'dev_split': dev_split}

### Model saving util

In [None]:
def save_model(model, config, output_dir, epoch, best_dataset_performance):
    
    #TODO: Check whether this needs to be saved directly to Minio
    if not config['not_save_model']:

        model_name = 'model_last.pt' if config['save_last'] else 'model_{}.pt'.format(epoch) 
        model_file = os.path.join(output_dir, model_name)

        if config['save_last'] and os.path.exists(model_file):
            model_temp=os.path.join(output_dir, 'model_secondlast.pt')
            copyfile(model_file, model_temp)

        model.save(model_file)

        if config['save_best'] and best_epoch==epoch:
            best_path = os.path.join(output_dir,'best_model.pt')
            copyfile(model_file,best_path)

            for dataset in config['mtl_observe_datasets']:
                if best_dataset_performance[dataset]['epoch'] == epoch:
                    best_path = os.path.join(output_dir,'best_model_{}.pt'.format(dataset))
                    copyfile(model_file, best_path)   

### Final Training script

In [None]:
def train(config, data_dir, output_dir):
    
    logger.info('Launching the MT-DNN training')
    logger.info('#' * 20)
    logger.info(opt)
    logger.info('#' * 20)
    
    train_data_dict = get_train_dataset(config, data_dir)
    train_data_list = train_data_dict['train_list']
    nclass_list = train_data_dict['nclass_list']
    
    dev_test_dict = get_dev_test_dataset(config, nclass_list)
    dev_data_list = dev_test_dict['dev_list']
    test_data_list = dev_test_dict['test_list']

    all_iters =[iter(item) for item in train_data_list]
    all_lens = [len(bg) for bg in train_data_list]
    num_all_batches = config['epochs'] * sum(all_lens)

    if len(config['external_datasets']) > 0 and config['external_include_ratio'] > 0:
        num_in_domain_batches = config['epochs']* sum(all_lens[:-len(config['.external_datasets'])])
        num_all_batches = num_in_domain_batches * (1 + config['external_include_ratio'])
    # pdb.set_trace()

    model_path = config['init_checkpoint']
    state_dict = None
    update_state_dict(config, model_path, state_dict)
    

    model = MTDNNModel(config, state_dict=state_dict, num_train_step=num_all_batches)
    ####model meta str
    headline = '############# Model Arch of MT-DNN #############'
    ###print network
    # logger.info('\n{}\n{}\n'.format(headline, model.network))
    
    #TODO: Check whether it needs to be saved to Minio
    # dump config
    config_file = os.path.join(output_dir, 'config.json')
    with open(config_file, 'w', encoding='utf-8') as writer:
        writer.write('{}\n'.format(json.dumps(config)))
        writer.write('\n{}\n{}\n'.format(headline, model.network))

    logger.info("Total number of params: {}".format(model.total_param))

    if config['freeze_layers'] > 0:
        model.network.freeze_layers(config['freeze_layers'])

    if cuda_flag:
        model.cuda()
        
    best_epoch=-1
    best_performance=0 
    best_dataset_performance={dataset:{'perf':0,'epoch':-1} for dataset in config['mtl_observe_datasets']}
    
     for epoch in range(config['epochs']):
        
        train_dict = train_per_epoch(model, train_data_list, epoch, config)
        for train_data in train_data_list:
            train_data.reset()

        this_performance = dev_test_eval(config, model, dev_data_list, test_data_list, train_dict['dev_split'])
        print('this_performance:',this_performance)
    
        if config['predict_split'] is not None:
            break
        epoch_performance = sum([val for val in this_performance.values()])
        if epoch_performance>best_performance:
            print('changed:',epoch_performance,best_performance)
            best_performance=epoch_performance
            best_epoch=epoch

        for dataset in config['mtl_observe_datasets']:
            if best_dataset_performance[dataset]['perf'] < this_performance[dataset]:
                best_dataset_performance[dataset]= {'perf':this_performance[dataset],
                                                   'epoch':epoch} 


        print('current best:',best_performance,'at epoch', best_epoch)
        
        save_model(model, config, output_dir, epoch, best_dataset_performance)
    
    return {'model': model, 
            'train_data_list': train_data_list, 
            'dev_data_list': dev_data_list, 
            'test_data_list': test_data_list, 
            'best_dataset_performance': best_dataset_performance}