In [35]:
import os
import json
import unidecode
import numpy as np
import random
from tqdm import tqdm
import spacy
from spacy import displacy
from spacy.tokenizer import Tokenizer
from spacy.util import compile_infix_regex
from spacy.lang.char_classes import ALPHA, ALPHA_LOWER, ALPHA_UPPER, HYPHENS
from spacy.lang.char_classes import CONCAT_QUOTES, LIST_ELLIPSES, LIST_ICONS
from spacy.util import compile_infix_regex
from spacy.matcher import Matcher
from datasets import load_dataset, concatenate_datasets
from collections import defaultdict, Counter
import matplotlib
import matplotlib.pyplot as plt
import transformers
import pandas as pd
import torch
import torch.nn as nn

In [3]:
# customize spacy tokenizer
nlp = spacy.load("en_core_web_lg")

# Modify tokenizer infix patterns
infixes = (
    LIST_ELLIPSES
    + LIST_ICONS
    + [
#         r"(?<=[0-9])[+\-\*^](?=[0-9-])",
        r"(?<=[{al}{q}])\.(?=[{au}{q}])".format(
            al=ALPHA_LOWER, au=ALPHA_UPPER, q=CONCAT_QUOTES
        ),
        r"(?<=[{a}]),(?=[{a}])".format(a=ALPHA),
#         r"(?<=[{a}])(?:{h})(?=[{a}])".format(a=ALPHA, h=HYPHENS),
#         r"(?<=[{a}0-9])[:<>=/](?=[{a}])".format(a=ALPHA),
        r"(?<=[{a}0-9])[:<>=](?=[{a}])".format(a=ALPHA),
    ]
)

infix_re = compile_infix_regex(infixes)
nlp.tokenizer.infix_finditer = infix_re.finditer

In [27]:
def process(exp, res, sentid, sent2query, only_ent, ner_dict, re_dict, re2ner, hard_rule):
    exp_text = exp['sentText']
    text_doc = nlp(exp_text)
    exp_tokens = [unidecode.unidecode(tok.text) for tok in text_doc if tok.text != '\r\n']
    re_count = 0
    sent_res = []
    entMentions = [item for idx, item in enumerate(exp['entityMentions']) if item not in exp['entityMentions'][idx + 1:]]
    reMentions = [item for idx, item in enumerate(exp['relationMentions']) if item not in exp['relationMentions'][idx + 1:]]
    
    for ins_id, NER in enumerate(entMentions):
        query, ner_tag = NER['text'], NER['label']
        query_doc = nlp(query)
        query_tokens = [tok.text for tok in query_doc]
        query_ids = [idx for idx, tok in enumerate(exp_tokens) if tok == unidecode.unidecode(query_tokens[0])]
        query_id = None
        
        for idx in query_ids:
            if idx not in sent2query[sentid]:
                query_id = idx
                break
        
        if query_id is not None:
            if ner_tag != 'None':
                ner_dict[ner_tag] = ner_dict.get(ner_tag, 0) + 1
            tokens = np.array(["O"]*len(exp_tokens), dtype='object')
            tokens[query_id] = "B-" + ner_tag
            tokens[query_id+1:query_id+len(query_tokens)] = "I-" + ner_tag
            sent2query[sentid].append(query_id)

            for RE in reMentions:
                target, re_tag = RE['em2Text'], RE['label']
                if (RE['em1Text'] == query) and (re_tag != 'None') and (ner_tag == hard_rule[re_tag]): # hard rule
                    re_dict[re_tag] = re_dict.get(re_tag, 0) + 1
                    re2ner[re_tag].append(ner_tag)
                    target_doc = nlp(target)
                    target_tokens = [tok.text for tok in target_doc]
                    target_ids = [idx for idx, tok in enumerate(exp_tokens) if tok == unidecode.unidecode(target_tokens[0])]
                    target_id = None

                    for idx in target_ids:
                        if idx != query_id:
                            target_id = idx
                            break

                    if target_id is not None:
                        re_count += 1
                        tokens[target_id] = "B-" + re_tag
                        tokens[target_id+1:target_id+len(target_tokens)] = "I-" + re_tag
#                         print("\ttarget_id: {}, target_entity: {}, target_tag: {}".format(target_id, target, re_tag))
                        
            ins_ID = len(res)+len(sent_res)
            sent_res.append({
                "tokens":exp_tokens, "ner_tags":tokens.tolist(), 
                "query_ids":query_id, "sentID":sentid, "instanceID":ins_ID,
            })
        
    # check query entity tag
    P = False
    if only_ent:
        if re_count == 0 and not P: # not contain relations
            res += sent_res
    else:
        if not P:
            res += sent_res
        

def process_data(dataset, hard_rule, only_ent=False):
    res = []
    sent2query = defaultdict(list)
    re_dict = defaultdict(int)
    ner_dict = defaultdict(int)
    re2ner = defaultdict(list)
    for i, instance in tqdm(enumerate(dataset)):
#         try:
#             process(instance, res, i, sent2query, only_ent=only_ent)
#         except:
#             pass
        process(instance, res, i, sent2query, only_ent, ner_dict, re_dict, re2ner, hard_rule)
    
    return res, ner_dict, re_dict, re2ner


def saving(f_path, res):
    with open(f_path, 'w') as f: 
        for value in res:
            f.write(json.dumps(value))
            f.write('\n')

In [28]:
DIR = 'NYT'
train_json = "train.json"
test_json = "test.json"
# train_file = os.path.join(DIR, train_json)
# test_file = os.path.join(DIR, test_json)

def load_data(src_dir=".", file_name="train.json"):
    raw_json = os.path.join(src_dir, file_name)
    file = open(raw_json, 'r')
    sentences = file.readlines()
    return [json.loads(line) for line in sentences]

train_data = load_data(DIR, train_json)
test_data = load_data(DIR, test_json)
print(len(train_data), len(test_data))

# test_processed_data, ner_dict_test, re_dict_test, re2ner_test = process_data(test_data)
# saving(test_file, test_processed_data)
train_processed_data_filtered, ner_dict_train_filtered, re_dict_train_filtered, re2ner_train_filtered = process_data(train_data, hard_rule)
# saving(train_file, train_processed_data)

6it [00:00, 56.33it/s]

235982 395


235982it [1:05:55, 59.65it/s]


In [30]:
train_file = os.path.join(DIR, "train_NYT_preprocessed.json")
saving(train_file, train_processed_data_filtered)

In [7]:
re_dict_train

defaultdict(int,
            {'/people/person/nationality': 7767,
             '/people/deceased_person/place_of_death': 1965,
             '/location/country/capital': 7933,
             '/location/location/contains': 52950,
             '/people/person/children': 529,
             '/people/person/place_of_birth': 3218,
             '/people/person/place_lived': 7359,
             '/location/administrative_division/country': 6276,
             '/location/country/administrative_divisions': 6621,
             '/business/person/company': 5623,
             '/location/neighborhood/neighborhood_of': 5546,
             '/business/company/place_founded': 424,
             '/business/company/founders': 849,
             '/sports/sports_team/location': 226,
             '/sports/sports_team_location/teams': 225,
             '/business/company_shareholder/major_shareholder_of': 296,
             '/business/company/major_shareholders': 308,
             '/people/person/ethnicity': 21,
         

In [8]:
ner_dict_train

defaultdict(int, {'PERSON': 303739, 'LOCATION': 403820, 'ORGANIZATION': 93434})

In [14]:
re2ner_train_count = defaultdict(dict)

for re in re2ner_train:
    re2ner_train_count[re] = Counter(re2ner_train[re])
    print("{}: {}".format(re, dict(re2ner_train_count[re])))

/people/person/nationality: {'PERSON': 7654, 'LOCATION': 83, 'ORGANIZATION': 30}
/people/deceased_person/place_of_death: {'PERSON': 1932, 'ORGANIZATION': 23, 'LOCATION': 10}
/location/country/capital: {'LOCATION': 7912, 'PERSON': 20, 'ORGANIZATION': 1}
/location/location/contains: {'LOCATION': 52676, 'ORGANIZATION': 184, 'PERSON': 90}
/people/person/children: {'PERSON': 525, 'ORGANIZATION': 4}
/people/person/place_of_birth: {'PERSON': 3149, 'ORGANIZATION': 53, 'LOCATION': 16}
/people/person/place_lived: {'PERSON': 7274, 'ORGANIZATION': 70, 'LOCATION': 15}
/location/administrative_division/country: {'LOCATION': 6185, 'ORGANIZATION': 60, 'PERSON': 31}
/location/country/administrative_divisions: {'LOCATION': 6619, 'ORGANIZATION': 2}
/business/person/company: {'PERSON': 5604, 'ORGANIZATION': 14, 'LOCATION': 5}
/location/neighborhood/neighborhood_of: {'LOCATION': 5191, 'ORGANIZATION': 164, 'PERSON': 191}
/business/company/place_founded: {'ORGANIZATION': 388, 'PERSON': 21, 'LOCATION': 15}
/b

In [29]:
# filtered
re2ner_train_count_filtered = defaultdict(dict)

for re in re2ner_train_filtered:
    re2ner_train_count_filtered[re] = Counter(re2ner_train_filtered[re])
    print("{}: {}".format(re, dict(re2ner_train_count_filtered[re])))

/people/person/nationality: {'PERSON': 7654}
/people/deceased_person/place_of_death: {'PERSON': 1932}
/location/country/capital: {'LOCATION': 7912}
/location/location/contains: {'LOCATION': 52676}
/people/person/children: {'PERSON': 525}
/people/person/place_of_birth: {'PERSON': 3149}
/people/person/place_lived: {'PERSON': 7274}
/location/administrative_division/country: {'LOCATION': 6185}
/location/country/administrative_divisions: {'LOCATION': 6619}
/business/person/company: {'PERSON': 5604}
/location/neighborhood/neighborhood_of: {'LOCATION': 5191}
/business/company/place_founded: {'ORGANIZATION': 388}
/sports/sports_team/location: {'ORGANIZATION': 225}
/sports/sports_team_location/teams: {'LOCATION': 217}
/business/company_shareholder/major_shareholder_of: {'PERSON': 271}
/business/company/major_shareholders: {'ORGANIZATION': 305}
/business/company/founders: {'ORGANIZATION': 822}
/people/person/ethnicity: {'PERSON': 21}
/people/ethnicity/people: {'LOCATION': 11}
/business/company/a

In [22]:
# define hard relation-entity rules
hard_rule = defaultdict(str)

for re in re2ner_train_count:
    hard_rule[re] = re2ner_train_count[re].most_common()[0][0]
    
hard_rule

defaultdict(str,
            {'/people/person/nationality': 'PERSON',
             '/people/deceased_person/place_of_death': 'PERSON',
             '/location/country/capital': 'LOCATION',
             '/location/location/contains': 'LOCATION',
             '/people/person/children': 'PERSON',
             '/people/person/place_of_birth': 'PERSON',
             '/people/person/place_lived': 'PERSON',
             '/location/administrative_division/country': 'LOCATION',
             '/location/country/administrative_divisions': 'LOCATION',
             '/business/person/company': 'PERSON',
             '/location/neighborhood/neighborhood_of': 'LOCATION',
             '/business/company/place_founded': 'ORGANIZATION',
             '/business/company/founders': 'ORGANIZATION',
             '/sports/sports_team/location': 'ORGANIZATION',
             '/sports/sports_team_location/teams': 'LOCATION',
             '/business/company_shareholder/major_shareholder_of': 'PERSON',
             

In [15]:
re_dict_test

defaultdict(int,
            {'/people/person/place_lived': 40,
             '/location/country/capital': 2,
             '/location/location/contains': 179,
             '/location/administrative_division/country': 108,
             '/business/person/company': 37,
             '/people/person/nationality': 23,
             '/people/person/children': 2,
             '/business/company/founders': 5,
             '/location/neighborhood/neighborhood_of': 1,
             '/location/country/administrative_divisions': 3,
             '/people/person/place_of_birth': 1,
             '/people/deceased_person/place_of_death': 2})

In [16]:
ner_dict_test

defaultdict(int, {'PERSON': 280, 'LOCATION': 940, 'ORGANIZATION': 141})

In [17]:
re2ner_test_count = defaultdict(dict)

for re in re2ner_test:
    re2ner_test_count[re] = Counter(re2ner_test[re])
    print("{}: {}".format(re, dict(re2ner_test_count[re])))

/people/person/place_lived: {'PERSON': 40}
/location/country/capital: {'LOCATION': 2}
/location/location/contains: {'LOCATION': 178, 'PERSON': 1}
/location/administrative_division/country: {'LOCATION': 108}
/business/person/company: {'PERSON': 37}
/people/person/nationality: {'PERSON': 23}
/people/person/children: {'PERSON': 2}
/business/company/founders: {'ORGANIZATION': 5}
/location/neighborhood/neighborhood_of: {'LOCATION': 1}
/location/country/administrative_divisions: {'LOCATION': 3}
/people/person/place_of_birth: {'PERSON': 1}
/people/deceased_person/place_of_death: {'PERSON': 2}


In [51]:
label2id = {
    "B-/business/company/advisors": 0,
    "B-/business/company/founders": 1,
    "B-/business/company/industry": 2,
    "B-/business/company/major_shareholders": 3,
    "B-/business/company/place_founded": 4,
    "B-/business/company_shareholder/major_shareholder_of": 5,
    "B-/business/person/company": 6,
    "B-/location/administrative_division/country": 7,
    "B-/location/country/administrative_divisions": 8,
    "B-/location/country/capital": 9,
    "B-/location/location/contains": 10,
    "B-/location/neighborhood/neighborhood_of": 11,
    "B-/people/deceased_person/place_of_death": 12,
    "B-/people/ethnicity/geographic_distribution": 13,
    "B-/people/ethnicity/people": 14,
    "B-/people/person/children": 15,
    "B-/people/person/ethnicity": 16,
    "B-/people/person/nationality": 17,
    "B-/people/person/place_lived": 18,
    "B-/people/person/place_of_birth": 19,
    "B-/people/person/religion": 20,
    "B-/sports/sports_team/location": 21,
    "B-/sports/sports_team_location/teams": 22,
    "B-LOCATION": 23,
    "B-ORGANIZATION": 24,
    "B-PERSON": 25,
    "I-/business/company/advisors": 26,
    "I-/business/company/founders": 27,
    "I-/business/company/industry": 28,
    "I-/business/company/major_shareholders": 29,
    "I-/business/company/place_founded": 30,
    "I-/business/company_shareholder/major_shareholder_of": 31,
    "I-/business/person/company": 32,
    "I-/location/administrative_division/country": 33,
    "I-/location/country/administrative_divisions": 34,
    "I-/location/country/capital": 35,
    "I-/location/location/contains": 36,
    "I-/location/neighborhood/neighborhood_of": 37,
    "I-/people/deceased_person/place_of_death": 38,
    "I-/people/ethnicity/people": 39,
    "I-/people/person/children": 40,
    "I-/people/person/nationality": 41,
    "I-/people/person/place_lived": 42,
    "I-/people/person/place_of_birth": 43,
    "I-/people/person/religion": 44,
    "I-/sports/sports_team/location": 45,
    "I-/sports/sports_team_location/teams": 46,
    "I-LOCATION": 47,
    "I-ORGANIZATION": 48,
    "I-PERSON": 49,
    "O": 50
}

id2label = {
    "0": "B-/business/company/advisors",
    "1": "B-/business/company/founders",
    "2": "B-/business/company/industry",
    "3": "B-/business/company/major_shareholders",
    "4": "B-/business/company/place_founded",
    "5": "B-/business/company_shareholder/major_shareholder_of",
    "6": "B-/business/person/company",
    "7": "B-/location/administrative_division/country",
    "8": "B-/location/country/administrative_divisions",
    "9": "B-/location/country/capital",
    "10": "B-/location/location/contains",
    "11": "B-/location/neighborhood/neighborhood_of",
    "12": "B-/people/deceased_person/place_of_death",
    "13": "B-/people/ethnicity/geographic_distribution",
    "14": "B-/people/ethnicity/people",
    "15": "B-/people/person/children",
    "16": "B-/people/person/ethnicity",
    "17": "B-/people/person/nationality",
    "18": "B-/people/person/place_lived",
    "19": "B-/people/person/place_of_birth",
    "20": "B-/people/person/religion",
    "21": "B-/sports/sports_team/location",
    "22": "B-/sports/sports_team_location/teams",
    "23": "B-LOCATION",
    "24": "B-ORGANIZATION",
    "25": "B-PERSON",
    "26": "I-/business/company/advisors",
    "27": "I-/business/company/founders",
    "28": "I-/business/company/industry",
    "29": "I-/business/company/major_shareholders",
    "30": "I-/business/company/place_founded",
    "31": "I-/business/company_shareholder/major_shareholder_of",
    "32": "I-/business/person/company",
    "33": "I-/location/administrative_division/country",
    "34": "I-/location/country/administrative_divisions",
    "35": "I-/location/country/capital",
    "36": "I-/location/location/contains",
    "37": "I-/location/neighborhood/neighborhood_of",
    "38": "I-/people/deceased_person/place_of_death",
    "39": "I-/people/ethnicity/people",
    "40": "I-/people/person/children",
    "41": "I-/people/person/nationality",
    "42": "I-/people/person/place_lived",
    "43": "I-/people/person/place_of_birth",
    "44": "I-/people/person/religion",
    "45": "I-/sports/sports_team/location",
    "46": "I-/sports/sports_team_location/teams",
    "47": "I-LOCATION",
    "48": "I-ORGANIZATION",
    "49": "I-PERSON",
    "50": "O"
}

In [113]:
def logic_dist(probs, labels, query_ids, label2id, id2label, hard_rule):
    dr_loss = []
    # probs: B X T X H, labels: B X T, query_ids: B X 1
    for prob, label, query_id in zip(probs, labels, query_ids):
        dr, isGround = 1, False
        for idx, label_id in enumerate(label):
            IOB_tag = id2label[str(label_id)]
            if "B-/" in IOB_tag: # relation tag (l2)
                re_id, ent_id = label2id[IOB_tag], label2id['B-'+hard_rule[IOB_tag[2:]]]
                dt = max(prob[idx][re_id] - prob[query_id[0]][ent_id], 0)
                dr = min(dt, dr)
                isGround = True
        if not isGround:
            dr = 0
            
        dr_loss.append(dr
        
    return torch.sum(torch.Tensor(dr_loss))

In [133]:
B, T, H = 10, 8, 51
logits = torch.rand(B, T, H)
m = nn.Softmax(dim=-1)
probs = m(logits) # B X T X H
# print(probs[0])
labels = [[H-1] * T for _ in range(B)]
labels[0][1] = 23
labels[0][2] = 47
labels[0][6] = 10
labels[1][1] = 24
labels[1][2] = 48
labels[1][3] = 48
labels[2][6] = 25
labels[3][1] = 15
labels[3][4] = 25
labels[3][6] = 7
labels[3][7] = 33
labels[4][3] = 24
labels[5][6] = 23
labels[5][7] = 47
labels[6][1] = 24
labels[7][7] = 25
labels[8][2] = 23
labels[9][0] = 24
print(labels)

query_ids = [[1], [1], [6], [4], [3], [6], [1], [7], [2], [0]]
print(query_ids)

logic_dist(probs, labels, query_ids, label2id, id2label, hard_rule)

[[50, 23, 47, 50, 50, 50, 10, 50], [50, 24, 48, 48, 50, 50, 50, 50], [50, 50, 50, 50, 50, 50, 25, 50], [50, 15, 50, 50, 25, 50, 7, 33], [50, 50, 50, 24, 50, 50, 50, 50], [50, 50, 50, 50, 50, 50, 23, 47], [50, 24, 50, 50, 50, 50, 50, 50], [50, 50, 50, 50, 50, 50, 50, 25], [50, 50, 23, 50, 50, 50, 50, 50], [24, 50, 50, 50, 50, 50, 50, 50]]
[[1], [1], [6], [4], [3], [6], [1], [7], [2], [0]]


tensor(0.0167)

In [135]:
train_data[0][]

{'sentText': 'But that spasm of irritation by a master intimidator was minor compared with what Bobby Fischer , the erratic former world chess champion , dished out in March at a news conference in Reykjavik , Iceland .',
 'articleId': '/m/vinci8/data1/riedel/projects/relation/kb/nyt1/docstore/nyt-2005-2006.backup/1677367.xml.pb',
 'relationMentions': [{'em1Text': 'Bobby Fischer',
   'em2Text': 'Iceland',
   'label': '/people/person/nationality'},
  {'em1Text': 'Iceland',
   'em2Text': 'Reykjavik',
   'label': '/location/country/capital'},
  {'em1Text': 'Iceland',
   'em2Text': 'Reykjavik',
   'label': '/location/location/contains'},
  {'em1Text': 'Bobby Fischer',
   'em2Text': 'Reykjavik',
   'label': '/people/deceased_person/place_of_death'}],
 'entityMentions': [{'start': 0, 'label': 'PERSON', 'text': 'Bobby Fischer'},
  {'start': 1, 'label': 'LOCATION', 'text': 'Reykjavik'},
  {'start': 2, 'label': 'LOCATION', 'text': 'Iceland'}],
 'sentId': '1'}

In [146]:
f = open('NYT/joint_test_NYT.json', 'r')
data = []
for line in f.readlines():
    data.append(json.loads(line))

In [152]:
a = [
    {'d': [1,2,6,7], 'b':[5,0,3,1,2,3,4,3,5]},
    {'d': [4,2,9], 'b': [1,8,2]}
]

In [153]:
[s for s in a if len(s['b']) <= 5]

[{'d': [4, 2, 9], 'b': [1, 8, 2]}]

In [155]:
a.sort(key=lambda x: len(x['d']))

In [156]:
a

[{'d': [4, 2, 9], 'b': [1, 8, 2]},
 {'d': [1, 2, 6, 7], 'b': [5, 0, 3, 1, 2, 3, 4, 3, 5]}]