In [None]:
# Python 3.8.10
# pytorch: 1.10.1+cpu

In [1]:
import re
import json
import itertools
import ner_A
import ner_filter
import rel_ext_A

Build char-word based NER Task...
build gaz embedding...
Build the Gaz bilstm...
build batched crf...


Some weights of the model checkpoint at ./relation-extraction/pretrained_models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Read data

In [2]:
with open('./annual_report_23.json') as r:
    corpus = json.loads(r.read())

### Cleaning

In [3]:
# keep only chinese documents
corpus = [stock for stock in corpus if stock['language']=='zh']

# adapt to training data
r1 = r'(?<=[a-zA-Z.]{1})(?=[\u4e00-\u9fff]{1})'
r3 = r'（\d+歲）'

for i in range(len(corpus)):
    for j in range(len(corpus[i]['text'])):
        x = corpus[i]['text'][j]['text']
        x = re.sub(r3, '，', x) #age in this format is stranger to the model
        x = re.sub('[－、•╱]', '，', x) #uncommon punctuation in training data
        x = re.sub('[（） ]', '', x) #uncommon punctuation in training data
        x = re.sub(r1, '，', x) #add comma after english and before chinese
        x = re.sub('[\uf098\uf099]', '，', x) #unwanted characters 
        corpus[i]['text'][j]['text'] = x

### NER labelling

In [4]:
texts, instances = ner_A.get_compatible_input(corpus)
ner_predictions = ner_A.predict(instances)
labelled_corpus = ner_A.get_ner_labelled_corpus(corpus, texts, ner_predictions)

  cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)


NER counts {'NAME': 1063, 'ORG': 2842}


### Filter unwanted NER

In [5]:
labelled_corpus, valid_org, valid_person, invalid_org, invalid_person =\
    ner_filter.filter_ner(labelled_corpus)

NER counts {'NAME': 549, 'ORG': 1870}
# Unique Valid person/org: 397/1389
# Unique Invalid person/org: 219/389


In [6]:
# # Inspecting (in)valid person/org lists.
# import numpy as np
# def get_count(ls):
#     ls = np.array(ls)
#     return sorted([((ls==x).sum(), x) for x in np.unique(ls)], reverse=True)

# get_count(valid_name)

### Relation extraction (Person to Person or Person to Org)

In [7]:
valid_relation = []
invalid_relation = []

for stock in labelled_corpus:
    for doc in stock['text']:
        doc['relation_list'] = []
        for h, t in itertools.combinations(doc['ner'], 2):
            
            is_per_per = (h['label_'] == 'NAME' and t['label_'] == 'NAME')
            is_per_org = (h['label_'] == 'NAME' and t['label_'] == 'ORG') or\
                         (h['label_'] == 'ORG' and t['label_'] == 'NAME')
            
            # Only person-to-person and person-to-organization
            if not is_per_per and not is_per_org:
                continue

            try:
                item = rel_ext_A.gen_item(doc, h, t, max_length=500)
                model_input = rel_ext_A.tokenize_item(item)
                logits = rel_ext_A.predict(model_input)
                probas = rel_ext_A.get_proba(logits, mode='GROUPED_SUM')
#                 print(h['text'], t['text'], probas[0:2])

                relation = dict(
                    predicate=probas[0][1],
                    subject_type=h['label_'],
                    object_type=t['label_'],
                    subject=h['text'],
                    object=t['text'],
                )
    
                # Only keep meaningful relations
                if is_per_per and relation['predicate'] != 'Unknown' or\
                    is_per_org and relation['predicate'] == 'Work':
                    doc['relation_list'].append(relation)
                    valid_relation.append(relation)
                
                else:
                    invalid_relation.append(relation)
                
            except Exception as e:
                # Exception can happen when the text length is too long.
                print(e)


Give up stripping
The size of tensor a (540) must match the size of tensor b (512) at non-singleton dimension 1
Give up stripping
The size of tensor a (540) must match the size of tensor b (512) at non-singleton dimension 1
Give up stripping
The size of tensor a (540) must match the size of tensor b (512) at non-singleton dimension 1
Give up stripping
The size of tensor a (526) must match the size of tensor b (512) at non-singleton dimension 1
Give up stripping
Give up stripping
The size of tensor a (556) must match the size of tensor b (512) at non-singleton dimension 1
Give up stripping
The size of tensor a (556) must match the size of tensor b (512) at non-singleton dimension 1


In [9]:
with open('final_corpus.json', 'w') as f:
    f.write(json.dumps(labelled_corpus))

In [None]:
# print('Identified', len(valid_relation), 'relations')
# print('Dropped', len(invalid_relation), 'relations')
# for relation in sorted(valid_relation, key=lambda x: x['predicate'])[:10]:
#     print(relation['subject'], '>', relation['predicate'], '>', relation['object'])

In [2]:
# Listing valid chinese organization names.
display(list(filter(lambda x: not re.match(r'^[a-zA-Z.]+$', x), sorted(set(valid_org))))[:10])

['三六零安全科技股份有限公司上交所',
 '三號幹線郊野公園段有限公司',
 '上海吉祥航空股份有限公司上海證券交易所',
 '上海同濟科技實業股份有限公司',
 '上海國際港務集團股份有限公司',
 '上海大眾公用事業集團股份有限公司',
 '上海復旦張江生物醫藥股份有限公司',
 '上海振華重工集團股份有限公司',
 '上海時代航運有限公司',
 '上海東方明珠新媒體股份有限公司上海證']

In [3]:
# Listing valid chinese organization names.
display(list(filter(lambda x: not re.match(r'^[a-zA-Z.]+$', x), sorted(set(valid_person))))[:10])

['丁良輝', '中國領', '于正人', '井賢棟', '付丹偉', '代者', '伍成業', '何平何平', '何成效', '何漢明']