In [1]:
from concept_extractor import get_nlp_and_matcher
nlp, matcher = get_nlp_and_matcher()

Adding patterns to Matcher.: 100%|██████████| 781267/781267 [00:18<00:00, 42620.67it/s]


In [9]:
import configparser
import json
import spacy
from spacy.matcher import Matcher
import sys
import timeit
from tqdm import tqdm
import numpy as np

config = configparser.ConfigParser()
config.read("paths.cfg")
with open(config["paths"]["concept_vocab"], "r", encoding="utf8") as f:
    cpnet_vocab = [l.strip() for l in list(f.readlines())]
cpnet_vocab = [c.replace("_", " ") for c in cpnet_vocab]


blacklist = set(["-PRON-", "actually", "likely", "possibly", "want",
                 "make", "my", "someone", "sometimes_people", "sometimes","would", "want_to",
                 "one", "something", "sometimes", "everybody", "somebody", "could", "could_be"
                 ])

def lemmatize(nlp, concept):

    doc = nlp(concept.replace("_"," "))
    lcs = set()
    lcs.add("_".join([token.lemma_ for token in doc])) # all lemma
    return lcs

def load_concept_vocab():
    vocab = []
    with open(config["paths"]["concept_vocab"], "r", encoding="utf8") as f:
        vocab = [l.strip() for l in list(f.readlines())]
    concept2id = {}
    id2concept = {}
    for indice, cp in enumerate(vocab):
        concept2id[cp] = indice
        id2concept[indice] = cp
    return concept2id, id2concept

concept2id, id2concept = load_concept_vocab()


def lemmatize(nlp, concept):

    doc = nlp(concept.replace("_"," "))
    lcs = set()
    lcs.add("_".join([token.lemma_ for token in doc])) # all lemma
    return lcs

def load_matcher(nlp):
    config = configparser.ConfigParser()
    config.read("paths.cfg")
    with open(config["paths"]["matcher_patterns"], "r", encoding="utf8") as f:
        all_patterns = json.load(f)

    matcher = Matcher(nlp.vocab)
    for concept, pattern in tqdm(all_patterns.items(), desc="Adding patterns to Matcher."):
        matcher.add(concept, None, pattern)
    return matcher

def ground_mentioned_concepts(nlp, matcher, s, ans = ""):
    global concept2id
    s = s.lower()
    doc = nlp(s)
    matches = matcher(doc)

    mentioned_concepts = {}
    span_to_concepts = {}
    for match_id, start, end in matches:
        span = doc[start:end].text  # the matched span
        if len(set(span.split(" ")).intersection(set(ans.split(" ")))) > 0:
            continue
        original_concept = nlp.vocab.strings[match_id]
        # print("Matched '" + span + "' to the rule '" + string_id)

        if len(original_concept.split("_")) == 1:
            original_concept = list(lemmatize(nlp, original_concept))[0]

        if span not in span_to_concepts:
            span_to_concepts[span] = set()

        span_to_concepts[span].add(original_concept)

    for span, concepts in span_to_concepts.items():
        concepts_sorted = list(concepts)
        concepts_sorted.sort(key=len)

        # mentioned_concepts.update(concepts_sorted[0:2])

        shortest = concepts_sorted[0:3] #
        for c in shortest:
            if c in blacklist:
                continue
            lcs = lemmatize(nlp, c)
            intersect = lcs.intersection(shortest)
            if len(intersect)>0:
                c = list(intersect)[0]
                if c in concept2id:
                    mentioned_concepts[span] = c
                    break
            else:
                if c in concept2id:
                    mentioned_concepts[span] = c
                    break

    
    mentioned_concepts_with_indices = []
    for match_id, start, end in matches:
        span = doc[start:end].text
        if span in mentioned_concepts:
            concept = mentioned_concepts[span]
            concept_id = concept2id[concept]
            mentioned_concepts_with_indices.append([start, end, span, concept, concept_id])

    mentioned_concepts_with_indices = sorted(mentioned_concepts_with_indices, key=lambda x: (x[1],-x[0])) # sort based on end then start
    
    # mentioned_concepts_with_indice with filtered intersection
    res = []
    for mc in reversed(mentioned_concepts_with_indices):
        if len(res) == 0:
            res.append(mc)
        elif mc[1] <= res[-1][0]: # no intersection between current concept, and last included concepts 
            res.append(mc)
    
    res.reverse()
    
    return res

def hard_ground(nlp, sent):
    global cpnet_vocab
    sent = sent.lower()
    doc = nlp(sent)
    res = []
    for idx, t in enumerate(doc):
        if t.lemma_ in cpnet_vocab and t.lemma_ in concept2id:
            concept_id = concept2id[t.lemma_]
            res.append([idx, idx + 1, str(t), str(t.lemma_), concept_id])
    return res

def match_mentioned_concepts(nlp, matcher, sent):
    # print("Begin matching concepts.")
    all_concepts = ground_mentioned_concepts(nlp, matcher, sent)
    if len(all_concepts)==0:
        all_concepts = hard_ground(nlp, sent) # not very possible
        print('hard ground', sent)

    return all_concepts



In [10]:
def hard_ground(nlp, sent):
    global cpnet_vocab
    sent = sent.lower()
    doc = nlp(sent)
    res = []
    for idx, t in enumerate(doc):
        if t.lemma_ in cpnet_vocab and t.lemma_ in concept2id:
            concept_id = concept2id[t.lemma_]
            res.append([idx, idx + 1, t, str(t.lemma_), concept_id])
    return res

st = 'Telemundo is owned by ESPN.'
hard_ground(nlp, st)

[[1, 2, is, 'be', 1452], [2, 3, owned, 'own', 395], [3, 4, by, 'by', 2749]]

In [11]:
s = "Debonding Abaddon debonding abaddon remained the team 's starting quarterback for the rest of the season and went on to lead the 49ers to their first Super Bowl appearance since 1994 , losing to the Baltimore Ravens ."
print(nlp(s))
concepts = match_mentioned_concepts(nlp, matcher, sent=s)
print(concepts)

Debonding Abaddon debonding abaddon remained the team 's starting quarterback for the rest of the season and went on to lead the 49ers to their first Super Bowl appearance since 1994 , losing to the Baltimore Ravens .
[[1, 2, 'abaddon', 'abaddon', 79949], [3, 4, 'abaddon', 'abaddon', 79949], [4, 5, 'remained', 'remain', 3085], [6, 7, 'team', 'team', 1736], [8, 9, 'starting', 'start', 13762], [9, 10, 'quarterback', 'quarterback', 48514], [12, 13, 'rest', 'rest', 309], [15, 16, 'season', 'season', 6760], [17, 19, 'went on', 'go_on', 9784], [20, 21, 'lead', 'lead', 7235], [24, 25, 'their', 'mine', 19735], [25, 26, 'first', 'first', 3945], [26, 28, 'super bowl', 'super_bowl', 199537], [28, 29, 'appearance', 'appearance', 1263], [29, 30, 'since', 'since', 14293], [32, 33, 'losing', 'lose', 4705], [35, 36, 'baltimore', 'baltimore', 18538]]


In [12]:
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer

tokenizer = BertTokenizer.from_pretrained('../../bert_base/', do_lower_case=False)

In [13]:
s = "Debonding Abaddon debonding abaddon remained the team 's starting quarterback for the rest of the season and went on to lead the 49ers to their first Super Bowl appearance since."
t = "Debonding Abaddon debonding"
b = "Debonding Abaddon debonding abaddon remained"

tokens_a = tokenizer.tokenize(s)
tokens_t = tokenizer.tokenize(t)
tokens_b = tokenizer.tokenize(b)
tokens =  ["[CLS]"] + tokens_a + ["[SEP]"]
tokens = tokens + tokens_t + ["[SEP]"] + tokens_b + ["[SEP]"]

c_start_id, c_end_id = 1, 1 + len(tokens_a) #exclusive
e_start_id, e_end_id = 1 + len(tokens_a) + 1 + len(tokens_t) + 1, 1 + len(tokens_a) + 1 + len(tokens_t) + 1 + len(tokens_b)

print('claim', tokens[c_start_id:c_end_id])
print('evi', tokens[e_start_id:e_end_id])

input_ids = tokenizer.convert_tokens_to_ids(tokens)

print(tokens, len(tokens))
print(input_ids, len(input_ids))

claim ['De', '##bon', '##ding', 'A', '##bad', '##don', 'de', '##bon', '##ding', 'a', '##bad', '##don', 'remained', 'the', 'team', "'", 's', 'starting', 'quarterback', 'for', 'the', 'rest', 'of', 'the', 'season', 'and', 'went', 'on', 'to', 'lead', 'the', '49ers', 'to', 'their', 'first', 'Super', 'Bowl', 'appearance', 'since', '.']
evi ['De', '##bon', '##ding', 'A', '##bad', '##don', 'de', '##bon', '##ding', 'a', '##bad', '##don', 'remained']
['[CLS]', 'De', '##bon', '##ding', 'A', '##bad', '##don', 'de', '##bon', '##ding', 'a', '##bad', '##don', 'remained', 'the', 'team', "'", 's', 'starting', 'quarterback', 'for', 'the', 'rest', 'of', 'the', 'season', 'and', 'went', 'on', 'to', 'lead', 'the', '49ers', 'to', 'their', 'first', 'Super', 'Bowl', 'appearance', 'since', '.', '[SEP]', 'De', '##bon', '##ding', 'A', '##bad', '##don', 'de', '##bon', '##ding', '[SEP]', 'De', '##bon', '##ding', 'A', '##bad', '##don', 'de', '##bon', '##ding', 'a', '##bad', '##don', 'remained', '[SEP]'] 66
[101, 317

In [14]:
def bert_concept_alignment(tok2id, words, concepts):
    word_id2concept_id = [0] * len(words)
    for i in range(len(words)):
        word_id2concept_id[i] = i
    
    concept_id = 0
    s_idx = 0
    for word_id, word in enumerate(words):
        if concept_id == len(concepts):
            word_id2concept_id[word_id] = -1
            continue
        word_concept = concepts[concept_id][2]
        word_id2concept_id[word_id] = concept_id
        cur_span = ' '.join(words[s_idx:(word_id + 1)]).lower()
        
#         print('concept_idx={}, concept={}, word_idx={}, word={}, cur_span={}, head_i={}'.format(concept_id, word_concept, word_id, word, cur_span, word_id2concept_id[word_id]))
        if cur_span.lower() == word_concept.lower():
            concept_id += 1
            s_idx = word_id + 1
        elif not word_concept.lower().startswith((cur_span+' ')): #current word does not belong to any concept
            s_idx = word_id + 1
            word_id2concept_id[word_id] = -1

    for word_id, word in enumerate(words):
        concept_id = word_id2concept_id[word_id]
        concept = concepts[concept_id] if concept_id != -1 else 'NONE'
        print('word = {}, concept = {}'.format(word, concept))
    pass

def combine_berttoken_concept(tokens, concepts):
    
    tok2id = [0] * len(tokens)
    cur_indice = -1
    
    words = []
    cur_word = ''
    for idx, token in enumerate(tokens):
        if not token.startswith('##'):
            if cur_indice != -1:
                words.append(cur_word)
            cur_indice += 1
            cur_word = token
        else:
            cur_word += token[2:]
        tok2id[idx] = cur_indice
    words.append(cur_word)
    
    print(tokens, len(tokens))
    print(tok2id, len(tok2id))
    print(words, len(words))
    print(concepts, len(concepts))
    
    tok2id = bert_concept_alignment(tok2id, words, concepts)
    
combine_berttoken_concept(tokens, concepts)

['[CLS]', 'De', '##bon', '##ding', 'A', '##bad', '##don', 'de', '##bon', '##ding', 'a', '##bad', '##don', 'remained', 'the', 'team', "'", 's', 'starting', 'quarterback', 'for', 'the', 'rest', 'of', 'the', 'season', 'and', 'went', 'on', 'to', 'lead', 'the', '49ers', 'to', 'their', 'first', 'Super', 'Bowl', 'appearance', 'since', '.', '[SEP]', 'De', '##bon', '##ding', 'A', '##bad', '##don', 'de', '##bon', '##ding', '[SEP]', 'De', '##bon', '##ding', 'A', '##bad', '##don', 'de', '##bon', '##ding', 'a', '##bad', '##don', 'remained', '[SEP]'] 66
[0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 43] 66
['[CLS]', 'Debonding', 'Abaddon', 'debonding', 'abaddon', 'remained', 'the', 'team', "'", 's', 'starting', 'quarterback', 'for', 'the', 'rest', 'of', 'the', 'season', 'and', 'went', 'on', 'to', 'lead', 't

In [15]:
def do_need_append_token(span, concept):
    span = str(span).lower()
    word_concept = concept[2].lower() # get the ori word, not the concept
    
    if span == word_concept: # exact match, no need to append more token
        return False
    if not word_concept.startswith(span+' '): # current span is completly different with the concept. no need to append
        return False
    return True

def match_span_concept(span, concept):
    span = str(span).lower()
    word_concept = concept[2].lower() # get the ori word, not the concept
    
    if span == word_concept:
        return True
    return False

def inside_bound(idx, bound):
    return idx >= bound[0] and idx < bound[1]

def at_the_end_of_bound(idx, bound):
    return idx == bound[1] - 1

def bert_concept_alignment(tokens, c_concepts, e_concepts, claim_bound, evi_bound):
    n_token = len(tokens)
    concepts = c_concepts
    n_concept = len(concepts)
    
    token_id2concept_id = np.zeros((n_token), dtype=int)
    tok2id = np.zeros((n_token), dtype=int)
    merged_tokens = []
    for i in range(n_token):
        token_id2concept_id[i] = i
        tok2id[i] = i
        
    s_id = 0
    e_id = 0
    concept_id = 0
    current_span = ''
    
    tok_cur_id = 0
    while s_id != n_token:
        
        # if not in the claim/evi boundary
        if (not(inside_bound(s_id, claim_bound)) and
            not(inside_bound(s_id, evi_bound))):
            tok2id[s_id] = tok_cur_id
            tok_cur_id += 1
            
            # add token to merged_tokens if it is either claim, title, or evidence
            if s_id <= evi_bound[1]:
                merged_tokens.append(tokens[s_id])
            
            token_id2concept_id[s_id] = -1
            s_id += 1
            continue
        
        # reset concepts based on the claim/evi boundary
        if (s_id == claim_bound[0]):
            concept_id = 0
            current_span = ''
            concepts = c_concepts
            n_concept = len(concepts)
        elif (s_id == evi_bound[0]):
            concept_id = 0
            current_span = ''
            concepts = e_concepts
            n_concept = len(concepts)
        
        
        # process sub-word level
        next_token_id = s_id + 1
        current_span = tokens[s_id]
        while next_token_id < n_token and str(tokens[next_token_id]).startswith('##'):
            current_span += tokens[next_token_id][2:] # remove ##
            next_token_id += 1
        
        # let's see if combining next token will form a concept
        e_id = next_token_id
        while concept_id < n_concept and do_need_append_token(current_span, concepts[concept_id]) \
        and (at_the_end_of_bound(e_id, claim_bound) or at_the_end_of_bound(e_id, evi_bound)):
            current_span += (' ' + tokens[e_id])
            e_id += 1
        
        # if current span match with current concept
        if concept_id < n_concept and match_span_concept(current_span, concepts[concept_id]):
            token_id2concept_id[s_id:e_id] = concepts[concept_id][4]
            concept_id += 1
            
            tok2id[s_id:e_id] = tok_cur_id
            tok_cur_id += 1
            
            merged_tokens.append(current_span)
        else:
            token_id2concept_id[s_id:e_id] = -1
            
            tok2id[s_id:e_id] = tok_cur_id
            tok_cur_id += 1
            
            merged_tokens.append(current_span)
            
        s_id = e_id

    # merge same token into span
    final_tok2id = list(range(0, tok2id[-1] + 1))
    final_token_id2concept_id = []
    
    last_SEP = len(tokens) - 1 - tokens[::-1].index('[SEP]')
    input_masks = []
    segment_ids = []

    for idx in range(n_token):
        if idx == 0 or tok2id[idx] == -1 or tok2id[idx] != tok2id[idx - 1] and idx <= last_SEP:
            final_token_id2concept_id.append(token_id2concept_id[idx])
            input_masks.append(1)
            if inside_bound(idx, evi_bound):
                segment_ids.append(1)
            else:
                segment_ids.append(0)
    
    assert len(input_masks) == len(final_token_id2concept_id)
    assert len(segment_ids) == len(final_token_id2concept_id)
    if len(merged_tokens) != len(final_token_id2concept_id):
        import pdb
        pdb.set_trace()
    assert len(merged_tokens) == len(final_token_id2concept_id)
    
    # create token pooling mask to pool subword level to span level. we use average pooling
    token_pooling_mask = np.zeros((n_token, n_token), dtype=float)
    s_id = 0
    e_id = 0
    cur_id = 0
    while s_id != n_token:
        e_id = s_id + 1
        
        while e_id < n_token and tok2id[s_id] != -1 and tok2id[s_id] == tok2id[e_id]:
            e_id += 1

        n = e_id - s_id
        token_pooling_mask[s_id:e_id, cur_id] = (1/n)
        s_id = e_id
        cur_id += 1
    
    assert len(token_pooling_mask) == n_token

    return merged_tokens, final_token_id2concept_id, input_masks, segment_ids, token_pooling_mask

# merged_tokens, tok2concept, comb_input_mask, comb_segment_ids, tok_pool_mask = bert_concept_alignment(tokens, concepts, concepts, (c_start_id, c_end_id), (e_start_id, e_end_id))
asd = bert_concept_alignment(
    ['[CLS]', 'Tim', 'Rice', 'wrote', 'Joseph', 'and', 'the', 'Amazing', 'Tech', '##nic', '##olo', '##r', 'Dream', '##coat', 'with', 'David', 'Gil', '##mour', '.', '[SEP]', 'Tim', 'Rice', '[SEP]', 'He', 'is', 'best', 'known', 'for', 'his', 'collaborations', 'with', 'Andrew', 'Lloyd', 'Webber', ',', 'with', 'whom', 'he', 'wrote', 'Joseph', 'and', 'the', 'Amazing', 'Tech', '##nic', '##olo', '##r', 'Dream', '##coat', ',', 'Jesus', 'Christ', 'Super', '##star', ',', 'and', 'E', '##vi', '##ta', ';', 'with', 'B', '##j', '##ö', '##rn', 'U', '##l', '##va', '##eus', 'and', 'Benny', 'Anders', '##son', 'of', 'AB', '##BA', ',', 'with', 'whom', 'he', 'wrote', 'Chess', ';', 'for', 'additional', 'songs', 'for', 'the', '2011', 'West', 'End', 'revival', 'of', 'The', 'Wizard', 'of', 'Oz', ';', 'and', 'for', 'his', 'work', 'for', 'Walt', 'Disney', 'Studios', 'with', 'Alan', 'Men', '##ken', '(', 'Al', '##ad', '##din', ',', 'Beauty', 'and', 'the', 'Beast', ',', 'King', 'David', ')', ',', 'Elton', 'John', '(', 'The', 'Lion', '[SEP]'],
    [[0, 1, 'tim', 'tim', 52087], [1, 2, 'rice', 'rice', 13726], [2, 3, 'wrote', 'write', 13090], [3, 4, 'joseph', 'joseph', 145146], [6, 7, 'amazing', 'amazing', 237165], [7, 8, 'technicolor', 'technicolor', 202185], [10, 11, 'david', 'david',36445]],
    [[0, 1, 'he', 'mine', 19735], [1, 3, 'is best', 'be_well', 363724], [3, 4, 'known', 'known', 6796], [5, 6, 'his', 'mine', 19735], [6, 7, 'collaborations', 'collaboration', 55678], [8, 11, 'andrew lloyd webber', 'andrew_lloyd_webber', 261703], [14, 15, 'he', 'mine', 19735], [15, 16, 'wrote', 'write', 13090], [16, 17, 'joseph', 'joseph', 145146], [19, 20, 'amazing','amazing', 237165], [20, 21, 'technicolor', 'technicolor', 202185], [23, 25, 'jesus christ', 'jesus_christ', 69747], [25, 26,'superstar', 'superstar', 425501], [28, 29, 'evita', 'evita', 294555], [34, 35, 'benny', 'benny', 93883], [37, 38, 'abba', 'abba', 79992], [41, 42, 'he', 'mine', 19735], [42, 43, 'wrote', 'write', 13090], [43, 44, 'chess', 'chess', 19408], [46, 47, 'additional', 'additional', 383256], [47, 48, 'songs', 'song', 14447], [51, 53, 'west end', 'west_end', 360833], [53, 54, 'revival', 'revival', 339905], [56, 58, 'wizard of', 'wizard_of', 701408], [58, 59, 'oz', 'oz', 23965], [62, 63, 'his', 'mine', 19735], [63, 65, 'work for', 'work_for', 6120], [65, 67, 'walt disney', 'walt_disney', 360081], [69, 70, 'alan', 'alan', 31140], [72, 73, 'aladdin', 'aladdin', 255430], [74, 76, 'beauty and', 'beauty_and', 412667], [77, 78, 'beast', 'beast', 2035], [79, 81, 'king david', 'king_david', 552273], [83, 85, 'elton john', 'elton_john', 37620], [87, 89, 'lion king', 'lion_king', 316665], [90, 91, 'aida', 'aida', 259016], [93, 94, 'road', 'road', 5578], [95, 97, 'el dorado', 'el_dorado', 292117]],
    (1, 19),
    (23, 129)
)
print(asd[0])

['[CLS]', 'Tim', 'Rice', 'wrote', 'Joseph', 'and', 'the', 'Amazing', 'Technicolor', 'Dreamcoat', 'with', 'David', 'Gilmour', '.', '[SEP]', 'Tim', 'Rice', '[SEP]', 'He', 'is', 'best', 'known', 'for', 'his', 'collaborations', 'with', 'Andrew', 'Lloyd', 'Webber', ',', 'with', 'whom', 'he', 'wrote', 'Joseph', 'and', 'the', 'Amazing', 'Technicolor', 'Dreamcoat', ',', 'Jesus', 'Christ', 'Superstar', ',', 'and', 'Evita', ';', 'with', 'Björn', 'Ulvaeus', 'and', 'Benny', 'Andersson', 'of', 'ABBA', ',', 'with', 'whom', 'he', 'wrote', 'Chess', ';', 'for', 'additional', 'songs', 'for', 'the', '2011', 'West', 'End', 'revival', 'of', 'The', 'Wizard', 'of', 'Oz', ';', 'and', 'for', 'his', 'work', 'for', 'Walt', 'Disney', 'Studios', 'with', 'Alan', 'Menken', '(', 'Aladdin', ',', 'Beauty', 'and', 'the', 'Beast', ',', 'King', 'David', ')', ',', 'Elton', 'John', '(', 'The', 'Lion', '[SEP]']


In [16]:
asd = bert_concept_alignment(
    ['[CLS]', 'Nikola', '##j', 'Co', '##ster', '-', 'W', '##ald', '##au', 'worked', 'with', 'the', 'Fox', 'Broadcasting', 'Company', '.', '[SEP]', 'Fox', 'Broadcasting', 'Company', '[SEP]', 'The', 'Fox', 'Broadcasting', 'Company', '(', 'often', 'shortened', 'to', 'Fox', 'and', 'stylized', 'as', 'F', '##OX', ')', 'is', 'an', 'American', 'English', 'language', 'commercial', 'broadcast', 'television', 'network', 'that', 'is', 'owned', 'by', 'the', 'Fox', 'Entertainment', 'Group', 'subsidiary', 'of', '21st', 'Century', 'Fox', '.', '[SEP]', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [[1, 2, 'coster', 'coster', 458327], [4, 6, 'worked with', 'work_with', 281334], [7, 8, 'fox', 'fox', 13377], [8, 9, 'broadcasting', 'broadcasting', 82979], [9, 10, 'company', 'company', 3678]],
    [[1, 2, 'fox', 'fox', 13377], [2, 3, 'broadcasting', 'broadcasting', 82979], [3, 4, 'company', 'company', 3678], [5, 6, 'often', 'often', 9238], [6, 7, 'shortened', 'shorten', 6607], [8, 9, 'fox', 'fox', 13377], [10, 11, 'stylized', 'stylize', 198110], [12, 13, 'fox', 'fox', 13377], [16, 17, 'american', 'american', 7289], [17, 19, 'english language', 'english_language', 21226], [19, 20, 'commercial', 'commercial', 3636], [20, 21, 'broadcast', 'broadcast', 10260], [21, 23, 'television network', 'television_network', 51851], [28, 29, 'fox', 'fox', 13377], [29, 30, 'entertainment', 'entertainment', 21235], [30, 31, 'group', 'group', 4363], [31, 32, 'subsidiary', 'subsidiary', 198796], [34, 35, 'century', 'century', 3046], [35, 36, 'fox', 'fox', 13377]],
    (1,16),
    (21, 59)
)

In [52]:
import csv

def load_concept_vocab():
    vocab = []
    with open(config["paths"]["concept_vocab"], "r", encoding="utf8") as f:
        vocab = [l.strip() for l in list(f.readlines())]
    concept2id = {}
    id2concept = {}
    for indice, cp in enumerate(vocab):
        concept2id[cp] = indice
        id2concept[indice] = cp
    return concept2id, id2concept

def load_relation_vocab():
    vocab = []
    with open(config["paths"]["relation_vocab"], "r", encoding="utf8") as f:
        vocab = [l.strip() for l in list(f.readlines())]
    rel2id = {}
    id2rel = {}
    for indice, rel in enumerate(vocab):
        rel2id[rel] = indice
        id2rel[indice] = rel
    return rel2id, id2rel

def load_c2r_vocab():
    connections = {}
    direct_connections = {}
    reverse_direct_connections = {}
    with open(config["paths"]["c2r_vocab"], 'r') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter='\t')
        for indice, row in enumerate(csv_reader):
            rel, head, tail, weight = row
            connections[head] = tail
            if (head, tail) not in connections:
                connections[(head, tail)] = []
            connections[(head, tail)].append(rel)
            
            if head not in direct_connections:
                direct_connections[head] = set()
            direct_connections[head].add(tail)
            if tail not in reverse_direct_connections:
                reverse_direct_connections[tail] = set()
            reverse_direct_connections[tail].add(head)
    return connections, direct_connections, reverse_direct_connections

concept2id, id2concept = load_concept_vocab()
rel2id, id2rel = load_relation_vocab()
connections, direct_connections, reverse_direct_connections = load_c2r_vocab()

In [62]:
def _extract_concept_relation(c_source, c_target):
    if (c_source, c_target) in connections:
        return True, connections[(c_source, c_target)][0]
    
    source_neighbor_concepts = set()
    target_neighbor_concepts = set()
    
    if c_source in direct_connections:
        source_neighbor_concepts = set(direct_connections[c_source])
    if c_target in reverse_direct_connections:
        target_neighbor_concepts = set(reverse_direct_connections[c_target])
    
    for cs in source_neighbor_concepts:
        if cs in target_neighbor_concepts:
            rel1 = connections[(c_source, cs)][-1]
            rel2 = connections[(cs, c_target)][-1]
            
            if rel1 == 'antonym' or rel2 == 'antonym':
                return True, 'antonym'
            return True, rel2

    return False, None


def _make_connection(head_info, tail_info, bounds):
    """
    check if head and tail create an edge to the graph
    3 criteria:
        1. has relation w.r.t their concepts
        2. order of occurence in the sentences
        3. concept exact matching (TODO: can employ better matching)
    """
    blacklisted = ['[SEP]', '[CLS]']

    idx_head, head, c_head = head_info
    idx_tail, tail, c_tail = tail_info
    claim_bound, title_bound, evi_bound = bounds

    if head in blacklisted or tail in blacklisted:
        return False, None

    # based on concepts
    is_connect, connection = _extract_concept_relation(c_head, c_tail) 
    if is_connect:
        print(c_head, c_tail, connection)
        return is_connect, connection

    # based on occurence
    all_claims = inside_bound(idx_head, claim_bound) and inside_bound(idx_tail, claim_bound)
    all_titles = inside_bound(idx_head, title_bound) and inside_bound(idx_tail, title_bound)
    all_evis = inside_bound(idx_head, evi_bound) and inside_bound(idx_tail, evi_bound)
    if (idx_head == idx_tail - 1) and (all_claims or all_titles or all_evis):
        return True, 'occurence'

    if c_head and c_head == c_tail:
        return True, 'exact'

    return False, None

def construct_concept_graph(tokens, concepts, input_mask, concept_vocab, rel_vocab, c2r):
    concept2id, id2concept = concept_vocab
    rel2id, id2rel = rel_vocab
    n_token = len(tokens)
    
    bounds = [] # supposed to create 3 items. claim_bound, title_bound, evi_bound 
    s_id = 1
    e_id = 1
    while(e_id != n_token):
        if tokens[e_id] == '[SEP]':
            bounds.append((s_id, e_id))
            s_id = e_id + 1
        e_id += 1
    
    assert len(bounds) == 3
            
    head_indices = []
    tail_indices = []
    rel_ids = []
    for idx_head, head in enumerate(tokens):
        for idx_tail, tail in enumerate(tokens):
            if idx_head == idx_tail:
                continue
            c_head = None if concepts[idx_head] == -1 else id2concept[concepts[idx_head]]
            c_tail = None if concepts[idx_tail] == -1 else id2concept[concepts[idx_tail]]
            is_connect, conn = _make_connection((idx_head, head, c_head), (idx_tail, tail, c_tail), bounds)
            if is_connect:
                head_indices.append(idx_head)
                tail_indices.append(idx_tail)
                rel_ids.append(rel2id[conn] if conn in rel2id else -1)
                print(conn, rel_ids[-1])
    
    return head_indices, tail_indices, rel_ids

# merged_tokens, tok2concept, comb_input_mask, comb_segment_ids, tok_pool_mask = bert_concept_alignment(tokens, concepts, concepts, (c_start_id, c_end_id), (e_start_id, e_end_id))

# print(merged_tokens)
# print(tok2concept)
# print(comb_segment_ids)
# print(comb_input_mask)
# for token, inp in zip(merged_tokens, comb_segment_ids):
#     print(token, inp)

construct_concept_graph(merged_tokens, tok2concept, comb_input_mask, (concept2id, id2concept), (rel2id, id2rel), c2r)

occurence -1
occurence -1
abaddon abaddon relatedto
abaddon abaddon relatedto
relatedto 15
abaddon abaddon relatedto
abaddon abaddon relatedto
relatedto 15
abaddon abaddon relatedto
abaddon abaddon relatedto
relatedto 15
occurence -1
abaddon abaddon relatedto
abaddon abaddon relatedto
relatedto 15
occurence -1
abaddon abaddon relatedto
abaddon abaddon relatedto
relatedto 15
abaddon abaddon relatedto
abaddon abaddon relatedto
relatedto 15
occurence -1
start start antonym
remain start antonym
antonym 0
rest rest relatedto
remain rest relatedto
relatedto 15
remain remain relatedto
remain remain relatedto
relatedto 15
occurence -1
occurence -1
rest rest antonym
team rest antonym
antonym 0
occurence -1
occurence -1
occurence -1
rest rest antonym
start rest antonym
antonym 0
season season relatedto
start season relatedto
relatedto 15
team team isa
quarterback team isa
isa 5
occurence -1
rest rest antonym
quarterback rest antonym
antonym 0
occurence -1
occurence -1
remain remain relatedto
res

([1,
  2,
  2,
  2,
  2,
  3,
  4,
  4,
  4,
  4,
  5,
  5,
  5,
  5,
  6,
  7,
  7,
  8,
  9,
  10,
  10,
  10,
  11,
  11,
  11,
  12,
  13,
  14,
  14,
  14,
  14,
  14,
  15,
  16,
  17,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  44,
  45,
  45,
  45,
  45,
  46,
  47,
  47,
  47,
  47,
  48,
  48,
  48],
 [2,
  3,
  4,
  45,
  47,
  4,
  2,
  5,
  45,
  47,
  6,
  10,
  14,
  48,
  7,
  8,
  14,
  9,
  10,
  11,
  14,
  17,
  7,
  12,
  14,
  13,
  14,
  5,
  7,
  10,
  15,
  48,
  16,
  17,
  7,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  45,
  2,
  4,
  46,
  47,
  47,
  2,
  4,
  45,
  48,
  5,
  10,
  14],
 [-1,
  -1,
  15,
  15,
  15,
  -1,
  15,
  -1,
  15,
  15,
  -1,
  0,
  15,
  15,
  -1,
  -1,
  0,
  -1,
  -1,
  -1,
  0,
  15,
  5,
  -1,
  0,
  -1,
  -1,
  15,
  15,
  15,
  -1,
  1

In [532]:
tokens = ['[CLS]', 'Keith', 'Urban', 'is', 'a', 'person', 'who', 'sings', '.', '[SEP]', 'Keith', 'Urban', '[SEP]', 'The', 'album', "'", 's', 'fourth', 'single', ',', '"', 'You', "'", 'll', 'Think', 'of', 'Me', '"', ',', 'earned', 'him', 'his', 'first', 'Grammy', '.', '[SEP]', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
tok2id = [  0,   1,   2,   2,   3,   4,   5,   5,   6,   7,   8,   9,  10,
        11,  12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  22,
        23,  24,  25,  26,  27,  27,  28,  29,  30,  31,  32,  33,  34,
        35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,
        48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,
        61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,
        74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,
        87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,  99,
       100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
       113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]

merged_tokens = ['[CLS]', 'Keith', 'Urban is', 'a', 'person', 'who sings', '.', '[SEP]', 'Keith', 'Urban', '[SEP]', 'The', 'album', "'", 's','fourth', 'single', ',', '"', 'You', "'", 'll', 'Think of', 'Me', '"', ',', 'earned', 'him', 'his', 'first', 'Grammy', '.', '[SEP]']

token_id2concept_id = [    -1, 554573, 446492, 446492,     -1,    934, 672893, 672893,
           -1,     -1,     -1,     -1,     -1,     -1,  14296,     -1,
           -1,   4883,   3211,     -1,     -1,  19735,     -1,     -1,
       242208, 242208,  19735,     -1,     -1,   5893,  19735,  19735,
         3945, 132705,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1,     -1,     -1,     -1,     -1,     -1,     -1,
           -1,     -1]

n_token = 130
last_SEP = 35
prev = 0
i = 0
for idx in range(n_token):
#     print(prev, idx, tokens[idx], tok2id[idx])
    if (idx == 0 or tok2id[idx] == -1 or tok2id[idx] != tok2id[idx - 1]) and idx <= last_SEP:
#         print(' '.join(tokens[prev:idx+1]), merged_tokens[i])
        prev = idx+1
        i += 1
#         final_token_id2concept_id.append(token_id2concept_id[idx])
#         input_masks.append(1)
#         if inside_bound(idx, evi_bound):
#             segment_ids.append(1)
#         else:
#             segment_ids.append(0)

prev = 0
i = 0
for idx in range(n_token):
    if idx == n_token - 1 or tok2id[idx] != tok2id[idx + 1]:
        print(' '.join(tokens[prev:idx + 1]), ':', merged_tokens[i])
        prev = idx + 1
        i += 1

[CLS] : [CLS]
Keith : Keith
Urban is : Urban is
a : a
person : person
who sings : who sings
. : .
[SEP] : [SEP]
Keith : Keith
Urban : Urban
[SEP] : [SEP]
The : The
album : album
' : '
s : s
fourth : fourth
single : single
, : ,
" : "
You : You
' : '
ll : ll
Think of : Think of
Me : Me
" : "
, : ,
earned : earned
him his : him
first : his
Grammy : first
. : Grammy
[SEP] : .


TypeError: sequence item 0: expected str instance, int found

In [540]:
print(tokens[30:32])
print(tok2id[30:32])
print(token_id2concept_id[30:32])

['him', 'his']
[27, 27]
[19735, 19735]


In [547]:
asd = bert_concept_alignment(
    ['[CLS]', 'Keith', 'Urban', 'is', 'a', 'person', 'who', 'sings', '.', '[SEP]', 'Keith', 'Urban', '[SEP]', 'The', 'album', "'", 's', 'fourth', 'single', ',', '"', 'You', "'", 'll', 'Think', 'of', 'Me', '"', ',', 'earned', 'him', 'his', 'first', 'Grammy', '.', '[SEP]', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [[0, 1, 'keith', 'keith', 554573], [1, 3, 'urban is', 'urban_are', 446492], [4, 5, 'person', 'person', 934], [5, 7, 'who sings', 'who_sings', 672893]],
    [[1, 2, 'album', 'album', 14296], [3, 4, 'fourth', 'fourth', 4883], [4, 5, 'single', 'single', 3211], [8, 9, 'you', 'mine', 19735], [10, 12, 'think of', 'think_of', 242208], [12, 13, 'me', 'mine', 19735], [15, 16, 'earned', 'earn', 5893], [16, 17, 'him', 'mine', 19735], [17, 18, 'his', 'mine', 19735], [18, 19, 'first', 'first', 3945], [19, 20, 'grammy', 'grammy', 132705]],
    (1, 9),
    (13, 35)
)
print(asd[0])

[['[CLS]'], ['Keith'], ['Urban'], ['is', 'a'], ['person'], ['who'], ['sings', '.'], ['[SEP]'], ['Keith'], ['Urban'], ['[SEP]'], ['The'], ['album'], ["'"], ['s'], ['fourth'], ['single'], [','], ['"'], ['You'], ["'"], ['ll'], ['Think'], ['of', 'Me'], ['"'], [','], ['earned'], ['him'], ['his'], ['first'], ['Grammy'], ['.'], ['[SEP]']]
