In [1]:
from tqdm import tqdm
import json
import copy
from transformers import AutoTokenizer

In [2]:
rel2id = json.load(open('data/rel2id.json', 'r'))
fact_in_train = set()
span_wrong_dict = set()

In [3]:
def convert_feature(file_name, output_file, max_seq_length=512, is_training=True, is_test=False):
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    i_line = 0
    max_len_for_doc = max_seq_length - 2  # [CLS] [SEP]

    pos_samples = 0
    neg_samples = 0

    print('convert features...')
    with open(output_file, 'w',encoding='utf-8') as w:
        with open(file_name, 'r') as f:
            data_samples = json.load(f)
            for sample in tqdm(data_samples):

                if not is_test:
                    labels = sample['labels']
                # 先wordpiece分词,映射每句的word index
                sents = []
                sent_map = []
                for sent in sample['sents']:
                    new_sent = []
                    new_map = {}
                    for i_t, token in enumerate(sent):
                        tokens_wordpiece = tokenizer.tokenize(token)
                        new_map[i_t] = len(new_sent)
                        new_sent.extend(tokens_wordpiece)
                    new_map[i_t + 1] = len(new_sent)
                    sent_map.append(new_map)
                    sents.append(new_sent)

                entitys = sample['vertexSet']

                # 先存储一波有relation的实体关系
                train_triple = {}
                if not is_test:
                    for label in labels:
                        evidence = label['evidence']
                        r = int(rel2id[label['r']])
                        # 由于同一组实体可能存在多个关系，这里要用list存！
                        if (label['h'], label['t']) not in train_triple:
                            train_triple[(label['h'], label['t'])] = [{'relation': r, 'evidence': evidence}]
                        else:  # 不过要确保他们的关系是不同的
                            in_triple = False
                            for tmp_r in train_triple[(label['h'], label['t'])]:
                                if tmp_r['relation'] == r:
                                    in_triple = True
                                    break
                            if not in_triple:
                                train_triple[(label['h'], label['t'])].append({'relation': r, 'evidence': evidence})

                        intrain = False
                        # 登记哪些实体关系在train中出现过了
                        for e1i in entitys[label['h']]:
                            for e2i in entitys[label['t']]:
                                if is_training:
                                    fact_in_train.add((e1i['name'], e2i['name'], r))
                                elif not is_test:
                                    # 验证集查找
                                    if (e1i['name'], e2i['name'], r) in fact_in_train:
                                        for train_tmp in train_triple[(label['h'], label['t'])]:
                                            train_tmp['intrain'] = True
                                        intrain = True
                        if not intrain:
                            for train_tmp in train_triple[(label['h'], label['t'])]:
                                train_tmp['intrain'] = False

                # 遍历所有实体构建关系，没有关系的打上NA
                for e1, entity1 in enumerate(entitys):
                    for e2, entity2 in enumerate(entitys):
                        if e1 != e2:
                            # 在$所有$实体1前后加上[unused0]和[unused1]用来给实体定位,在$所有$实体2前后加上[unused2]和[unused3]用来给实体定位
                            # [unused0] Hirabai Badodekar [unused1] , Gangubai Hangal , Mogubai Kurdikar ) ,
                            # made the [unused2] Indian [unused3] classical music so much greater .

                            entity1_ = copy.deepcopy(entity1)
                            entity2_ = copy.deepcopy(entity2)
                            for e in entity1_:
                                e['first'] = True  # 是entity1
                            for e in entity2_:
                                e['first'] = False  # 是entity2
                            new_sents = copy.deepcopy(sents)
                            # 把entity按照pos从后往前排序，起点相同根据终点倒序排, 这样insert可以无视pos的offset
                            sorted_entity = sorted(entity1_ + entity2_, key=lambda x: (x['pos'][0], x['pos'][1]),
                                                   reverse=True)

                            start_end_dict = set()  # 为了记录起点和终点是否有重叠
                            for se in sorted_entity:
                                map_start = sent_map[se['sent_id']][se['pos'][0]]
                                map_end = sent_map[se['sent_id']][se['pos'][1]]

                                # 如果有重叠，起点+1，终点+2，否则起点不变，终点+1(因为起点+了一个标识)
                                if (map_start, se['sent_id']) in start_end_dict:
                                    map_start_fin = map_start + 1
                                    map_end_fin = map_end + 2
                                else:
                                    map_start_fin = map_start
                                    map_end_fin = map_end + 1
                                    start_end_dict.add((map_start, se['sent_id']))

                                if se['first']:  # 混合排序后区分entity1和entity2
                                    new_sents[se['sent_id']].insert(map_start_fin, '[unused0]')
                                    new_sents[se['sent_id']].insert(map_end_fin, '[unused1]')
                                else:
                                    new_sents[se['sent_id']].insert(map_start_fin, '[unused2]')
                                    new_sents[se['sent_id']].insert(map_end_fin, '[unused3]')

                            doc_tokens = []
                            for sent in new_sents:
                                doc_tokens.extend(sent)

                            if len(doc_tokens) > max_len_for_doc:
                                continue
                                # TODO doc_tokens = doc_tokens[:max_len_for_doc]

                            tokens = ['[CLS]'] + doc_tokens + ['[SEP]']
                            segment_ids = [0] * (len(doc_tokens) + 2)
                            input_ids = tokenizer.convert_tokens_to_ids(tokens)
                            input_mask = [1] * len(input_ids)

                            intrain = None
                            relation_label = None
                            evidence = []
                            if not is_test:
                                if (e1, e2) not in train_triple:
                                    relation_label = [0] * len(rel2id)
                                    relation_label[0] = 1
                                    evidence = []
                                    intrain = False
                                    neg_samples += 1
                                else:
                                    relation_label = [0] * len(rel2id)
                                    # 一个实体可能存在多个关系
                                    for train_tmp in train_triple[(e1, e2)]:
                                        relation_label[train_tmp['relation']] = 1
                                        evidence.append(train_tmp['evidence'])
                                    intrain = train_triple[(e1, e2)][0]['intrain']
                                    pos_samples += 1

                            # Zero-pad up to the sequence length.
                            while len(input_ids) < max_seq_length:
                                input_ids.append(0)
                                input_mask.append(0)
                                segment_ids.append(0)

                            assert len(input_ids) == max_seq_length
                            assert len(input_mask) == max_seq_length
                            assert len(segment_ids) == max_seq_length

                            if i_line <= 5:
                                print('#' * 100)
                                print('E1:', [e['name'] for e in entity1])
                                print('E2:', [e['name'] for e in entity2])
                                print('intrain:', intrain)
                                print('Evidence:', evidence)
                                print('tokens:', tokens)
                                print('segment ids:', segment_ids)
                                print('input ids:', input_ids)
                                print('input mask', input_mask)
                                print('relation_label:', relation_label)

                            i_line += 1

                            feature = {'input_ids': input_ids,
                                       'input_mask': input_mask,
                                       'segment_ids': segment_ids,
                                       'labels': relation_label,
                                       'evidences': evidence,
                                       'intrain': intrain}

                            w.write(json.dumps(feature, ensure_ascii=False) + '\n')

    print(output_file, 'final samples', i_line)
    print('pos samples:', pos_samples)
    print('neg samples:', neg_samples)

In [4]:
file_list = ['data/train_annotated.json', 'data/dev.json', 'data/test.json']
for file_name in file_list:
    output_file = file_name.split('/')[0] + '/' + file_name.split('/')[-1].split('.json')[0] + '_cls_data.txt'
    convert_feature(file_name, output_file, is_training=True if 'train' in file_name else False,
                    is_test=True if 'test' in file_name else False)

convert features...


  0%|          | 1/3053 [00:00<10:23,  4.90it/s]

####################################################################################################
E1: ['Zest Airways, Inc.', 'Asian Spirit and Zest Air', 'AirAsia Zest', 'AirAsia Zest']
E2: ['Ninoy Aquino International Airport', 'Ninoy Aquino International Airport']
intrain: False
Evidence: []
tokens: ['[CLS]', '[unused0]', 'ze', '##st', 'airways', ',', 'inc', '.', '[unused1]', 'operated', 'as', '[unused0]', 'air', '##asia', 'ze', '##st', '[unused1]', '(', 'formerly', '[unused0]', 'asian', 'spirit', 'and', 'ze', '##st', 'air', '[unused1]', ')', ',', 'was', 'a', 'low', '-', 'cost', 'airline', 'based', 'at', 'the', '[unused2]', 'nino', '##y', 'aquino', 'international', 'airport', '[unused3]', 'in', 'pas', '##ay', 'city', ',', 'metro', 'manila', 'in', 'the', 'philippines', '.', 'it', 'operated', 'scheduled', 'domestic', 'and', 'international', 'tourist', 'services', ',', 'mainly', 'feeder', 'services', 'linking', 'manila', 'and', 'cebu', 'with', '24', 'domestic', 'destinations', 'in', 

100%|██████████| 3053/3053 [12:33<00:00,  4.05it/s]


data/train_annotated_cls_data.txt final samples 1189412
pos samples: 35416
neg samples: 1153996
convert features...


  0%|          | 0/1000 [00:00<?, ?it/s]

####################################################################################################
E1: ['Lark Force', 'Lark Force', 'Lark Force']
E2: ['Australian Army']
intrain: True
Evidence: [[0, 1]]
tokens: ['[CLS]', '[unused0]', 'lark', 'force', '[unused1]', 'was', 'an', '[unused2]', 'australian', 'army', '[unused3]', 'formation', 'established', 'in', 'march', '1941', 'during', 'world', 'war', 'ii', 'for', 'service', 'in', 'new', 'britain', 'and', 'new', 'ireland', '.', 'under', 'the', 'command', 'of', 'lieutenant', 'colonel', 'john', 'scan', '##lan', ',', 'it', 'was', 'raised', 'in', 'australia', 'and', 'deployed', 'to', 'ra', '##bau', '##l', 'and', 'ka', '##vie', '##ng', ',', 'aboard', 'ss', 'kato', '##omba', ',', 'mv', 'ne', '##pt', '##una', 'and', 'hm', '##at', 'zealand', '##ia', ',', 'to', 'defend', 'their', 'strategically', 'important', 'harbour', '##s', 'and', 'airfields', '.', 'the', 'objective', 'of', 'the', 'force', ',', 'was', 'to', 'maintain', 'a', 'forward', 'air', 

100%|██████████| 1000/1000 [04:08<00:00,  4.03it/s]


data/dev_cls_data.txt final samples 392062
pos samples: 11437
neg samples: 380625
convert features...


  0%|          | 1/1000 [00:00<02:22,  7.02it/s]

####################################################################################################
E1: ['Miguel Riofrio Sánchez', 'Riofrio', 'Miguel Riofrío']
E2: ['September 7 , 1822']
intrain: None
Evidence: []
tokens: ['[CLS]', '[unused0]', 'miguel', 'rio', '##fr', '##io', 'sanchez', '[unused1]', '(', '[unused2]', 'september', '7', ',', '1822', '[unused3]', '–', 'october', '11', ',', '1879', ')', 'was', 'an', 'ecuador', '##an', 'poet', ',', 'novelist', ',', 'journalist', ',', 'or', '##ator', ',', 'and', 'educator', '.', 'he', 'was', 'born', 'in', 'the', 'city', 'of', 'lo', '##ja', '.', 'he', 'is', 'best', 'known', 'today', 'as', 'the', 'author', 'of', 'ecuador', "'", 's', 'first', 'novel', 'la', 'em', '##an', '##ci', '##pad', '##a', '(', '1863', ')', '.', 'owing', 'to', 'the', 'book', "'", 's', 'length', ',', 'usually', 'less', 'than', '100', 'pages', 'long', ',', 'many', 'experts', 'have', 'argued', 'that', 'it', 'is', 'really', 'a', 'novella', 'rather', 'than', 'a', 'full', 'nov

100%|██████████| 1000/1000 [04:00<00:00,  4.17it/s]

data/test_cls_data.txt final samples 388216
pos samples: 0
neg samples: 0



