In [8]:
import os
import json
import unidecode
import numpy as np
import random
from tqdm import tqdm
import spacy
from spacy import displacy
from spacy.tokenizer import Tokenizer
from spacy.util import compile_infix_regex
from spacy.lang.char_classes import ALPHA, ALPHA_LOWER, ALPHA_UPPER, HYPHENS
from spacy.lang.char_classes import CONCAT_QUOTES, LIST_ELLIPSES, LIST_ICONS
from spacy.util import compile_infix_regex
from spacy.matcher import Matcher
from datasets import load_dataset, concatenate_datasets
from collections import defaultdict, Counter
import matplotlib
import matplotlib.pyplot as plt
import transformers

In [9]:
# customize spacy tokenizer
nlp = spacy.load("en_core_web_lg")

# Modify tokenizer infix patterns
infixes = (
    LIST_ELLIPSES
    + LIST_ICONS
    + [
#         r"(?<=[0-9])[+\-\*^](?=[0-9-])",
        r"(?<=[{al}{q}])\.(?=[{au}{q}])".format(
            al=ALPHA_LOWER, au=ALPHA_UPPER, q=CONCAT_QUOTES
        ),
        r"(?<=[{a}]),(?=[{a}])".format(a=ALPHA),
#         r"(?<=[{a}])(?:{h})(?=[{a}])".format(a=ALPHA, h=HYPHENS),
#         r"(?<=[{a}0-9])[:<>=/](?=[{a}])".format(a=ALPHA),
        r"(?<=[{a}0-9])[:<>=](?=[{a}])".format(a=ALPHA),
    ]
)

infix_re = compile_infix_regex(infixes)
nlp.tokenizer.infix_finditer = infix_re.finditer

In [10]:
# DIR = "BioInfer"
DIR = 'KBP'
train_json = "train.json"
test_json = "test.json"

def load_data(src_dir=".", file_name="train.json"):
    raw_json = os.path.join(src_dir, file_name)
    file = open(raw_json, 'r')
    sentences = file.readlines()
    return [json.loads(line) for line in sentences]

train_data = load_data(DIR, train_json)
test_data = load_data(DIR, test_json)
print(len(train_data), len(test_data))

23784 289


In [11]:
def process(exp, res, sentid, sent2query, only_ent, ner_dict, re_dict, re2ner):
    exp_text = exp['sentText']
    text_doc = nlp(exp_text)
    exp_tokens = [unidecode.unidecode(tok.text) for tok in text_doc if tok.text != '\r\n']
    re_count = 0
    sent_res = []
    entMentions = [item for idx, item in enumerate(exp['entityMentions']) if item not in exp['entityMentions'][idx + 1:]]
    reMentions = [item for idx, item in enumerate(exp['relationMentions']) if item not in exp['relationMentions'][idx + 1:]]
    
    for ins_id, NER in enumerate(entMentions):
        query, ner_tag = NER['text'], NER['label']
        query_doc = nlp(query)
        query_tokens = [tok.text for tok in query_doc]
        query_ids = [idx for idx, tok in enumerate(exp_tokens) if tok == unidecode.unidecode(query_tokens[0])]
        query_id = None
        
        for idx in query_ids:
            if idx not in sent2query[sentid]:
                query_id = idx
                break
        
        if query_id is not None:
            if ner_tag != 'None':
                ner_dict[ner_tag] = ner_dict.get(ner_tag, 0) + 1
            tokens = np.array(["O"]*len(exp_tokens), dtype='object')
            tokens[query_id] = "B-" + ner_tag
            tokens[query_id+1:query_id+len(query_tokens)] = "I-" + ner_tag
            sent2query[sentid].append(query_id)

            for RE in reMentions:
                target, re_tag = RE['em2Text'], RE['label']
                if RE['em1Text'] == query and re_tag != 'None':
                    re_dict[re_tag] = re_dict.get(re_tag, 0) + 1
                    re2ner[re_tag].append(ner_tag)
                    target_doc = nlp(target)
                    target_tokens = [tok.text for tok in target_doc]
                    target_ids = [idx for idx, tok in enumerate(exp_tokens) if tok == unidecode.unidecode(target_tokens[0])]
                    target_id = None

                    for idx in target_ids:
                        if idx != query_id:
                            target_id = idx
                            break

                    if target_id is not None:
                        re_count += 1
                        tokens[target_id] = "B-" + re_tag
                        tokens[target_id+1:target_id+len(target_tokens)] = "I-" + re_tag
#                         print("\ttarget_id: {}, target_entity: {}, target_tag: {}".format(target_id, target, re_tag))
                        
            ins_ID = len(res)+len(sent_res)
            sent_res.append({
                "tokens":exp_tokens, "ner_tags":tokens.tolist(), 
                "query_ids":query_id, "sentID":sentid, "instanceID":ins_ID,
            })
        
    # check query entity tag
    P = False
#     for instance in sent_res:
#         tags, queryID = instance["ner_tags"], instance["query_ids"]
#         if '/' in tags[queryID]:
#             P = True
#             break
    
    if only_ent:
        if re_count == 0 and not P: # not contain relations
            res += sent_res
    else:
        if not P:
            res += sent_res
        

def process_data(dataset, only_ent=False):
    res = []
    sent2query = defaultdict(list)
    re_dict = defaultdict(int)
    ner_dict = defaultdict(int)
    re2ner = defaultdict(list)
    for i, instance in tqdm(enumerate(dataset)):
#         try:
#             process(instance, res, i, sent2query, only_ent=only_ent)
#         except:
#             pass
        process(instance, res, i, sent2query, only_ent, ner_dict, re_dict, re2ner)
    
    return res, ner_dict, re_dict, re2ner


def saving(f_path, res):
    with open(f_path, 'w') as f: 
        for value in res:
            f.write(json.dumps(value))
            f.write('\n')

In [12]:
# DIR = "BioInfer"
DIR = "KBP"
train_json = 'train_KBP.json'
train_file = os.path.join(DIR, train_json)
test_json = 'test_KBP.json'
test_file = os.path.join(DIR, test_json)

if os.path.exists(test_file):
    test_processed_data = load_dataset('json', data_files=test_file)['train']
else:
    test_processed_data, ner_dict_test, re_dict_test, re2ner_test = process_data(test_data)
    saving(test_file, test_processed_data)

if os.path.exists(train_file):
    train_processed_data = load_dataset('json', data_files=train_file)['train']
else:
    train_processed_data, ner_dict_train, re_dict_train, re2ner_train = process_data(train_data)
    saving(train_file, train_processed_data)

289it [00:04, 58.51it/s]
23784it [10:27, 37.92it/s]


## Test pattern extraction

In [7]:
train_file = 'BioInfer/train_BIF.json'
test_file = 'BioInfer/test_BIF.json'
datasets = load_dataset(
    'json', data_files={'train': train_file, 'test': test_file}
)
train_BIF = datasets['train']
test_BIF = datasets['test']

Using custom data configuration default-60954ea1da94cb08


Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /Users/yufli/.cache/huggingface/datasets/json/default-60954ea1da94cb08/0.0.0/793d004298099bd3c4e61eb7878475bcf1dc212bf2e34437d85126758720d7f9...


0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /Users/yufli/.cache/huggingface/datasets/json/default-60954ea1da94cb08/0.0.0/793d004298099bd3c4e61eb7878475bcf1dc212bf2e34437d85126758720d7f9. Subsequent calls will reuse this data.


In [11]:
# # stats instances that belong to the sentences that contain relations
# re_sent = set([sentID for sentID, tags in zip(train['sentID'], train['ner_tags']) if any('/' in tag for tag in tags)])
# re_instances = set([insID for sentID, insID in zip(train['sentID'], train['instanceID']) if sentID in re_sent])
# ent_sent = set(train['sentID']) - re_sent
# ent_instances = set(train['instanceID']) - re_instances
# # sum([int(any('/' in tag for tag in tags)) for tags in train['ner_tags']])
# print("# sents {}, # re sents {}, # re ins {}, # ent sents {}, # ent ins {}".format(len(set(train['sentID'])), len(re_sent), len(re_instances), len(ent_sent), len(ent_instances)))

### Baseline Attention

In [None]:
# attention analysis on entity part
DIR = 'attention'
eval_data = 'quarter_base_ent'
batch_size = 20
test_data = load_dataset('json', data_files=os.path.join(DIR, 'quarter_ent_data.json'))['train']
tokenizer = torch.load('tokenizer.pt')
f_list = [x for x in os.listdir(os.path.join(DIR, eval_data)) if 'attention' in x]
f_list = sorted(f_list)
l_list = [x for x in os.listdir(os.path.join(DIR, eval_data)) if 'label' in x]
l_list = sorted(l_list)

with open('attention_base_ent.html', 'w') as f:
    for i, instance in enumerate(test_data):
        f_index = int(i/batch_size)
        ins_index = i % batch_size
        queryID = instance['query_ids'][0]
        attentions = torch.load(os.path.join(DIR, eval_data, f_list[f_index]), map_location=torch.device('cpu')) # B X T X T
        attentions = np.round(attentions.numpy(), decimals=3)[ins_index][queryID]
        input_ids = instance['input_ids']
        words = [tokenizer.convert_ids_to_tokens([int(input_id)])[0].lstrip('Ġ') for input_id in input_ids] # remove "Ġ"
        tags = instance['ner_tags']
        tokens = instance['tokens']
        query_entities = ' '.join([tok for tok, tag in zip(tokens, tags) if (tag != "O" and "/" not in tag)])
        target_entities = ' '.join([tok for tok, tag in zip(tokens, tags) if (tag != "O" and "/" in tag)])

        if target_entities:
            targetID = min([idx for idx, tag in enumerate(tags) if (tag != "O" and "/" in tag)])

        if target_entities:
#             print("query entities: {}, target entities: {}".format(query_entities, target_entities))
            info = "query entity (%d): {%s}, target entity (%d): {%s}"%(queryID, query_entities, targetID, target_entities)
        else:
#             print("query entities: {}".format(query_entities))
            info = "query entity (%d): {%s}"%(queryID, query_entities)
        
        s = colorize(words, attentions)
        f.write(f'{info}<br />')
        f.write(f'{s}<br />')


In [None]:
# attention analysis on relation part
DIR = 'attention'
eval_data = 'quarter_base_re'
batch_size = 20
test_data = load_dataset('json', data_files=os.path.join(DIR, 'quarter_re_data.json'))['train']
tokenizer = torch.load('tokenizer.pt')
f_list = [x for x in os.listdir(os.path.join(DIR, eval_data)) if 'attention' in x]
f_list = sorted(f_list)
l_list = [x for x in os.listdir(os.path.join(DIR, eval_data)) if 'label' in x]
l_list = sorted(l_list)

with open('attention_base_re.html', 'w') as f:
    for i, instance in enumerate(test_data):
        f_index = int(i/batch_size)
        ins_index = i % batch_size
        queryID = instance['query_ids'][0]
        attentions = torch.load(os.path.join(DIR, eval_data, f_list[f_index]), map_location=torch.device('cpu')) # B X T X T
        attentions = np.round(attentions.numpy(), decimals=3)[ins_index][queryID]
        input_ids = instance['input_ids']
        words = [tokenizer.convert_ids_to_tokens([int(input_id)])[0].lstrip('Ġ') for input_id in input_ids] # remove "Ġ"
        tags = instance['ner_tags']
        tokens = instance['tokens']
        query_entities = ' '.join([tok for tok, tag in zip(tokens, tags) if (tag != "O" and "/" not in tag)])
        target_entities = ' '.join([tok for tok, tag in zip(tokens, tags) if (tag != "O" and "/" in tag)])

        if target_entities:
            targetID = min([idx for idx, tag in enumerate(tags) if (tag != "O" and "/" in tag)])

        if target_entities:
#             print("query entities: {}, target entities: {}".format(query_entities, target_entities))
            info = "query entity (%d): {%s}, target entity (%d): {%s}"%(queryID, query_entities, targetID, target_entities)
        else:
#             print("query entities: {}".format(query_entities))
            info = "query entity (%d): {%s}"%(queryID, query_entities)
        
        s = colorize(words, attentions)
        f.write(f'{info}<br />')
        f.write(f'{s}<br />')


In [None]:
# attention analysis on entity part
DIR = 'attention'
eval_data = 'relation_base_ent'
batch_size = 20
test_data = load_dataset('json', data_files=os.path.join(DIR, 'quarter_ent_data.json'))['train']
tokenizer = torch.load('tokenizer.pt')
f_list = [x for x in os.listdir(os.path.join(DIR, eval_data)) if 'attention' in x]
f_list = sorted(f_list)
l_list = [x for x in os.listdir(os.path.join(DIR, eval_data)) if 'label' in x]
l_list = sorted(l_list)

with open('attention_relation_base_ent.html', 'w') as f:
    for i, instance in enumerate(test_data):
        f_index = int(i/batch_size)
        ins_index = i % batch_size
        queryID = instance['query_ids'][0]
        attentions = torch.load(os.path.join(DIR, eval_data, f_list[f_index]), map_location=torch.device('cpu')) # B X T X T
        attentions = np.round(attentions.numpy(), decimals=3)[ins_index][queryID]
        input_ids = instance['input_ids']
        words = [tokenizer.convert_ids_to_tokens([int(input_id)])[0].lstrip('Ġ') for input_id in input_ids] # remove "Ġ"
        tags = instance['ner_tags']
        tokens = instance['tokens']
        query_entities = ' '.join([tok for tok, tag in zip(tokens, tags) if (tag != "O" and "/" not in tag)])
        target_entities = ' '.join([tok for tok, tag in zip(tokens, tags) if (tag != "O" and "/" in tag)])

        if target_entities:
            targetID = min([idx for idx, tag in enumerate(tags) if (tag != "O" and "/" in tag)])

        if target_entities:
#             print("query entities: {}, target entities: {}".format(query_entities, target_entities))
            info = "query entity (%d): {%s}, target entity (%d): {%s}"%(queryID, query_entities, targetID, target_entities)
        else:
#             print("query entities: {}".format(query_entities))
            info = "query entity (%d): {%s}"%(queryID, query_entities)
        
        s = colorize(words, attentions)
        f.write(f'{info}<br />')
        f.write(f'{s}<br />')


In [None]:
# attention analysis on relation part
DIR = 'attention'
eval_data = 'relation_base_re'
batch_size = 20
test_data = load_dataset('json', data_files=os.path.join(DIR, 'quarter_re_data.json'))['train']
tokenizer = torch.load('tokenizer.pt')
f_list = [x for x in os.listdir(os.path.join(DIR, eval_data)) if 'attention' in x]
f_list = sorted(f_list)
l_list = [x for x in os.listdir(os.path.join(DIR, eval_data)) if 'label' in x]
l_list = sorted(l_list)

with open('attention_relation_base_re.html', 'w') as f:
    for i, instance in enumerate(test_data):
        f_index = int(i/batch_size)
        ins_index = i % batch_size
        queryID = instance['query_ids'][0]
        attentions = torch.load(os.path.join(DIR, eval_data, f_list[f_index]), map_location=torch.device('cpu')) # B X T X T
        attentions = np.round(attentions.numpy(), decimals=3)[ins_index][queryID]
        input_ids = instance['input_ids']
        words = [tokenizer.convert_ids_to_tokens([int(input_id)])[0].lstrip('Ġ') for input_id in input_ids] # remove "Ġ"
        tags = instance['ner_tags']
        tokens = instance['tokens']
        query_entities = ' '.join([tok for tok, tag in zip(tokens, tags) if (tag != "O" and "/" not in tag)])
        target_entities = ' '.join([tok for tok, tag in zip(tokens, tags) if (tag != "O" and "/" in tag)])

        if target_entities:
            targetID = min([idx for idx, tag in enumerate(tags) if (tag != "O" and "/" in tag)])

        if target_entities:
#             print("query entities: {}, target entities: {}".format(query_entities, target_entities))
            info = "query entity (%d): {%s}, target entity (%d): {%s}"%(queryID, query_entities, targetID, target_entities)
        else:
#             print("query entities: {}".format(query_entities))
            info = "query entity (%d): {%s}"%(queryID, query_entities)
        
        s = colorize(words, attentions)
        f.write(f'{info}<br />')
        f.write(f'{s}<br />')


In [None]:
# test torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import re
import spacy
from spacy.tokenizer import Tokenizer
from scipy.stats import entropy
from sklearn.metrics import roc_curve, auc, average_precision_score

In [None]:
nlp = spacy.load('en_core_web_sm')
nlp.tokenizer = Tokenizer(nlp.vocab, token_match=re.compile(r'\S+').match)

In [None]:
inputs = [1,78,32,5,24,7,22,16,72]
input_tensor = Variable(torch.LongTensor(inputs).unsqueeze(0))

In [None]:
input_tensor

In [None]:
def common_get_entropy(pos : torch.Tensor):
    pred_prob = F.softmax(pos, dim=-1) # (N, k)
    etp = entropy(pred_prob, axis=-1) # np.ndarray
    return etp

In [None]:
common_get_entropy(torch.rand(1,4))[0]

In [None]:
doc = nlp(' '.join(["I", "am", "a", "good", "person's", "person"]))

In [None]:
[tok for tok in doc]

In [None]:
B = 16
T = 25
kl_loss = nn.KLDivLoss(reduction='none')
pa = torch.rand(B, T).sigmoid()
target_att = torch.rand(B, T).sigmoid()
loss = kl_loss(normalize(pa).log(), normalize(target_att)).sum(dim=1)
loss

In [None]:
def normalize(pa):
    normalize_pa = pa / pa.sum(dim=1).unsqueeze(1).expand_as(pa)
    normalize_pa[normalize_pa != normalize_pa] = 0
    return normalize_pa

In [None]:
normalize(target_att).sum(dim=1)

In [None]:
a ={}

In [None]:
list(a.keys())

In [None]:
import torch
import torch.nn.functional as F
from scipy.stats import entropy
a = torch.rand(1,10)
a

In [None]:
_, y = a.max(dim=1)

In [None]:
def common_get_entropy(pos : torch.Tensor):
    pred_prob = F.softmax(pos, dim=-1) # (N, k)
    etp = entropy(pred_prob, axis=-1)/np.log(pos.size(-1)) # np.ndarray
    return etp

In [None]:
def common_get_maxpos(pos : torch.Tensor):
    test_pred_pos, _ = torch.max(F.softmax(pos, dim=1), dim=1)
    return 1 - test_pred_pos.detach().cpu().numpy()

In [None]:
pos1 = torch.tensor([[0,0,0,0,0,0,0]]).float()
pos2 = torch.tensor([[1,1,1,1,1,1,1]]).float()
pos3 = torch.tensor([[0,0,0,5,0,0,0]]).float()
pos4 = torch.tensor([[0.0001,0.0001,0.0001,0.0001,0.0001,0.0001,50]]).float()
print(common_get_entropy(pos1)[0], common_get_entropy(pos2)[0], common_get_entropy(pos3)[0], common_get_entropy(pos4)[0])
print(common_get_maxpos(pos1)[0], common_get_maxpos(pos2)[0], common_get_maxpos(pos3)[0], common_get_maxpos(pos4)[0])

In [None]:
a = [1,6,5,3,7,16,2]
c = [3,9,22,8,1,3]

In [None]:
def next_greater_element(arr):
    n = len(arr)
    result = [None]*n
    stack = []
#     print("increase index array: {}".format(arr))
    for i in range(n):
#         print("The {}th index".format(i))
        while stack and arr[stack[-1]] < arr[i]:
#             print("stack: {}, stack[-1]: {}, arr[stack[-1]]: {}, arr[i]: {}".format(stack, stack[-1], arr[stack[-1]], arr[i]))
            result[arr[stack.pop()]] = arr[i]
#             print("result: {}".format(result))
        stack.append(i)
    del stack
    return result

In [None]:
large_arr = sorted(range(len(a)), key=lambda x: a[x])
odd = next_greater_element(large_arr)

In [None]:
a = [1,5,3,8,0,4,1]
print(list(set(a)))
print(list(dict.fromkeys(a)))
b = '4sgh485saj.0?3'
print(''.join(list(set(b))))
print(''.join(list(dict.fromkeys(b))))

In [None]:
w = 'sfdk3o40s.?@s-34://s2 is a good thing to play!'
w.replace('play', 'S')

## KBP class analysis

In [3]:
train_file = "KBP/train_KBP.json"
test_file = "KBP/test_KBP.json"
kbp_datasets = load_dataset('json', data_files={'train':train_file, 'test':test_file})
kbp_train = kbp_datasets['train']
kbp_test = kbp_datasets['test']
print('train: {}\ntest: {}'.format(kbp_train, kbp_test))

Using custom data configuration default-86ab6f03c9ffd4b9
Reusing dataset json (/Users/yufli/.cache/huggingface/datasets/json/default-86ab6f03c9ffd4b9/0.0.0/793d004298099bd3c4e61eb7878475bcf1dc212bf2e34437d85126758720d7f9)


train: Dataset({
    features: ['tokens', 'ner_tags', 'query_ids', 'sentID', 'instanceID'],
    num_rows: 144646
})
test: Dataset({
    features: ['tokens', 'ner_tags', 'query_ids', 'sentID', 'instanceID'],
    num_rows: 919
})


In [4]:
def cluster_tags(examples):
    all_tags = []
    for i, tags in enumerate(examples['ner_tags']):
        new_tags = []
        for tag in tags:
            if tag == 'O': # non tag
                new_tags.append(tag)
            elif ':' in tag: # relation tag
                prefix, IOB_tag = tag[:2], tag[2:]
                new_tags.append(IOB_tag)
            else:
                prefix, IOB_tag = tag[:2], tag[2:]
#                 new_t = prefix + IOB_tag.split(',')[0].split('/')[1]
#                 new_tags.append(new_t)
                new_tags.append(IOB_tag)
                
        all_tags.append(new_tags)
            
    examples['ner_tags'] = all_tags
    return examples

In [5]:
def get_ie(dataset):
    ner_test = defaultdict(list)
    re_test = defaultdict(list)
    ner_count = defaultdict(int)
    re_count = defaultdict(int)
    for i, instance in tqdm(enumerate(dataset)):
        for j, tag in enumerate(instance['ner_tags']):
            if tag != 'O':
                if ':' in tag: # relation tag
                    re_test[tag[2:]].append(instance['tokens'][j])
                    re_count[tag[2:]] = re_count.get(tag[2:], 0) + 1
                else: # entity tag
                    ner_test[tag[2:]].append(instance['tokens'][j])
                    ner_count[tag[2:]] = ner_count.get(tag[2:], 0) + 1
    return ner_test, re_test, ner_count, re_count

In [6]:
new_kbp_train = kbp_train.map(
    cluster_tags, 
    batched=True,
    num_proc=None,
    load_from_cache_file=False
)

  0%|          | 0/145 [00:00<?, ?ba/s]

In [7]:
new_kbp_test = kbp_test.map(
    cluster_tags, 
    batched=True,
    num_proc=None,
    load_from_cache_file=False
)

  0%|          | 0/1 [00:00<?, ?ba/s]

In [8]:
new_kbp_train.to_json("KBP/KBP_train_cluster.json")
new_kbp_test.to_json("KBP/KBP_test_cluster.json")

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

415845

In [10]:
ner_train, re_train, ner_count_train, re_count_train = get_ie(kbp_train)

144646it [00:15, 9266.92it/s]


In [11]:
ner_test, re_test, ner_count_test, re_count_test = get_ie(kbp_test)

919it [00:00, 7527.96it/s]


In [12]:
re_train.keys()

dict_keys(['per:country_of_death', 'per:children', 'per:parents', 'per:country_of_birth', 'per:religion', 'per:countries_of_residence'])

In [13]:
re_test.keys()

dict_keys(['per:country_of_death', 'per:countries_of_residence', 'per:country_of_birth', 'per:children', 'org:founded_by', 'org:parents', 'per:parents', 'org:shareholders', 'per:employee_or_member_of', 'org:subsidiaries', 'org:member_of', 'per:religion'])

In [14]:
sorted_ner_count_train = dict(sorted(ner_count_train.items(), key=lambda item: item[1], reverse=True)) 
sorted_ner_count_train

{'/location/city,/location': 20797,
 '/person': 17650,
 '/title': 15978,
 '/location/country,/location': 8710,
 '/person/author,/person': 7794,
 '/location': 7422,
 '/person/monarch,/person': 5757,
 '/religion,/religion/religion': 2957,
 '/person,/person/politician': 2891,
 '/person/monarch,/person/politician,/person': 2651,
 '/person/politician,/person': 2626,
 '/organization': 2112,
 '/language': 2031,
 '/written_work': 2004,
 '/god': 2000,
 '/person/artist,/person': 1991,
 '/event': 1958,
 '/person/musician,/person/artist,/person': 1929,
 '/person/religious_leader,/person': 1911,
 '/location,/building': 1887,
 '/location,/location/province': 1836,
 '/person/politician,/person/author,/person': 1774,
 '/person/actor,/person/artist,/person': 1581,
 '/person/soldier,/person/actor,/person/politician,/person/author,/person': 1499,
 '/person/actor,/person': 1478,
 '/location,/organization,/location/city,/organization/company': 1454,
 '/organization/sports_team,/organization': 1438,
 '/peop

In [15]:
sorted_ner_count_test = dict(sorted(ner_count_test.items(), key=lambda item: item[1], reverse=True)) 
sorted_ner_count_test

{'/person': 356,
 '/organization': 173,
 '/location/city,/location': 118,
 '/location': 107,
 '/organization,/organization/company': 64,
 '/location/country,/location': 50,
 '/person/author,/person': 35,
 '/location,/organization,/location/city,/organization/company': 22,
 '/person/artist,/person/actor,/person/author,/person': 22,
 '/person/actor,/person/author,/person': 19,
 '/location,/location/province': 17,
 '/person/politician,/person': 17,
 '/location/country,/government,/government/government,/location,/government_agency': 15,
 '/person/author,/person/actor,/person/artist,/person': 15,
 '/person/artist,/person/actor,/person/author,/person/director,/person': 15,
 '/person/politician,/person/religious_leader,/person': 15,
 '/location/county,/organization,/location/city,/location,/organization/company': 13,
 '/person/author,/person/actor,/person/artist,/person/musician,/person': 12,
 '/religion,/religion/religion': 12,
 '/person/actor,/person/politician,/person/author,/person': 11,

In [16]:
ner_intersect = set(sorted_ner_count_train) & set(sorted_ner_count_test)
ner_intersect

{'/art/film,/art',
 '/award',
 '/broadcast_network,/organization',
 '/broadcast_program',
 '/building,/building/hospital,/location,/organization,/person/author,/person',
 '/event/military_conflict,/event',
 '/finance/currency,/finance',
 '/god',
 '/government,/government/government,/military,/government_agency,/organization',
 '/government,/organization,/government/political_party',
 '/government,/organization,/government/political_party,/organization/terrorist_organization',
 '/language,/location/city,/location',
 '/location',
 '/location,/building,/building/airport',
 '/location,/building/airport,/building',
 '/location,/location/body_of_water',
 '/location,/location/province',
 '/location,/organization,/location/city,/organization/company',
 '/location,/organization,/location/province,/organization/company',
 '/location,/organization,/organization/educational_institution,/organization/company',
 '/location,/organization,/rail/railway,/rail,/organization/company',
 '/location/city,/l

In [None]:
ent_list = ['person', 'religion', 'organization', 'location']
ent_groups = {'building'}

In [18]:
re_count_train.keys()

dict_keys(['per:country_of_death', 'per:children', 'per:parents', 'per:country_of_birth', 'per:religion', 'per:countries_of_residence'])

In [19]:
re_count_test.keys()

dict_keys(['per:country_of_death', 'per:countries_of_residence', 'per:country_of_birth', 'per:children', 'org:founded_by', 'org:parents', 'per:parents', 'org:shareholders', 'per:employee_or_member_of', 'org:subsidiaries', 'org:member_of', 'per:religion'])

In [17]:
re_intersect = set(re_count_train) & set(re_count_test)
re_intersect

{'per:children',
 'per:countries_of_residence',
 'per:country_of_birth',
 'per:country_of_death',
 'per:parents',
 'per:religion'}

In [None]:
a = [{"entity1":"Beklss", "relation": "Located_in", "entity2": "England"}, 
     {"entity1":"Jorge", "relation": "Founder_of", "entity2": "MC.@inc"}]
b = [{"entity1":"O'ssli", "relation": "Winner_of", "entity2": "World Championship"}, 
     {"entity1":"s1kv", "relation": "Born_in", "entity2": "Mars"}]
res = [a, b]

In [None]:
import csv

with open("triplets.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerows(res)

In [None]:
from random import choice
from math import floor

def partition(array, p):
    # swap array[p] and array[n]
#     print("pivot: {}, original array: {}, array[p]: {}".format(p, array, array[p]))
    temp = array[p]
    array[p] = array[-1]
    array[-1] = temp
    # iterately judge each element's relation with array[p]
    l = 0
    for i in range(len(array)-1):
        if array[i] < array[-1]:
            # swap array[l] and array[i]
            temp = array[i]
            array[i] = array[l]
            array[l] = temp
            # update l
            l += 1
        
    # swap back the pivot array[p] and array[l]
    temp = array[-1]
    array[-1] = array[l]
    array[l] = temp
#     print("pivot: {}, array: {}, l: {}".format(p, array, l))
    
    return l
     
    
def quicksort(array):
    if len(array) > 1:
        # choose the pivot
        p = choice(range(len(array)))
        # partition with the pivot
        r = partition(array, p)
        # recursion
        quicksort(array[:r])
        quicksort(array[r+1:])
        
        
def merge(array, m):
    res = []
    i, j = 0, m+1
    for b in range(len(array)):
        if j > len(array)-1:
            res.append(array[i])
            i += 1
        elif i > m:
            res.append(array[j])
            j += 1
        elif array[i] < array[j]:
            res.append(array[i])
            i += 1
        else:
            res.append(array[j])
            j += 1
            
    array = res


def mergesort(array):
    if len(array) > 1:
        m = floor(len(array)/2)
        mergesort(array[:m+1])
        mergesort(array[m+1:])
        merge(array, m)

In [None]:
a = [1,70,2,3,5,13,7,68,4,90,21,42,75,39]
# quicksort(a)
# mergesort(a)

## BioInfer

In [15]:
def get_ie_BIF(dataset):
    ner_test = defaultdict(list)
    re_test = defaultdict(list)
    ner_count = defaultdict(int)
    re_count = defaultdict(int)
    for i, instance in tqdm(enumerate(dataset)):
        for j, tag in enumerate(instance['ner_tags']):
            if tag != 'O':
                if '_' in tag: # relation tag
                    re_test[tag[2:]].append(instance['tokens'][j])
                    re_count[tag[2:]] = re_count.get(tag[2:], 0) + 1
                else: # entity tag
                    ner_test[tag[2:]].append(instance['tokens'][j])
                    ner_count[tag[2:]] = ner_count.get(tag[2:], 0) + 1
    return ner_test, re_test, ner_count, re_count

In [16]:
ner_train, re_train, ner_count_train, re_count_train = get_ie_BIF(train_BIF)

13683it [00:01, 8429.01it/s]


In [40]:
ner_dict_test

defaultdict(int,
            {'Individual_protein,Protein': 1391,
             'Protein_family_or_group,Protein': 104,
             'Gene,Nucleic_acid': 70,
             'DNA,Nucleic_acid': 1,
             'Protein_complex,Protein': 12})

In [42]:
ner_dict_train

defaultdict(int,
            {'base.type_ontology.physically_instantiable,base.type_ontology.non_agent,medicine.anatomical_structure': 14,
             'people.cause_of_death,base.firstaid.topic,base.type_ontology.abstract,medicine.risk_factor,base.pethealth.cause,base.pethealth.pet_disease_risk_factor,base.type_ontology.inanimate,base.type_ontology.non_agent,medicine.disease_cause,medicine.disease': 14,
             'people.cause_of_death,base.tagit.concept,medicine.risk_factor,user.tfmorris.default_domain.merge_candidate,base.type_ontology.abstract,user.alexander.misc.murder_method,user.alexander.misc.topic,base.type_ontology.inanimate,fictional_universe.medical_condition_in_fiction,base.consumermedical.disease,base.consumermedical.medical_term,base.type_ontology.non_agent,base.disaster2.type_of_injury,medicine.disease_cause,medicine.disease,film.film_subject': 15,
             'organization.organization_sector': 3,
             'base.type_ontology.abstract,base.type_ontology.inanima

In [41]:
re_dict_test

defaultdict(int,
            {'POS_REG(-)_SUPPRESS': 2,
             'POS_REG(-)_Assembly': 5,
             'POS_ACTION_SUPPRESS': 7,
             'POS_ACTION_BIND': 305,
             'NEG_ACTION_Amount': 5,
             'POS_ACTION_MEMBER': 184,
             'POS_ACTION_Physical': 4,
             'POS_ACTION_AFFECT': 6,
             'POS_ACTION_INTERACT': 87,
             'POS_ACTION_Change': 36,
             'POS_REG(0)_POLYMERIZE': 12,
             'POS_REG(0)_LOCALIZE': 6,
             'POS_ACTION_Causal': 27,
             'POS_ACTION_MEDIATE': 8,
             'POS_ACTION_SIMILAR': 40,
             'POS_REG(+)_DEPOLYMERIZE': 4,
             'NEG_ACTION_CORELATE': 4,
             'POS_ACTION_PHOSPHORYLATE': 13,
             'POS_ACTION_FNSIMILAR': 8,
             'POS_ACTION_RELATE': 51,
             'POS_ACTION_COLOCALIZE': 63,
             'POS_ACTION_ASSEMBLE': 5,
             'POS_REG(+)_ASSEMBLE': 11,
             'POS_ACTION_UNBIND': 6,
             'POS_REG(0)_ASSEMBLE': 13,


In [43]:
re_dict_train

defaultdict(int,
            {'people.cause_of_death.includes_causes_of_death': 89,
             'medicine.risk_factor.diseases': 91,
             'medicine.drug.active_moieties': 91,
             'medicine.disease.risk_factors': 90,
             'medicine.medical_treatment.used_to_treat': 93,
             'medicine.disease.symptoms': 95,
             'medicine.drug_ingredient.active_moiety_of_drug': 90,
             'medicine.disease.treatments': 92,
             'biology.organism_classification.lower_classifications': 78,
             'medicine.symptom.symptom_of': 94,
             'POS_ACTION_SUPPRESS': 7,
             'POS_ACTION_Causal': 31,
             'NEG_ACTION_Amount': 3,
             'POS_ACTION_CROSS-LINK': 5,
             'POS_REG(-)_POLYMERIZE': 8,
             'POS_ACTION_DISRUPT': 6,
             'POS_ACTION_BIND': 304,
             'POS_ACTION_INTERACT': 83,
             'POS_ACTION_DOWNREGULATE': 5,
             'POS_REG(+)_Change': 5,
             'POS_REG(0)_Change

In [15]:
re2ner_train.keys()

dict_keys(['per:country_of_death', 'per:children', 'per:parents', 'per:country_of_birth', 'per:religion', 'per:countries_of_residence'])

In [16]:
re2ner_train['per:country_of_birth']

['/person',
 '/person,/person/politician',
 '/person/actor,/person/artist,/person',
 '/person/soldier,/person',
 '/person',
 '/person,/person/author,/person/politician',
 '/person/monarch,/person',
 '/person',
 '/person',
 '/person/monarch,/person',
 '/person/monarch,/person',
 '/person/author,/person/musician,/person',
 '/person/monarch,/person',
 '/person/monarch,/person',
 '/person',
 '/person',
 '/person/monarch,/person,/person/politician',
 '/person/author,/person/artist,/person/architect,/person',
 '/person',
 '/person/architect,/person',
 '/person/artist,/person/musician,/person',
 '/person/musician,/person/artist,/person',
 '/person/musician,/person/artist,/person',
 '/person/author,/person',
 '/person/author,/person',
 '/person/politician,/person/author,/person',
 '/person',
 '/person/author,/person',
 '/person/author,/person',
 '/person/artist,/person/author,/person/musician,/person',
 '/person/author,/person/architect,/person',
 '/person',
 '/person/musician,/person/artist,/

In [17]:
re2ner_train['per:religion']

['/person/soldier,/person/actor,/person/politician,/person/author,/person',
 '/person/author,/person',
 '/person',
 '/person/author,/person',
 '/person/author,/person',
 '/person/soldier,/person,/person/author,/person/politician',
 '/person',
 '/person',
 '/person/author,/person',
 '/person/author,/person',
 '/person',
 '/person/artist,/person/actor,/person/musician,/person/author,/person',
 '/person/artist,/person/author,/person/musician,/person',
 '/person/author,/person',
 '/person/doctor,/person/politician,/person',
 '/person/author,/person/artist,/person',
 '/person/author,/person',
 '/person/actor,/person/artist,/person/politician,/person/athlete,/person/author,/person',
 '/person/author,/person',
 '/person/author,/person',
 '/person',
 '/person/politician,/person/author,/person',
 '/person/politician,/person',
 '/person/artist,/person/religious_leader,/person/author,/person',
 '/person/artist,/person/religious_leader,/person/author,/person',
 '/person,/person/politician',
 '/per

In [18]:
re2ner_train['per:countries_of_residence']

['/location/city,/location,/person']

In [19]:
ner_dict_train

defaultdict(int,
            {'/person/soldier,/person/monarch,/person': 265,
             '/person/soldier,/person/actor,/person/politician,/person/author,/person': 1486,
             '/organization/sports_team,/organization': 994,
             '/written_work': 1545,
             '/food': 120,
             '/location/country,/location/city,/location': 351,
             '/location/city,/location': 18856,
             '/chemistry': 765,
             '/organization,/location,/government_agency,/location/province': 298,
             '/person/artist,/person/musician,/person/actor,/person/author,/person': 161,
             '/organization,/organization/company': 858,
             '/title': 14246,
             '/location': 5929,
             '/person/musician,/person/artist,/person': 1289,
             '/person/monarch,/person/politician,/person': 1621,
             '/person,/person/author,/person/politician': 514,
             '/language': 1985,
             '/location,/building': 1100,
    

In [29]:
from sklearn.metrics import roc_curve, auc, average_precision_score

def common_get_auc(y_test, y_score, name=None):
    fpr, tpr, threshold = roc_curve(y_test, y_score)  ###计算真正率和假正率
    print('tpr: {}, fpr: {}, threshold: {}'.format(tpr, fpr, threshold))
    roc_auc = auc(fpr, tpr)  ###计算auc的值
    if name is not None:
        print(name, 'auc is ', roc_auc)
    return roc_auc

In [10]:
np.random.randn(5, 10)

array([[-0.27776034,  0.05135848, -0.06862446,  0.51427202,  1.07452298,
        -0.04925929, -0.13965543,  1.3813591 ,  1.21517014,  0.11244194],
       [ 1.567292  , -1.38216132,  0.2234964 , -1.14617894, -0.98066249,
         0.39286229,  2.02137936, -1.17637032, -1.6938079 , -0.84654488],
       [ 1.14265737, -0.40311433,  0.52785212,  0.37675819, -0.28956108,
         1.82246276,  0.21846135, -1.47325391, -0.2996324 , -1.04349187],
       [-0.76281126, -1.21143999,  0.38298783,  0.42655664,  0.39284529,
        -0.79897156, -0.38956972, -1.36110853,  1.25739664, -1.05175228],
       [ 0.01629581, -0.10530652,  1.41602314, -1.0878351 , -0.78790627,
         0.59590408, -1.00185093, -0.44997144, -0.67020579,  0.47294359]])