In [2]:
import xml.etree.ElementTree as ET
from pyknp import Juman
from tqdm import tqdm
from argparse import ArgumentParser
import mojimoji
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from pytorch_pretrained_bert import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from collections import defaultdict
from sklearn.model_selection import train_test_split, GroupKFold, GroupShuffleSplit, KFold
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
import logging
from model import RelationClassifier

from utils import *
import importlib

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")
if str(device) == 'cpu':
    BERT_URL='/Users/fei-c/Resources/embed/L-12_H-768_A-12_E-30_BPE'
elif str(device) == 'cuda':
    BERT_URL='/larch/share/bert/Japanese_models/Wikipedia/L-12_H-768_A-12_E-30_BPE'
    
tokenizer = BertTokenizer.from_pretrained(BERT_URL, do_lower_case=False, do_basic_tokenize=False)
juman = Juman()

logger = logging.getLogger('Data_Process')
logger.setLevel(logging.INFO)

# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)

# create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# add formatter to ch
ch.setFormatter(formatter)

if (logger.hasHandlers()):
    logger.handlers.clear()
# add ch to logger
logger.addHandler(ch)

In [None]:
merge_map_6c = {
    'after': 'AFTER',
    'met-by': 'AFTER',
    'overlapped-by': 'OVERLAP-OR-AFTER',
    'finishes': 'OVERLAP-OR-AFTER',
    'during': 'OVERLAP',
    'started-by': 'OVERLAP',
    'equal' : 'OVERLAP',
    'starts': 'BEFORE-OR-OVERLAP',
    'contains': 'OVERLAP',
    'finished-by' : 'OVERLAP',
    'overlaps' : 'BEFORE-OR-OVERLAP',
    'meets' : 'BEFORE',
    'before': 'BEFORE',
    'is_included' : 'OVERLAP',
    'identity' : 'OVERLAP',
    'includes' : 'OVERLAP',
    'vague' : 'VAGUE',
}
merge_map_4c = {
    'after': 'AFTER',
    'met-by': 'AFTER',
    'overlapped-by': 'OVERLAP',
    'finishes': 'OVERLAP',
    'during': 'OVERLAP',
    'started-by': 'OVERLAP',
    'equal' : 'OVERLAP',
    'starts': 'OVERLAP',
    'contains': 'OVERLAP',
    'finished-by' : 'OVERLAP',
    'overlaps' : 'OVERLAP',
    'meets' : 'BEFORE',
    'before': 'BEFORE',
    'is_included' : 'OVERLAP',
    'identity' : 'OVERLAP',
    'includes' : 'OVERLAP',
    'vague' : 'VAGUE',
}

In [None]:
def save_checkpoint(model, checkpoint_dir):

    # Goal: Save a model, configuration and vocabulary that you have fine-tuned
    
    # If we have a distributed model, save only the encapsulated model
    # (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
    model_to_save = model.module if hasattr(model, 'module') else model

    # If we save using the predefined names, we can load using `from_pretrained`
    checkpoint_model_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
    checkpoint_config_file = os.path.join(checkpoint_dir, CONFIG_NAME)

    torch.save(model_to_save.state_dict(), checkpoint_model_file)
    model_to_save.config.to_json_file(checkpoint_config_file)
    tokenizer.save_vocabulary(checkpoint_dir)


def flatten_list(l):
    return [item for sublist in l for item in sublist]


def merge_lab(lab, lab_type='6c'):
    
    # Goal: merge 13+1 labels into 5+1 or 3+1 labels.
    if lab_type == '6c':
        return merge_map_6c[lab] if lab in merge_map_6c else 'VAGUE'
    elif lab_type == '4c':
        return merge_map_4c[lab] if lab in merge_map_4c else 'VAGUE'

    
def vote_labels(lab_a, lab_b, lab_c, lab_type=None, comp=None):
    
    if lab_type in ['4c', '6c']:
        lab_a = merge_lab(lab_a, lab_type = lab_type)
        lab_b = merge_lab(lab_b, lab_type = lab_type)
        lab_c = merge_lab(lab_c, lab_type = lab_type)
    
    if comp:
        if lab_a == lab_b == lab_c:
            return lab_a
        else:
            return None
    else:
        if lab_a == lab_b or lab_a == lab_c:
            return lab_a
        elif lab_b == lab_c:
            return lab_b
        else:
            return 'vague'

    
def mask_mention(sent_mask, mention_id):
    return [1 if mask == mention_id else 0 for mask in sent_mask]


def extract_sents_from_xml(xml_file, verbose):
    doc_deunk_toks, doc_toks, doc_masks, doc_tlinks = [], [], [], []
    eiid2eid = {}
    tlinks = defaultdict(lambda: dict())
    root = ET.parse(xml_file).getroot()
    for text_node in root.findall('TEXT'):
        for dct_node in text_node.find('TIMEX3'):
            pass
        for event_node in text_node.iter('MAKEINSTANCE'):
            eiid2eid[event_node.attrib['eiid']] = event_node.attrib['eventID']
        logger.info('event num: %i' % len(eiid2eid))
        
        for tlink_node in text_node.iter('TLINK'):
            if tlink_node.attrib['task'] in ['DCT', 'T2E']:
                tlink_key = '%s-%s' % (tlink_node.attrib['timeID'], 
                                       eiid2eid[tlink_node.attrib['relatedToEventInstance']])
            elif tlink_node.attrib['task'] in ['E2E', 'MAT']:
                tlink_key = '%s-%s' % (eiid2eid[tlink_node.attrib['eventInstanceID']], 
                                       eiid2eid[tlink_node.attrib['relatedToEventInstance']])
            tlinks[tlink_node.attrib['task']][tlink_key] = vote_labels(tlink_node.attrib['relTypeA'],
                                                                       tlink_node.attrib['relTypeB'],
                                                                       tlink_node.attrib['relTypeC'])
        
        for sent_node in text_node.iter('sentence'):
            sent_toks, sent_masks = [], []
            for tag in sent_node.iter():
                try:
                    if tag.text and tag.text.strip():
                        text_seg = juman.analysis(mojimoji.han_to_zen(tag.text.strip()))
                        sent_toks += [w.midasi for w in text_seg.mrph_list()]
                        if tag.tag in ['EVENT', 'event'] and 'eid' in tag.attrib:
                            sent_masks += [tag.attrib['eid']] * len(text_seg)
                        elif tag.tag in ['TIMEX3'] and 'tid' in tag.attrib:
                            sent_masks += [tag.attrib['tid']] * len(text_seg)
                        else:
                            sent_masks += ['O'] * len(text_seg)
                    if tag.tag != 'sentence' and tag.tail and tag.tail.strip():
                        tail_seg = juman.analysis(mojimoji.han_to_zen(tag.tail.strip()))
                        sent_toks += [w.midasi for w in tail_seg.mrph_list()]
                        sent_masks += ['O'] * len(tail_seg)
                except Exception as ex:
                    logger.info(xml_file, tag.tag, tag.text, tag.attrib)
            
            sbp_toks = tokenizer.tokenize(' '.join(sent_toks))
            deunk_toks = explore_unk(sbp_toks, sent_toks)
            sbp_masks = match_sbp_label(sbp_toks, sent_masks)
            
            if deunk_toks:
                doc_deunk_toks.append(deunk_toks)
                doc_toks.append(sbp_toks)
                doc_masks.append(sbp_masks)
            
                if verbose:
                    logger.debug(sent_toks, len(sent_toks))
                    logger.debug(sent_masks, len(sent_masks))
                    logger.debug(sbp_toks, len(sbp_toks))
                    logger.debug(sbp_masks, len(sbp_masks))
                    logger.debug()
    
    return doc_deunk_toks, doc_toks, doc_masks, tlinks


def make_tlink_instances(doc_masks, doc_tlinks, task, verbose):
    logger.info('tlink num: %i' % len(doc_tlinks[task]))
    tlink_count = 0
    doc_sour_masks, doc_targ_masks, doc_labs = [], [], []
    for sent_mask in doc_masks:
        sent_sour_masks, sent_targ_masks, sent_labs = [], [], []
        for key in doc_tlinks[task].keys():
            sour_id, targ_id = key.split('-')
            if task in ['DCT'] and targ_id in sent_mask:
                rel = doc_tlinks[task][key]
                targ_mask = mask_mention(sent_mask, targ_id)
                sent_sour_masks.append([0] * len(sent_mask))
                sent_targ_masks.append(targ_mask)
                sent_labs.append(rel)
                if verbose:
                    logger.debug('%s\t%s' % (key, rel))
                    logger.debug('%s' % targ_mask)
                tlink_count += 1
            elif task in ['T2E', 'E2E', 'MAT'] and sour_id in sent_mask and targ_id in sent_mask:
                rel = doc_tlinks[task][key]
                sour_mask = mask_mention(sent_mask, sour_id)
                targ_mask = mask_mention(sent_mask, targ_id)
                sent_sour_masks.append(sour_mask)
                sent_targ_masks.append(targ_mask)
                sent_labs.append(rel)
                logger.debug('%s\t%s' % (key, rel))
                logger.debug('%s' % sour_mask)
                logger.debug('%s' % targ_mask)
                tlink_count += 1
        doc_sour_masks.append(sent_sour_masks)
        doc_targ_masks.append(sent_targ_masks)
        doc_labs.append(sent_labs)
    logger.info('generated tlink num: %i' % tlink_count)
    return doc_sour_masks, doc_targ_masks, doc_labs, len(doc_tlinks[task]), tlink_count


def flatten_tlink_instance(deunk_toks, toks, sour_masks, targ_masks, labs):
    f_deunk_toks, f_toks = [], []
    for det, t, ls in zip(deunk_toks, toks, labs):
        f_deunk_toks += [det] * len(ls)
        f_toks += [t] * len(ls)
    f_sour_masks = [mask for sent in sour_masks for mask in sent]
    f_targ_masks = [mask for sent in targ_masks for mask in sent]
    f_labs = [mask for sent in labs for mask in sent]
    assert len(f_deunk_toks) == len(f_toks) == len(f_sour_masks) == len(f_targ_masks) == len(f_labs)
    return f_deunk_toks, f_toks, f_sour_masks, f_targ_masks, f_labs


def batch_make_tlink_instances(data_dir, task='E2E', verbose=0):
    
    total_tlink_num, total_tlink_count = 0, 0
    deunk_toks, toks, sour_masks, targ_masks, labs = [], [], [], [], []
    for file in sorted(os.listdir(data_dir)):
        if file.endswith(".xml"):
            dir_file = os.path.join(data_dir, file)
        
        # sentence-level data
        doc_deunk_toks, doc_toks, doc_masks, doc_tlinks = extract_sents_from_xml(
            dir_file, 
            verbose
        )
        
        # instance-level data
        doc_sour_masks, doc_targ_masks, doc_labs, tlink_num, tlink_count = make_tlink_instances(
            doc_masks, 
            doc_tlinks, 
            task,
            verbose
        )

        assert len(doc_deunk_toks) == len(doc_masks) == len(doc_sour_masks) == len(doc_targ_masks) == len(doc_labs)
        total_tlink_num += tlink_num
        total_tlink_count += tlink_count
        deunk_toks += doc_deunk_toks
        toks += doc_toks
        sour_masks += doc_sour_masks
        targ_masks += doc_targ_masks
        labs += doc_labs
    logger.info('%i\t%i' % (total_tlink_num, total_tlink_count))
    logger.info('sent num: %i, data instance num: %i' % (len(deunk_toks), sum([len(d) for d in labs])))
    
    # flatten sentence-level to instance-level tlinks
    f_deunk_toks, f_toks, f_sour_masks, f_targ_masks, f_labs = flatten_tlink_instance(deunk_toks, 
                                                                                      toks, 
                                                                                      sour_masks, 
                                                                                      targ_masks, 
                                                                                      labs)
    
    logger.info('flat data instance num: %i' % len(f_deunk_toks))
    
    return f_deunk_toks, f_toks, f_sour_masks, f_targ_masks, f_labs


def convert_to_np(toks, sour_masks, targ_masks, labs, lab2ix):
    max_len = max([len(t) for t in toks])
    pad_tok_ids, pad_masks, pad_sm, pad_tm = [], [], [], []
    for inst_tok, inst_sm, inst_tm, inst_lab in zip(toks, sour_masks, targ_masks, labs):
        pad_inst_tok = padding_1d(['[CLS]'] + inst_tok, max_len + 1, pad_tok='[PAD]')
        pad_inst_tok_ids = tokenizer.convert_tokens_to_ids(pad_inst_tok)
        pad_inst_masks = padding_1d([1] * (len(inst_tok) + 1), max_len + 1, pad_tok=0)
        pad_inst_sm = padding_1d([0] + inst_sm, max_len + 1, pad_tok=0)
        pad_inst_tm = padding_1d([0] + inst_tm, max_len + 1, pad_tok=0)
        pad_tok_ids.append(pad_inst_tok_ids)
        pad_masks.append(pad_inst_masks)
        pad_sm.append(pad_inst_sm)
        pad_tm.append(pad_inst_tm)
    lab_ids = [lab2ix[l] for l in labs]
    assert len(pad_tok_ids) == len(pad_masks) == len(pad_sm) == len(pad_tm) == len(lab_ids)
    return np.array(pad_tok_ids), np.array(pad_masks), np.array(pad_sm), np.array(pad_tm), np.array(lab_ids)


def generate_group_ids(toks):
    group_ids = []
    id_offset = 1
    for index, inst_tok in enumerate(toks):
        if index > 0:
            if np.array_equal(inst_tok, toks[index - 1]):
                id_offset += 1
        group_ids.append(id_offset)
    assert len(toks) == len(group_ids)
    return group_ids
    
    
def merge_word_mention_boundaries(flat_word_ids, flat_doc_toks, mention_offsets):
    
    merge_num = 0
    
    flat_new_word_ids = flat_word_ids.copy()
    
    tmp_offs = 0
    
    for mid, mtype, offs_b, offs_e, m in mention_offsets:
        print(merge_num, mid, mtype, offs_b, offs_e, m)
        if offs_b == 0:
            continue
            
        if flat_word_ids[offs_b - 1] == flat_word_ids[offs_b]:
            print('b', m, flat_doc_toks[offs_b], flat_doc_toks[offs_b - 1], 
                  offs_b, offs_b - 1, 
                  flat_word_ids[offs_b - 1], flat_word_ids[offs_b])
            while tmp_offs < offs_b:
#                 print(tmp_offs, flat_doc_toks[tmp_offs], flat_word_ids[tmp_offs])
                flat_new_word_ids[tmp_offs] = flat_word_ids[tmp_offs] + merge_num
                print(flat_word_ids[tmp_offs] + merge_num)
                tmp_offs += 1
            merge_num += 1
            
        if flat_word_ids[offs_e - 1] == flat_word_ids[offs_e]:
            print('e', m, flat_doc_toks[offs_e], flat_doc_toks[offs_e - 1], 
                  offs_e, offs_e - 1, 
                  flat_word_ids[offs_e - 1], flat_word_ids[offs_e])   
            while tmp_offs < offs_e:
#                 print(tmp_offs, flat_doc_toks[tmp_offs], flat_word_ids[tmp_offs])
                flat_new_word_ids[tmp_offs] = flat_word_ids[tmp_offs] + merge_num
                tmp_offs += 1
            merge_num += 1
        print(flat_word_ids[tmp_offs-1 if tmp_offs > 0 else 0] + merge_num)
        print()
    while tmp_offs < len(flat_word_ids):
#         print(tmp_offs)
        flat_new_word_ids[tmp_offs] = flat_word_ids[tmp_offs] + merge_num
        tmp_offs += 1
                
    assert len(flat_word_ids) == len(flat_new_word_ids)
            
    return flat_new_word_ids


def attach_word_ids(doc_words):
    
    word_ids = []
    tmp_begin_id = 0
    for sent_word in doc_words:
        sent_ids = []
        for word in sent_word:
            sent_ids += [tmp_begin_id] * len(list(word))
            tmp_begin_id += 1
        word_ids.append(sent_ids)
    return word_ids

In [15]:
def retrieve_mention(sent_toks, ment_mask):
    assert len(sent_toks) == len(ment_mask)
    return [[t] for t, m in zip(sent_toks, ment_mask) if m == 1]


def extract_sents_from_xml_v2(xml_file, lab_type=None, comp=None):
    doc_deunk_toks, doc_toks = [], []
    eiid2eid = {}
    doc_mid2smask = {}
    doc_tlinks = defaultdict(lambda:list())
    root = ET.parse(xml_file).getroot()
    for text_node in root.findall('TEXT'):
        
        for dct_node in text_node.find('TIMEX3'):
            pass
        
        for event_node in text_node.iter('MAKEINSTANCE'):
            eiid2eid[event_node.attrib['eiid']] = event_node.attrib['eventID']
        logger.debug('event num: %i' % len(eiid2eid))
        
        for tlink_node in text_node.iter('TLINK'):
            if tlink_node.attrib['task'] in ['DCT', 'T2E']:
                sour_mid = tlink_node.attrib['timeID']
                targ_mid = eiid2eid[tlink_node.attrib['relatedToEventInstance']]
            elif tlink_node.attrib['task'] in ['E2E', 'MAT']:
                sour_mid = eiid2eid[tlink_node.attrib['eventInstanceID']]
                targ_mid = eiid2eid[tlink_node.attrib['relatedToEventInstance']]
            voted_label = vote_labels(
                tlink_node.attrib['relTypeA'],
                tlink_node.attrib['relTypeB'],
                tlink_node.attrib['relTypeC'],
                lab_type=lab_type,
                comp=comp
            )
            if voted_label:
                doc_tlinks[tlink_node.attrib['task']].append((sour_mid, targ_mid, voted_label))
        
        s_id = 0
        for sent_node in text_node.iter('sentence'):
            logger.debug('sentence %i' % s_id)
            sent_toks, tmp_mids = [], []
            for tag in sent_node.iter():
                try:
                    if tag.text and tag.text.strip():
                        text_seg = [w.midasi for w in juman.analysis(mojimoji.han_to_zen(tag.text.strip()).replace('\u3000', '[JSP]')).mrph_list()]
                        if tag.tag in ['EVENT', 'event'] and 'eid' in tag.attrib:
                            tmp_mids.append(tag.attrib['eid'])
                            doc_mid2smask[tag.attrib['eid']] = [
                                s_id, 
                                [0] * len(sent_toks) + [1] * len(text_seg)
                            ]
                        elif tag.tag in ['TIMEX3'] and 'tid' in tag.attrib:
                            tmp_mids.append(tag.attrib['tid'])
                            doc_mid2smask[tag.attrib['tid']] = [
                                s_id, 
                                [0] * len(sent_toks) + [1] * len(text_seg)
                            ]
                        sent_toks += text_seg
                    if tag.tag != 'sentence' and tag.tail and tag.tail.strip():
                        tail_seg = [w.midasi for w in juman.analysis(mojimoji.han_to_zen(tag.tail.strip()).replace('\u3000', '[JSP]')).mrph_list()]
                        sent_toks += tail_seg
                except Exception as ex:
                    logger.error(xml_file, tag.tag, tag.text, tag.attrib)
                
            """ subword tokenizer for word tokens """
            sbp_toks = tokenizer.tokenize(' '.join(sent_toks))
            deunk_toks = explore_unk(sbp_toks, sent_toks)
            logger.debug(str(len(sent_toks)) + ' ' + '/'.join(sent_toks))
            logger.debug(str(len(deunk_toks)) + ' ' + '/'.join(deunk_toks))
            
            """ padding sentence mention masks with matching sbp tokens """
            for mid in tmp_mids:
                sent_mask = padding_1d(doc_mid2smask[mid][1], len(sent_toks))
                logger.debug('%s, sent_id: %i' % (mid, doc_mid2smask[mid][0]))
                logger.debug(str(len(sent_mask)) + ' ' + ' '.join([str(i) for i in sent_mask]))
                sbp_mask = match_sbp_mask(sbp_toks, sent_mask)
                doc_mid2smask[mid][1] = sbp_mask
                logger.debug(str(len(sbp_mask)) + ' ' + ' '.join([str(i) for i in sbp_mask]))
                logger.debug(retrieve_mention(deunk_toks, sbp_mask))
            
            logger.debug('[EOS]')
            
            if deunk_toks:
                doc_deunk_toks.append(deunk_toks)
                doc_toks.append(sbp_toks)
                s_id += 1
    
    return doc_deunk_toks, doc_toks, doc_mid2smask, doc_tlinks


def make_tlink_instances_v2(doc_deunk_toks, doc_toks, doc_mid2smask, doc_tlinks, task=None):
    deunk_toks, toks, sour_masks, targ_masks, sent_masks, rels = [], [], [], [], [], []
    for sour_mid, targ_mid, rel in doc_tlinks[task]:
        logger.debug('%s\t%s\t%s' % (sour_mid, targ_mid, rel))
        targ_sid = doc_mid2smask[targ_mid][0]
        if task in ['DCT']:
            deunk_tok = doc_deunk_toks[targ_sid]
            tok = doc_toks[targ_sid]
            sour_mask = [0] * len(doc_mid2smask[targ_mid][1])
            targ_mask = doc_mid2smask[targ_mid][1]
            sent_mask = [0] * len(doc_mid2smask[targ_mid][1])
        elif task in ['T2E', 'E2E', 'MAT']:
            if sour_mid not in doc_mid2smask:
                continue
            sour_sid = doc_mid2smask[sour_mid][0]
            if targ_sid - sour_sid == 0:
                deunk_tok = doc_deunk_toks[targ_sid]
                tok = doc_toks[targ_sid]
                sour_mask = doc_mid2smask[sour_mid][1]
                targ_mask = doc_mid2smask[targ_mid][1]
                sent_mask = [0] * len(doc_mid2smask[targ_mid][1])
            else:
                deunk_tok = doc_deunk_toks[sour_sid] + doc_deunk_toks[targ_sid]
                tok = doc_toks[sour_sid] + doc_toks[targ_sid]
                sour_mask = doc_mid2smask[sour_mid][1] + [0] * len(doc_mid2smask[targ_mid][1])
                targ_mask = [0] * len(doc_mid2smask[sour_mid][1]) + doc_mid2smask[targ_mid][1]
                sent_mask = [0] * len(doc_mid2smask[sour_mid][1]) + [1] * len(doc_mid2smask[targ_mid][1])
        
        logger.debug(' '.join(deunk_tok)) 
        logger.debug(' '.join(tok)) 
        logger.debug(' '.join([str(i) for i in sour_mask])) 
        logger.debug(' '.join([str(i) for i in targ_mask])) 
        logger.debug(' '.join([str(i) for i in sent_mask])) 
        deunk_toks.append(deunk_tok)
        toks.append(tok)
        sour_masks.append(sour_mask)
        targ_masks.append(targ_mask)
        sent_masks.append(sent_mask)
        rels.append(rel)
        assert len(deunk_tok) == len(tok) == len(sour_mask) == len(targ_mask) == len(sent_mask)
    return deunk_toks, toks, sour_masks, targ_masks, sent_masks, rels


def batch_make_tlink_instances_v2(file_list, task=None, lab_type=None, comp=None):
    deunk_toks, toks, sour_masks, targ_masks, sent_masks, rels = [], [], [], [], [], []
    for dir_file in file_list:
        logger.debug('[Done] processing %s' % dir_file)
        doc_deunk_toks, doc_toks, doc_mid2smask, doc_tlinks = extract_sents_from_xml_v2(
            dir_file,
            lab_type=lab_type,
            comp=comp
        )
        inst_deunk_toks, inst_toks, inst_sour_masks, inst_targ_masks, inst_sent_masks, inst_rels = make_tlink_instances_v2(
            doc_deunk_toks, 
            doc_toks, 
            doc_mid2smask, 
            doc_tlinks, 
            task=task
        )
        deunk_toks += inst_deunk_toks
        toks += inst_toks
        sour_masks += inst_sour_masks
        targ_masks += inst_targ_masks
        sent_masks += inst_sent_masks
        rels += inst_rels
    return deunk_toks, toks, sour_masks, targ_masks, sent_masks, rels


def convert_to_np_v2(deunk_toks, toks, sour_masks, targ_masks, sent_masks, labs, lab2ix):
    max_len = max([len(t) for t in toks])
    logger.info('max seq length %i' % (max_len))
    pad_tok_ids, pad_masks, pad_sm, pad_tm, pad_sent_m = [], [], [], [], []
    for inst_tok, inst_sm, inst_tm, inst_sent_m, inst_lab in zip(toks, sour_masks, targ_masks, sent_masks, labs):
        pad_inst_tok = padding_1d(['[CLS]'] + inst_tok, max_len + 1, pad_tok='[PAD]')
        pad_inst_tok_ids = tokenizer.convert_tokens_to_ids(pad_inst_tok)
        pad_inst_masks = padding_1d([1] * (len(inst_tok) + 1), max_len + 1, pad_tok=0)
        pad_inst_sm = padding_1d([0] + inst_sm, max_len + 1, pad_tok=0)
        pad_inst_tm = padding_1d([0] + inst_tm, max_len + 1, pad_tok=0)
        pad_inst_sent_m = padding_1d([0] + inst_sent_m, max_len + 1, pad_tok=0)
        pad_tok_ids.append(pad_inst_tok_ids)
        pad_masks.append(pad_inst_masks)
        pad_sm.append(pad_inst_sm)
        pad_tm.append(pad_inst_tm)
        pad_sent_m.append(pad_inst_sent_m)
    lab_ids = [lab2ix[l] for l in labs]
    assert len(pad_tok_ids) == len(pad_masks) == len(pad_sm) == len(pad_tm) == len(pad_sent_m) == len(lab_ids)
    return np.array(pad_tok_ids), np.array(pad_masks), np.array(pad_sm), np.array(pad_tm), np.array(pad_sent_m), np.array(lab_ids)


def doc_kfold(data_dir):
    file_list, data_splits = [], []
    for file in sorted(os.listdir(data_dir)):
        if file.endswith(".xml"):
            dir_file = os.path.join(data_dir, file)
            file_list.append(dir_file)
    logger.info("[Number] %i files in '%s'" % (len(file_list), data_dir))
    gss = KFold(n_splits=5, shuffle=True, random_state=0)
    for train_split, test_split in gss.split(file_list):
        data_splits.append((train_split.tolist(), test_split.tolist()))
    return data_splits

2019-08-09 23:37:56,193 - Data_Process - INFO - [Number] 54 files in 'data/merge/BCCWJ-TIMEX'


In [None]:
parser = ArgumentParser(description='Bert-based Temporal Relation Classifier')

parser.add_argument("-t", "--task", dest="task",
                    help="classification task, i.g. DCT, T2E, E2E and MAT")
parser.add_argument("-l", "--lab", dest="lab_type",
                    help="lab_type, i.g. 4c, 6c or None")
parser.add_argument("-c", "--comp", dest="comp",
                    help="complete match, True or False", type=bool)
args = parser.parse_args()

logger.info('[args] task: %s, label type: %s, complete agree: %s' % (
    args.task, 
    args.lab_type,
    str(args.comp)
))

In [None]:
data_dir = 'data/merge/BCCWJ-TIMEX'

data_splits = doc_kfold(data_dir)

deunk_toks, toks, sour_masks, targ_masks, sent_masks, labs = batch_make_tlink_instances_v2(
    data_dir, 
    task=args.task,
    lab_type=args.lab_type,
    comp=args.comp
)
logger.info('Full data size %i ...' % len(labs))

In [None]:
# convert_xml_to_brat('data/merge/BCCWJ-TIMEX/00001_A_PN1c_00001.xml')

In [None]:
# data_dir = 'data/merge/BCCWJ-TIMEX'
# for file in os.listdir(data_dir):
#     if file.endswith(".xml"):
#         dir_file = os.path.join(data_dir, file)
#         convert_xml_to_brat(dir_file)

In [None]:
# doc_deunk_toks, doc_toks, doc_mid2smask, doc_tlinks = extract_sents_from_xml_v2(
#     'data/merge/BCCWJ-TIMEX/00003_A_PN1e_00001.xml'
# )
# deunk_toks, toks, sour_masks, targ_masks, sent_masks, rels = make_tlink_instances_v2(doc_deunk_toks, doc_toks, doc_mid2smask, doc_tlinks, task='MAT')

In [None]:
# deunk_toks, toks, sour_masks, targ_masks, labs = batch_make_tlink_instances(
#     'data/merge/BCCWJ-TIMEX', 
#     task='DCT', 
#     verbose=0
# )

In [None]:
lab2ix = get_label2ix(labs)
lab2count = {}
for l in labs:
    if l not in lab2count:
        lab2count[l] = 1
    else:
        lab2count[l] += 1
logger.info(str(lab2count))
logger.info('major vote: %.2f%%' % (100 * max(lab2count.values()) / sum(lab2count.values())))

In [None]:
toks_ids_np, tok_masks_np, sour_masks_np, targ_masks_np, sent_masks_np, lab_ids_np = convert_to_np_v2(
    deunk_toks, 
    toks, 
    sour_masks, 
    targ_masks, 
    sent_masks, 
    labs,
    lab2ix
)
logger.info(str(toks_ids_np.shape))

In [None]:
# toks_ids_np, tok_masks_np, sour_masks_np, targ_masks_np, lab_ids_np = convert_to_np(toks, 
#                                                                                     sour_masks, 
#                                                                                     targ_masks, 
#                                                                                     labs, 
#                                                                                     lab2ix)

In [None]:
# group_ids_np = np.array(generate_group_ids(toks_ids_np))  # those tlinks given the same sentences are grouped
# print(group_ids_np)

In [None]:
gss = KFold(n_splits=5, shuffle=True, random_state=0)
train_split, test_split = list(gss.split(toks_ids_np, lab_ids_np))[0]
logger.info('train size: %i, test size: %i' % (len(train_split), len(test_split)))


In [None]:
NUM_EPOCHS = 10
BATCH_SIZE = 16
train_tensors = TensorDataset(
    torch.from_numpy(toks_ids_np[train_split]).to(device),
    torch.from_numpy(tok_masks_np[train_split]).to(device),
    torch.from_numpy(sour_masks_np[train_split]).to(device),
    torch.from_numpy(targ_masks_np[train_split]).to(device),
    torch.from_numpy(sent_masks_np[train_split]).to(device),
    torch.from_numpy(lab_ids_np[train_split]).to(device)
)
test_tensors = TensorDataset(
    torch.from_numpy(toks_ids_np[test_split]).to(device),
    torch.from_numpy(tok_masks_np[test_split]).to(device),
    torch.from_numpy(sour_masks_np[test_split]).to(device),
    torch.from_numpy(targ_masks_np[test_split]).to(device),
    torch.from_numpy(sent_masks_np[test_split]).to(device),
    torch.from_numpy(lab_ids_np[test_split]).to(device)
)
train_dataloader = DataLoader(train_tensors, batch_size=BATCH_SIZE,shuffle=True)
test_dataloader = DataLoader(test_tensors, batch_size=BATCH_SIZE,shuffle=False)
logger.info('Train batch num: %i, Test batch num: %i' % (len(train_dataloader), len(test_dataloader)))

In [None]:
model = RelationClassifier.from_pretrained(BERT_URL, num_labels=len(lab2ix))
model.to(device)
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer= BertAdam(optimizer_grouped_parameters,
                    lr=5e-5,
                    warmup=0.1,
                    t_total=NUM_EPOCHS * len(train_dataloader))

eval_lab(model, test_dataloader, 0)

for epoch in range(1, NUM_EPOCHS + 1):
    for (b_tok, b_mask, b_sour_mask, b_targ_mask, b_sent_mask, b_lab) in tqdm(train_dataloader):
        model.train()
        model.zero_grad()
        loss = model(b_tok, b_sour_mask, b_targ_mask, token_type_ids=b_sent_mask, attention_mask=b_mask, labels=b_lab)
        
        loss.backward()
        optimizer.step()
    eval_lab(model, test_dataloader, epoch)

    