<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
import time
import argparse
import numpy as np
import os
import torch
from collections import defaultdict
from torch.utils.data import DataLoader
from pytorch_transformers import *
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import AffinityPropagation
from sklearn.cluster import SpectralCoclustering
from tqdm import tqdm
from random import sample
from pytorch_transformers import *
from pytorch_transformers.modeling_bert import *
from utils import *


In [2]:

def load_seed(dataset, file):
    topic_words = {}
    with open(dataset+'/result_'+file+'.txt') as f:
        data=f.readlines()
        current_topic = ''
        for line in data:
            if len(line.strip()) == 0:
                current_topic = ''
                continue
            elif len(line.split(' ')) == 1:
                current_topic = line.split(':')[0]
                continue
            elif current_topic != '':
                topic_words[current_topic] = line.strip().split(' ')
            # else:
                # print(line)

    return topic_words

def get_emb(vec_file):
    f = open(vec_file, 'r')
    contents = f.readlines()[1:]
    word_emb = {}
    vocabulary = {}
    vocabulary_inv = {}
    emb_mat = []
    for i, content in enumerate(contents):
        content = content.strip()
        tokens = content.split(' ')
        word = tokens[0]
        vec = tokens[1:]
        vec = [float(ele) for ele in vec]
        word_emb[word] = np.array(vec)
        vocabulary[word] = i
        vocabulary_inv[i] = word
        emb_mat.append(np.array(vec))
    vocab_size = len(vocabulary)
    emb_mat = np.array(emb_mat) 
    return word_emb, vocabulary, vocabulary_inv, emb_mat

def get_temb(vec_file, topic_file):
    topic2id = {}
    topic_emb = {}
    id2topic = {}
    topic_hier = {}
    i = 0
    with open(topic_file, 'r') as f:
        for line in f:
            parent = line.strip().split('\t')[0]
            temp = line.strip().split('\t')[1]          
            for topic in temp.split(' '):
                topic2id[topic] = i
                id2topic[i] = topic
                i += 1
                if parent not in topic_hier:
                    topic_hier[parent] = []
                topic_hier[parent].append(topic)
    f = open(vec_file, 'r')
    contents = f.readlines()[1:]
    for i, content in enumerate(contents):
        content = content.strip()
        tokens = content.split(' ')
        vec = tokens
        vec = [float(ele) for ele in vec]
        topic_emb[id2topic[i]] = np.array(vec)
    return topic_emb, topic2id, id2topic, topic_hier

def get_cap(vec_file, cap0_file=None):
    print(vec_file)
    f = open(vec_file, 'r')
    contents = f.readlines()[1:]
    word_cap = {}
    for i, content in enumerate(contents):
        content = content.strip()
        tokens = content.split(' ')
        word = tokens[0]
        vec = tokens[1]
        vec = float(vec)
        word_cap[word] = vec
    
    if cap0_file is not None:
        with open(cap0_file) as f:
            contents = f.readlines()[1:]
            for i, content in enumerate(contents):
                content = content.strip()
                tokens = content.split(' ')
                word = tokens[0]
                vec = tokens[1]
                vec = float(vec)
                word_cap[word] = vec 
    return word_cap

def topic_sim(query, idx2word, t_emb, w_emb):
    if query in t_emb:
        q_vec = t_emb[query]
    else:
        q_vec = w_emb[query]
    word_emb = np.zeros((len(idx2word), 100))
    for i in range(len(idx2word)):
        word_emb[i] = w_emb[idx2word[i]]
    res = np.dot(word_emb, q_vec)
    res = res/np.linalg.norm(word_emb, axis=1)
    sort_id = np.argsort(-res)

    return sort_id

def rank_cap(cap, idx2word, class_name):
    word_cap = np.zeros(len(idx2word))
    for i in range(len(idx2word)):
        if idx2word[i] in cap:
            word_cap[i] = (cap[idx2word[i]]-cap[class_name]) ** 2
        else:
            word_cap[i] = np.array([1.0])
    low2high = np.argsort(word_cap)
    return low2high

def rank_cap_customed(cap, idx2word, class_idxs):
    target_cap = np.mean([cap[idx2word[ind]] for ind in class_idxs])
    word_cap = np.zeros(len(idx2word))
#     print(f"target capacity: {target_cap}")
    for i in range(len(idx2word)):
        if idx2word[i] in cap:
            word_cap[i] = (cap[idx2word[i]]-target_cap) ** 2
#             tmp = cap[idx2word[i]]-target_cap
#             if tmp > 0:
#                 word_cap[i] = tmp
#             else:
#                 word_cap[i] = np.inf
        else:
            word_cap[i] = np.array([1.0])
    low2high = np.argsort(word_cap)
#     print("intermediate: "+ str(word_cap[vocabulary['data_mining']]))
#     print(low2high[0:5])
    return low2high, target_cap

def aggregate_ranking(sim, cap, word_cap, topic, idx2word, pretrain, target=None):
    simrank2id = np.ones(len(sim)) * np.inf
    caprank2id = np.ones(len(sim)) * np.inf
    for i, w in enumerate(sim[:]):
        simrank2id[w] = i + 1
#     print(f'topic capacity:{word_cap[topic]}')
#     print(f'target capcity: {target}')
    for i, w in enumerate(cap):
        if pretrain == 0:
            if target is not None and word_cap[idx2word[w]] > target:
                caprank2id[w] = i + 1
            if target is None:
                caprank2id[w] = i + 1
    if pretrain == 0:        
        agg_rank = simrank2id * caprank2id
        final_rank = np.argsort(agg_rank)
        final_rank_words = [idx2word[idx] for idx in final_rank[:500] if idx2word[idx] in ent_sent_index]
    else:
        agg_rank = simrank2id
        final_rank = np.argsort(agg_rank)
        final_rank_words = [idx2word[idx] for idx in final_rank[:500] if idx2word[idx] in ent_sent_index]
#     print(f'\n{topic} ranking list:')
#     print([caprank2id[idx] for idx in sim[:20]])
#     print([simrank2id[idx] for idx in final_rank[:20]])
#     print([caprank2id[idx] for idx in final_rank[:20]])
#     print(final_rank_words)
#     f = open(out_file, 'a')
#     f.write(f'\n{topic}:\n')
#     f.write(' '.join(final_rank_words) + '\n')    
    return final_rank_words

In [3]:
# processing training and test data
# Basically use entity pairs to find their co-occurred sentences in the corpus
# and then generate positive and negative training samples

def process_training_data(rep_words, topic_hier, max_seq_length):

    parent_list = [x for x in topic_hier if x != 'ROOT']
    sentences_index = []
    final_data = []

    real_rep_words = {}
    for key in rep_words:
        real_rep_words[key] = []
        
    print("collecting positive samples!")
    
    for parent in parent_list:
        for child in topic_hier[parent]:
#             print(child)
            count = 10
            for b in rep_words[child]:
                if b not in ent_sent_index:
                    continue
                for a in rep_words[parent]:
                    if a not in ent_sent_index:
                        continue
                    cooccur = ent_sent_index[a].intersection(ent_sent_index[b])
                    if len(cooccur) > 0:
                        if a not in real_rep_words[parent]:
                            real_rep_words[parent].append(a)
                        if b not in real_rep_words[child]:
                            real_rep_words[child].append(b) 
                        print(a)
                        print(b)
                        for sen in cooccur:
#                             if sen in sentences_index:
#                                 continue
                            sentences_index.append(sen)
                            s = sentences[sen]
                            s = '[CLS] '+s
                            s = s.split(' ')
                            if a not in s or b not in s:
                                continue
                            p_index = s.index(a)
                            c_index = s.index(b)
                            s[p_index] = '[MASK]'
                            s[c_index] = '[MASK]'
                            s = ' '.join(s).replace('_', ' ').replace('-lrb-','(').replace('-rrb-',')').split(' ')
                            input_id = tokenizer.encode(' '.join(s))
                            tokened_text = tokenizer.tokenize(' '.join(s))
                            mask_id = [x for x in range(len(input_id)) if input_id[x]==103]  
#                             if len(mask_id) < 2:
#                                 print(' '.join(s))
#                                 print(a)
#                                 print(b)

                            if len(input_id) > max_seq_length:
                                if mask_id[1] - mask_id[0] >= max_seq_length:
                                    continue
                                else:
                                    input_id = input_id[mask_id[0]:mask_id[1]+1]
                                    p_index = 0 if p_index<c_index else mask_id[1] - mask_id[0]
                                    c_index = 0 if c_index<p_index else mask_id[1] - mask_id[0]
                            else:
                                p_index = mask_id[0] if p_index<c_index else mask_id[1] 
                                c_index = mask_id[0] if c_index<p_index else mask_id[1]
                            
                            input_id = torch.tensor(input_id)
                            r_sentence = F.pad(torch.tensor(input_id),(0,max_seq_length-len(input_id)), "constant", 0)
                            attention_mask = torch.cat((torch.ones_like(input_id), torch.zeros(max_seq_length-len(input_id), dtype=torch.int64)),dim=0)
                            p_mask = np.zeros((max_seq_length))
                            p_mask[p_index] = 1
                            c_mask = np.zeros((max_seq_length))
                            c_mask[c_index] = 1
                                            
                            final_data.append([r_sentence, p_mask, c_mask, attention_mask, 0])
                            final_data.append([r_sentence, c_mask, p_mask, attention_mask, 1])
                            
#                             if count > 0:
#                                 print(a)
#                                 print(b)
#                                 print(r_sentence)
#                                 print(tokened_text)
#                                 print(p_index)
#                                 print(p_mask)
#                                 print(c_index)
#                                 print(c_mask)
#                                 print(attention_mask)
#                                 print('\n')
#                                 count -= 1
                                
    pos_len = len(final_data)
    print(f"positive data number: {pos_len}")
                                
    print("collecting negative samples from siblings!")  
    for parent in parent_list:
        for child in topic_hier[parent]:
#             print(child)
            count = 10
            for b in rep_words[child]:
                if b not in ent_sent_index:
                    continue
                for a in rep_words[child]:
                    if a == b:
                        continue
                    if a not in ent_sent_index:
                        continue
                    cooccur = ent_sent_index[a].intersection(ent_sent_index[b])
                    if len(cooccur) > 0:
                        for sen in cooccur:
                            if sen in sentences_index:
                                continue
                            if np.random.random(1) > 0.1:
                                continue
#                             sentences_index.append(sen)
                            s = sentences[sen]
                            s = '[CLS] '+s
                            s = s.split(' ')
                            if a not in s or b not in s:
                                continue
                            p_index = s.index(a)
                            c_index = s.index(b)
                            s[p_index] = '[MASK]'
                            s[c_index] = '[MASK]'
                            s = ' '.join(s).replace('_', ' ').replace('-lrb-','(').replace('-rrb-',')').split(' ')
                            input_id = tokenizer.encode(' '.join(s))
                            mask_id = [x for x in range(len(input_id)) if input_id[x]==103]                                                       

                            if len(input_id) > max_seq_length:
                                if mask_id[1] - mask_id[0] >= max_seq_length:
                                    continue
                                else:
                                    input_id = input_id[mask_id[0]:mask_id[1]+1]
                                    p_index = 0 if p_index<c_index else mask_id[1] - mask_id[0]
                                    c_index = 0 if c_index<p_index else mask_id[1] - mask_id[0]
                            else:
                                p_index = mask_id[0] if p_index<c_index else mask_id[1] 
                                c_index = mask_id[0] if c_index<p_index else mask_id[1]
                            
                            input_id = torch.tensor(input_id)
                            r_sentence = F.pad(torch.tensor(input_id),(0,max_seq_length-len(input_id)), "constant", 0)
                            attention_mask = torch.cat((torch.ones_like(input_id), torch.zeros(max_seq_length-len(input_id), dtype=torch.int64)),dim=0)
                            p_mask = np.zeros((max_seq_length))
                            p_mask[p_index] = 1
                            c_mask = np.zeros((max_seq_length))
                            c_mask[c_index] = 1
                                            
                            final_data.append([r_sentence, p_mask, c_mask, attention_mask, 2])
                            

    print(len(final_data))
    
    print("collecting negative samples from corpus!")
    while len(final_data) < pos_len * 2:
#         if len(final_data) % 100 == 0:
#             print(pos_len)
#             print(len(final_data))
        sen = np.random.choice(len(sentences))
#         remove positive sentences
        if sen in sentences_index:
            continue
            
        s = sentences[sen]
        s = '[CLS] '+s
        s = sentences[sen].split(' ')
        
        #randomly choose pairs
        entities = [x for x in s if "_" in x]
        if len(entities) < 2:
            continue
        p_index, c_index = np.random.choice(len(entities),2)
        while c_index == p_index:
#             continue
            c_index = np.random.choice(len(entities))
            
        s[p_index] = '[MASK]'
        s[c_index] = '[MASK]'
        s = ' '.join(s).replace('_', ' ').replace('-lrb-','(').replace('-rrb-',')').split(' ')
        input_id = tokenizer.encode(' '.join(s))
        mask_id = [x for x in range(len(input_id)) if input_id[x]==103]                                                       

        if len(input_id) > max_seq_length:
            if mask_id[1] - mask_id[0] >= max_seq_length:
                continue
            else:
                input_id = input_id[mask_id[0]:mask_id[1]+1]
                p_index = 0 if p_index<c_index else mask_id[1] - mask_id[0]
                c_index = 0 if c_index<p_index else mask_id[1] - mask_id[0]
        else:
            p_index = mask_id[0] if p_index<c_index else mask_id[1] 
            c_index = mask_id[0] if c_index<p_index else mask_id[1]

        input_id = torch.tensor(input_id)
        r_sentence = F.pad(torch.tensor(input_id),(0,max_seq_length-len(input_id)), "constant", 0)
        attention_mask = torch.cat((torch.ones_like(input_id), torch.zeros(max_seq_length-len(input_id), dtype=torch.int64)),dim=0)

        p_mask = np.zeros((max_seq_length))
        p_mask[p_index] = 1
        c_mask = np.zeros((max_seq_length))
        c_mask[c_index] = 1

        final_data.append([r_sentence, p_mask, c_mask, attention_mask, 2])
        

    
    
    return final_data, sentences_index


def process_test_data(test_topic_rep_words, test_cand, max_seq_length):

    print("collecting positive samples!")
    final_data = []
    

    count = 10
    for b in test_topic_rep_words:
        if b not in ent_sent_index or b not in ename2embed_bert:
            continue
        for a in test_cand:
            if a not in ent_sent_index or a not in ename2embed_bert:
                continue
            if a == b:
                continue
            cooccur = ent_sent_index[a].intersection(ent_sent_index[b])
            if len(cooccur) > 0:
                for sen in cooccur:
#                     if sen in sentences_index:
#                         continue
#                     sentences_index.append(sen)
                    s = sentences[sen]
                    s = '[CLS] '+s
                    s = s.split(' ')
                    if a not in s or b not in s:
                        continue
                    p_index = s.index(a)
                    c_index = s.index(b)
                    s[p_index] = '[MASK]'
                    s[c_index] = '[MASK]'
                    s = ' '.join(s).replace('_', ' ').replace('-lrb-','(').replace('-rrb-',')').split(' ')
                    input_id = tokenizer.encode(' '.join(s))
                    mask_id = [x for x in range(len(input_id)) if input_id[x]==103]
#                     if len(mask_id) < 2:
#                         print(sentences[sen])
#                         print(a)
#                         print(b)

                    if len(input_id) > max_seq_length:
                        if mask_id[1] - mask_id[0] >= max_seq_length:
                            continue
                        else:
                            input_id = input_id[mask_id[0]:mask_id[1]+1]
                            p_index = 0 if p_index<c_index else mask_id[1] - mask_id[0]
                            c_index = 0 if c_index<p_index else mask_id[1] - mask_id[0]
                    else:
                        p_index = mask_id[0] if p_index<c_index else mask_id[1] 
                        c_index = mask_id[0] if c_index<p_index else mask_id[1]

                    input_id = torch.tensor(input_id)
                    r_sentence = F.pad(torch.tensor(input_id),(0,max_seq_length-len(input_id)), "constant", 0)
                    attention_mask = torch.cat((torch.ones_like(input_id), torch.zeros(max_seq_length-len(input_id), dtype=torch.int64)),dim=0)

                    p_mask = np.zeros((max_seq_length))
                    p_mask[p_index] = 1
                    c_mask = np.zeros((max_seq_length))
                    c_mask[c_index] = 1

                    final_data.append([r_sentence, p_mask, c_mask, attention_mask, a])
                    final_data.append([r_sentence, c_mask, p_mask, attention_mask, a])
                    

#                     if count > 0:
#                         print(r_sentence)
#                         print(p_index)
#                         print(c_index)
#                         print(attention_mask)
#                         print('\n')
#                         count -= 1

    
    
    return final_data

def generate_batch(batch):
    input_ids = torch.tensor([np.array(entry[0]) for entry in batch])
    entity1_mask = torch.tensor([entry[1] for entry in batch])
    entity2_mask = torch.tensor([entry[2] for entry in batch])
    attention_mask = torch.tensor([np.array(entry[3]) for entry in batch])
    labels = torch.tensor([entry[4] for entry in batch])
    return input_ids, entity1_mask, entity2_mask, attention_mask, labels 

def generate_test_batch(batch):
    input_ids = torch.tensor([np.array(entry[0]) for entry in batch])
    entity1_mask = torch.tensor([entry[1] for entry in batch])
    entity2_mask = torch.tensor([entry[2] for entry in batch])
    attention_mask = torch.tensor([np.array(entry[3]) for entry in batch])
    entity = [entry[4] for entry in batch]
    return input_ids, entity1_mask, entity2_mask, attention_mask, entity


In [4]:
# training/validation/testing function of Bert model
def train_func(sub_train_, BATCH_SIZE, optimizer, scheduler):
    train_loss = 0
    train_acc = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
    for i, (input_ids, entity1_mask, entity2_mask, attention_mask, labels) in enumerate(data):
        optimizer.zero_grad()
        input_ids, entity1_mask, entity2_mask, attention_mask, labels = input_ids.to(device), entity1_mask.float().to(device), entity2_mask.float().to(device), attention_mask.to(device), labels.to(device)
        output, loss = model(input_ids, attention_mask=attention_mask, entity1_mask=entity1_mask, entity2_mask=entity2_mask, labels=labels)
        train_acc += (output.argmax(1) == labels).sum().item()
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        
    scheduler.step()
    
    return train_loss/len(sub_train_), train_acc/len(sub_train_)

def valid_func(data_):
    valid_loss = 0
    valid_acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for i, (input_ids, entity1_mask, entity2_mask, attention_mask, labels) in enumerate(data):
        with torch.no_grad():
            input_ids, entity1_mask, entity2_mask, attention_mask, labels = input_ids.to(device), entity1_mask.float().to(device), entity2_mask.float().to(device), attention_mask.to(device), labels.to(device)
            output, loss = model(input_ids, attention_mask=attention_mask, entity1_mask=entity1_mask, entity2_mask=entity2_mask, labels=labels)
            valid_acc += (output.argmax(1) == labels).sum().item()
            valid_loss += loss.item()
            

    return valid_loss / len(data_), valid_acc/len(data_)

def test(data_, BATCH_SIZE):
    logits = []
    entities = []
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_test_batch)
    for i, (input_ids, entity1_mask, entity2_mask, attention_mask, entity) in enumerate(data):
        entities.extend(entity)
        input_ids, entity1_mask, entity2_mask, attention_mask = input_ids.to(device), entity1_mask.float().to(device), entity2_mask.float().to(device), attention_mask.to(device)
        with torch.no_grad():
            output = model(input_ids, attention_mask=attention_mask, entity1_mask=entity1_mask, entity2_mask=entity2_mask)
            logits.extend(F.softmax(output, dim=1).cpu().numpy())
    return logits, entities 
        

In [5]:
# BertModel + Linear Classifier

import torch.nn as nn
import torch.nn.functional as F
from pytorch_transformers import *
from pytorch_transformers.modeling_bert import *

class RelationClassifier(BertPreTrainedModel):

    def __init__(self, config):
        super(RelationClassifier, self).__init__(config)
#         super().__init__(config)
        self.num_labels = 3

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size * 2, self.num_labels)
        self.apply(self.init_weights)

#         self.init_weights()


    def forward(
        self, 
        input_ids, 
        token_type_ids=None, 
        attention_mask=None, 
        entity1_mask=None, 
        entity2_mask=None, 
        labels=None):
        
        with torch.no_grad():
        
            encoded_layers = self.bert(input_ids, attention_mask)[0]
            batch_size, max_seq_length = entity1_mask.shape[0],entity1_mask.shape[1]

            diag_entity1_mask_ = []
            for i in range(batch_size):
                diag_entity1_mask_.append(torch.diag(entity1_mask[i]).cpu().numpy())
            diag_entity1_mask = torch.tensor(diag_entity1_mask_).cuda()

            diag_entity2_mask_ = []
            for i in range(batch_size):
                diag_entity2_mask_.append(torch.diag(entity2_mask[i]).cpu().numpy())
            diag_entity2_mask = torch.tensor(diag_entity2_mask_).cuda()

            # Concatenate two entity embedding      
            batch_entity1_emb = torch.matmul(diag_entity1_mask, encoded_layers).permute(0,2,1)
            batch_entity2_emb = torch.matmul(diag_entity2_mask, encoded_layers).permute(0,2,1)
            batch_entity_emb = torch.cat((batch_entity1_emb, batch_entity2_emb), dim=1)


            pooling = nn.MaxPool1d(kernel_size=max_seq_length, stride=1)
            entity_emb_output = pooling( batch_entity_emb ).squeeze()
            entity_emb_output = self.dropout(entity_emb_output)
        
        # Linear layer classifier
        logits = self.classifier(entity_emb_output)



        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            return logits , loss
        else:
            return logits

        return outputs  # (loss), logits, (hidden_states), (attentions)


In [6]:
def kl_divergence(p,q):
    d = 0
    num_labels = len(p)
    for i in range(num_labels):
        d += p[i] * np.log(p[i]/q[i])
    return d

In [7]:
from matplotlib import pyplot as plt

def type_consistent(topic_word_dict0, ename2embed_bert, print_cls = False):
    topic_word_dict = {}
    all_words = []

    for topic in topic_word_dict0:
        topic_word_dict[topic] = []
        for ename in topic_word_dict0[topic]:
#             ename = ename.replace('_',' ')
            if ename in ename2embed_bert:
                topic_word_dict[topic].append(ename)
                all_words.append(ename)
                
    topics = list(topic_word_dict0.keys())

    all_words.extend([x for x in topics if x in ename2embed_bert])

    # all_words.extend(['POTENTIAL_PARENT_'+x for x in potential_parents if x in ename2embed_bert])
    # all_embed.extend([ename2embed_bert[x][0] for x in potential_parents if x in ename2embed_bert])

    all_children = []
    all_embed = []
    all_words_and_their_parents = []   
    for word in all_words:
        topic_count = 0
        word0 = word
        for topic in topic_word_dict:
            if word in topic_word_dict[topic]:
                word0 = '('+topic+')'+word0
                topic_count += 1
        if topic_count > 1:
            continue
        all_words_and_their_parents.append(word0)
        all_children.append(word)
        all_embed.append(ename2embed_bert[word][0])

    # AP
    clustering = AffinityPropagation().fit(all_embed)
    n_clusters = max(clustering.labels_) + 1
    clusters = {}
    singular_words = []
    for i in range(n_clusters):
        clusters[i] = [ all_words_and_their_parents[x] for x in range(len(clustering.labels_)) if clustering.labels_[x]==i]
        if print_cls:
            print(clusters[i])
        category_count = set()
        for word0 in clusters[i]:
            tmp = word0.split('(')
            for seg in tmp:
                tmp2 = seg.split(')')
                if len(tmp2) < 2:
                    continue
                category_count.add(tmp2[0])
            if len(category_count) > 1:
                break
#         print(len(category_count))
        if len(category_count) <= 1:
            singular_words.extend([all_words[x] for x in range(len(clustering.labels_)) if clustering.labels_[x]==i])

#     print(singular_words)

    new_topic_word_dict = {}
    for topic in topic_word_dict:
        new_topic_word_dict[topic] = []
        for ename in topic_word_dict[topic]:
            if ename not in singular_words:
#                 ename = ename.replace(' ','_')
                new_topic_word_dict[topic].append(ename)

    return new_topic_word_dict

def type_consistent_for_list(l, topic_word_dict0, ename2embed_bert, print_cls = False):
    if len(l) < 60:
        return l
    tmp_cluster = {}
    tmp_cluster['0'] = l
    for topic in topic_word_dict0:
        tmp_cluster[topic] = topic_word_dict0[topic]
    tmp_cluster = type_consistent(tmp_cluster, ename2embed_bert, print_cls)    
#     print(tmp_cluster['0'])
    l = tmp_cluster['0']
    return l
    
    
    

In [8]:
def relation_inference(test_data, BATCH_SIZE,mode='child'):
    logits, cand_entities = test(test_data, BATCH_SIZE)
    
    if(len(test_data)==0):
        return {},{}

    logits = np.array(logits)
    labels = np.argmax(logits, axis=1)
    test_num = len(labels)
    entity_ratio = {}
    entity_count = {}

    for i in range(int(test_num/2)):
        ent = cand_entities[2*i]
        l1 = labels[2*i]
        kl1 = kl_divergence([1/3,1/3,1/3],logits[2*i])
        l2 = labels[2*i + 1]
        kl2 = kl_divergence([1/3,1/3,1/3],logits[2*i+1])
        if ent not in entity_ratio:
            entity_ratio[ent] = np.zeros((3))
        entity_ratio[ent] += logits[2*i]
        entity_ratio[ent] += [logits[2*i+1][1], logits[2*i+1][0], logits[2*i+1][2]]
#         print(f'{cand_entities[2*i]} {logits[2*i]} {kl1} {kl2}')
        if kl1 > 0.5 and kl2 > 0.5:
            if mode == 'child':
                if l1 == 0 and l2 == 1 and logits[2*i][0]>0.7 and logits[2*i+1][1]>0.7:
#                     print(f'parent: {cand_entities[2*i]}')
                    if ent not in entity_count:
                        entity_count[ent] = np.zeros((3))
                    entity_count[ent][0] += 1
                elif l1 == 1 and l2 == 0 and logits[2*i][1]>0.7 and logits[2*i+1][0]>0.7:
                    if ent not in entity_count:
                        entity_count[ent] = np.zeros((3))
                    entity_count[ent][1] += 1
#                     print(f'child: {cand_entities[2*i]} {logits[2*i]} {kl1} {kl2}')
                elif l1 == 2 and l2 == 2:
                    if ent not in entity_count:
                        entity_count[ent] = np.zeros((3))
                    entity_count[ent][2] += 1
            elif mode == 'parent':
                if l1 == 0 and l2 == 1 :
        #             print(f'parent: {cand_entities[2*i]}')
                    if ent not in entity_count:
                        entity_count[ent] = np.zeros((3))
                    entity_count[ent][0] += 1
                elif l1 == 1 and l2 == 0 :
                    if ent not in entity_count:
                        entity_count[ent] = np.zeros((3))
                    entity_count[ent][1] += 1
    #                 print(f'child: {cand_entities[2*i]} {logits[2*i]} {kl1} {kl2}')
                elif l1 == 2 and l2 == 2:
                    if ent not in entity_count:
                        entity_count[ent] = np.zeros((3))
                    entity_count[ent][2] += 1
    #             print(f'no relation: {cand_entities[2*i]}')
    #         print(f"kl1: {kl1} kl2: {kl2}")
    return entity_ratio, entity_count

In [9]:
def sum_all_rel(test_topics, entity_count_alltopics, mode='child'):
    child_entities_count = defaultdict(list)
    entity_confidence = defaultdict(float)
    for test_topic in test_topics:
#         print(f"test topic: {test_topic}")
        entity_count = entity_count_alltopics[test_topic]

        for ent in entity_count:            
            ratio = entity_count[ent]/np.sum(entity_count[ent])
#             print(f"{ent}: parent: {ratio[0]} child: {ratio[1]} no relation: {ratio[2]}")
            if mode == 'child' and ratio[1] > 0.7:
                child_entities_count[test_topic].append(ent)
                entity_confidence[ent] += ratio[1]
            elif mode == 'parent' and ratio[0] > 0.7:
                child_entities_count[test_topic].append(ent)
                entity_confidence[ent] += ratio[0]
    for ent in entity_confidence:
        entity_confidence[ent] /= len(test_topics)
    entity_conf = sorted(entity_confidence.items(), key = lambda x: x[1], reverse = True)
#     print(entity_conf)
    return child_entities_count


In [10]:
from collections import defaultdict

def get_common_ent_for_list(l):

    parent_cand = set()

    for test_topic in l:
        if len(parent_cand) == 0:
            parent_cand = ent_ent_index[test_topic]
        else:
            parent_cand = parent_cand.intersection(ent_ent_index[test_topic])

    return parent_cand

def get_common_ent_for_list_with_dict(l,d):
    parent_result = set()
    for test_topic in l:
        if len(parent_result) == 0:
            parent_result = set(d[test_topic])
        else:
            parent_result = parent_result.intersection(set(d[test_topic]))

    return parent_result

def get_threshold_from_dict(d, thre):
    parent_result_entities = defaultdict(int)
    for topic in d:
        for ent in d[topic]:
            parent_result_entities[ent] += 1
#     print(parent_result_entities)
    parent_result_entities = [x for x in parent_result_entities if parent_result_entities[x] >= len(d)*thre]
#     print(parent_result_entities)
    return parent_result_entities


In [11]:
# cluster by type, and delete 

from sklearn.cluster import SpectralCoclustering
def type_consistent_cocluster(topic_word_dict0, ename2embed_bert, n_cluster_min, print_cls = False, save_file=None):
    topic_word_dict = {}
    all_words = []

    for topic in topic_word_dict0:
        topic_word_dict[topic] = []
        for ename in topic_word_dict0[topic]:
            if ename in ename2embed_bert:
                topic_word_dict[topic].append(ename)
                all_words.append(ename)
                
    topics = list(topic_word_dict0.keys())
#     print("topics")
#     print(topics)

    all_children = [x for x in all_words]
#     all_words.extend([x for x in topics if x in ename2embed_bert])
    all_embed = [ename2embed_bert[x][0] for x in all_words]
#     print(all_children)


    all_words_and_their_parents = []
    for word in all_words:
        for topic in topic_word_dict:
            if word in topic_word_dict[topic]:
                word0 = (topic, word)
                break
        all_words_and_their_parents.append(word0)
#     print(all_words_and_their_parents)
        

    # AP
    clustering = AffinityPropagation().fit(all_embed)
    n_clusters = max(clustering.labels_) + 1
    clusters = {}
    col_vectors = np.zeros( (len(topic_word_dict) ,n_clusters), dtype=float)
    type_dict = {}
    for i in range(n_clusters):
        clusters[i] = [ all_words_and_their_parents[x] for x in range(len(clustering.labels_)) if clustering.labels_[x]==i]
        if print_cls:
            print(i)
            print(clusters[i])
            for ind, w in clusters[i]:
                type_dict[w] = str(i)
        for word0 in clusters[i]:
            word0_col = int(word0[0])
            col_vectors[word0_col,i] = 1
    col_vectors = np.array(col_vectors)
    col_vectors += 0.1*np.ones( (len(topic_word_dict) ,n_clusters), dtype=int)
    
    
    for n_cluster in range(n_cluster_min, n_cluster_min+10):
    
        model = SpectralCoclustering(n_clusters=n_cluster, random_state=0)
        model.fit(col_vectors)

        new_topic_word_dict = {}
#         print('row cluster result:', model.row_labels_)
#         print('col cluster result:', model.column_labels_)
        coverage_list = []
        for ind in range(n_cluster):
#             print(ind)
            small_matrix = col_vectors[[x for x in range(len(model.row_labels_)) if model.row_labels_[x] == ind]]
            small_matrix = small_matrix[:,[x for x in range(len(model.column_labels_)) if model.column_labels_[x] == ind]]
            coverage_list.append(np.sum(small_matrix)/np.sum(np.ones_like(small_matrix)))
        if max(coverage_list) >= 0.7:
            break
            
    fit_data = col_vectors[np.argsort(model.row_labels_)]
    fit_data = fit_data[:, np.argsort(model.column_labels_)]
    
    cluster_count = [sum(model.row_labels_==x) for x in range(n_cluster)]
#     print("row cluster count: ", cluster_count)
    
    cluster_count = [sum(model.column_labels_==x) for x in range(n_cluster)]
#     print("column cluster count: ", cluster_count)

    
    coverage_thre = min(max(coverage_list), 0.4)
#     print('coverage: ',coverage_list)
    
    for ind in range(n_cluster):
        if coverage_list[ind] <= coverage_thre:
#             print("del cluster ",ind)
            continue
        for topic in topic_word_dict:
            if model.row_labels_[int(topic)] == ind:
                new_topic_word_dict[topic] = [x for x in topic_word_dict[topic]]

    if print_cls:
        pass
#         if clustering.labels_[int(topic)]==0:
#             new_topic_word_dict[topic] = [x for x in topic_word_dict[topic]]
#         else:
#             print('del ',topic_word_dict[topic])

    return new_topic_word_dict



In [13]:
# select corpus
dataset = 'yelp'
file='des'
topic_file='topics_des.txt'


In [14]:

# ent_sent_index.txt: record the sentence id where each entity occurs; used for generating BERT training sample
print('loading corpus!')
ent_sent_index = dict()
with open(dataset+'/ent_sent_index.txt') as f:
    for line in f:
        ent = line.split('\t')[0]
        tmp = line.strip().split('\t')[1].split(' ')
        tmp = [int(x) for x in tmp]
        ent_sent_index[ent] = set(tmp)

# sentences_.txt: sentence id to text
sentences = dict()
with open(dataset+'/sentences_.txt') as f:
    for i,line in enumerate(f):
        sentences[i] = line
        
ent_ent_index = dict()
with open(dataset+'/ent_ent_index.txt') as f:
    for line in f:
        ent = line.split('\t')[0]
        tmp = line.strip().split('\t')[1].split(' ')
        ent_ent_index[ent] = set(tmp)
        

# ename2embed_bert = loadEnameEmbedding(os.path.join(dataset, 'BERTembed_5.txt'), 768)
# # ename2embed_bert = loadEnameEmbedding(os.path.join(dataset, 'bert_word_emb.txt'), 768)


print('finish loading corpus!')

loading corpus!
finish loading corpus!


In [15]:
pretrain=0
# load word embedding
word_emb, vocabulary, vocabulary_inv, emb_mat = get_emb(vec_file=os.path.join(dataset, 'emb_part_'+file + '_w.txt'))

# load topic embedding
topic_emb, topic2id, id2topic, topic_hier = get_temb(vec_file=os.path.join(dataset, 'emb_part_'+file+'_t.txt'), topic_file=os.path.join(dataset, topic_file))

# load word specificity
word_cap = get_cap(vec_file=os.path.join(dataset, 'emb_part_'+file+'_cap.txt'))

# load bert embedding
ename2embed_bert = loadEnameEmbedding(os.path.join(dataset, 'BERTembed.txt'), 768)

# calculate topic representative words: rep_words
rep_words = {}
for topic in topic_emb:
    print(topic)
    sim_ranking = topic_sim(topic, vocabulary_inv, topic_emb, word_emb)
    if pretrain:
        cap_ranking = np.ones((len(vocabulary)))
        word_cap1 = np.ones((len(vocabulary)))
    else:
        cap_ranking = rank_cap(word_cap, vocabulary_inv, topic)    
    rep_words[topic] = aggregate_ranking(sim_ranking, cap_ranking, word_cap, topic, vocabulary_inv, pretrain)
rep_words1 = {}
for topic in topic_emb:
    rep_words1[topic] = [x for x in rep_words[topic]]


yelp/emb_part_des_cap.txt
seafood
burger
salad
dessert
ice_cream
cake
pastries


In [16]:

for word in rep_words:
    rep_words[word] = [word]
print(rep_words)

{'seafood': ['seafood'], 'burger': ['burger'], 'salad': ['salad'], 'dessert': ['dessert'], 'ice_cream': ['ice_cream'], 'cake': ['cake'], 'pastries': ['pastries']}


In [17]:
# Relation Statement Classifier

import time
from torch.utils.data import DataLoader
from pytorch_transformers import *
import torch.nn as nn
import torch.nn.functional as F

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

BATCH_SIZE=16
TEST_BATCH_SIZE=512
EPOCHS = 5
max_seq_length = 128

print(torch.cuda.device_count())
# config = BertConfig.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# model = BertForSequenceClassification(config)
# model = BertModel(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RelationClassifier.from_pretrained('bert-base-uncased')
model.float()
model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.5)




1


In [18]:
# generating training data
total_data, sentences_index = process_training_data(rep_words, topic_hier, max_seq_length)
train_data = total_data[:int(len(total_data)/2*0.95)]
train_data.extend(total_data[int(len(total_data)/2):int(len(total_data)/2+len(total_data)/2*0.95)])
valid_data = total_data[int(len(total_data)/2*0.95):int(len(total_data)/2)]
valid_data.extend(total_data[int(len(total_data)/2*0.95+len(total_data)/2):])



collecting positive samples!
dessert
ice_cream
dessert
cake
dessert
pastries
positive data number: 2650
collecting negative samples from siblings!
2650
collecting negative samples from corpus!


In [19]:
print(f"training data point number: {len(train_data)}")

training data point number: 5034


In [20]:
# training the bert classifier

for epoch in range(EPOCHS):
    start_time = time.time()
    train_loss, train_acc = train_func(train_data, BATCH_SIZE, optimizer, scheduler)
    print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
    
    valid_loss, valid_acc = valid_func(valid_data)
    print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')
    
    
    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs)) 

	Loss: 0.0163(train)	|	Acc: 90.4%(train)
	Loss: 0.0123(valid)	|	Acc: 91.4%(valid)
Epoch: 1  | time in 0 minutes, 45 seconds
	Loss: 0.0091(train)	|	Acc: 94.6%(train)
	Loss: 0.0094(valid)	|	Acc: 94.4%(valid)
Epoch: 2  | time in 0 minutes, 45 seconds
	Loss: 0.0076(train)	|	Acc: 95.8%(train)
	Loss: 0.0094(valid)	|	Acc: 93.2%(valid)
Epoch: 3  | time in 0 minutes, 45 seconds
	Loss: 0.0070(train)	|	Acc: 96.1%(train)
	Loss: 0.0092(valid)	|	Acc: 92.9%(valid)
Epoch: 4  | time in 0 minutes, 45 seconds
	Loss: 0.0068(train)	|	Acc: 96.2%(train)
	Loss: 0.0091(valid)	|	Acc: 93.6%(valid)
Epoch: 5  | time in 0 minutes, 45 seconds


In [21]:
# Depth Expansion: get the relation classification result of test data
# train_topic='dessert'
entity_ratio_alltopics = {}
entity_count_alltopics = {}

training_topic = {}
for topic in topic_hier['ROOT']:
    if topic in topic_hier:
        training_topic[topic] = word_cap[topic]
training_topic = sorted(training_topic.items(), key = lambda x: x[1])
train_topic = training_topic[0][0]

for test_topic in topic_hier['ROOT']:
    
    sim_ranking = topic_sim(test_topic, vocabulary_inv, topic_emb, word_emb)
    cap_ranking, target_cap = rank_cap_customed(word_cap, vocabulary_inv, [vocabulary[word] for word in topic_hier[train_topic]])
    coefficient = max(word_cap[test_topic] / word_cap[train_topic],1)
    test_cand = aggregate_ranking(sim_ranking, cap_ranking, word_cap, test_topic, vocabulary_inv, pretrain, target_cap*coefficient)
    print(f'topic: {test_topic}')
    test_data = process_test_data(rep_words[test_topic], test_cand, max_seq_length)
    print(f"test data point number: {len(test_data)}")
#     if len(test_data) > 10000:
#         test_data = test_data[:10000]

    entity_ratio, entity_count = relation_inference(test_data, TEST_BATCH_SIZE)
    entity_ratio_alltopics[test_topic] = entity_ratio
    entity_count_alltopics[test_topic] = entity_count

child_entities_count = sum_all_rel(topic_hier['ROOT'], entity_count_alltopics, mode='child')

# print(child_entities_count)
    

topic: seafood
collecting positive samples!
test data point number: 6076
topic: burger
collecting positive samples!
test data point number: 16956
topic: salad
collecting positive samples!
test data point number: 49954
topic: dessert
collecting positive samples!
test data point number: 21254


In [22]:
child_entities = type_consistent(child_entities_count, ename2embed_bert)
# child_entities=child_entities_count
from sklearn.cluster import AffinityPropagation

clusters_all = {}
k=0
start_list = [0]
for j,topic in enumerate(topic_hier['ROOT']): 
    if topic not in child_entities:
        continue
    X = []
    
    for ent in child_entities[topic]:
        if ent not in word_emb:
            continue
        X.append(word_emb[ent])
    X = np.array(X)
#     print(X)
#     if len(X) == 0:
#         print(len(X))
#         continue
    clustering = AffinityPropagation().fit(X)
    n_clusters = max(clustering.labels_) + 1
    clusters = {}
    for i in range(n_clusters):
        clusters[str(i)] = [child_entities[topic][x] for x in range(len(clustering.labels_)) if clustering.labels_[x] == i]
        
        clusters_all[str(k)] = clusters[str(i)]
        k+=1
    start_list.append(k)
#     new_clusters = type_consistent_col(clusters, ename2embed_bert)
    
#     i=0
#     for k in new_clusters: 
# #         print(clusters[k])
#         if len(new_clusters[k])>1:
#             print(i,':', end='\t')
#             print(new_clusters[k])
#             i += 1
# print(clusters_all)    
# new_clusters = type_consistent_col(clusters_all, ename2embed_bert)
# print(clusters_all)
new_clusters = type_consistent_cocluster(clusters_all, ename2embed_bert, n_cluster_min = 2, print_cls=False)

# print(start_list)

tmp = defaultdict(list)

topic_idx = 0
for k in range(len(clusters_all)):
    if k >= start_list[topic_idx]:
        print('\n',topic_hier['ROOT'][topic_idx])
        topic_idx += 1
    if str(k) in new_clusters:# and len(new_clusters[str(k)]) > 1:
        print(new_clusters[str(k)])
        print('')
        tmp[topic_hier['ROOT'][topic_idx-1]].append(new_clusters[str(k)])

child_entities = tmp
    
    

  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))



 seafood
['king_crab', 'shellfish', 'shrimps', 'oyster', 'snow_crab_legs', 'mussel', 'crabs', 'lobsters', 'crab_cakes', 'fresh_oysters', 'raw_oysters', 'crab_claws', 'rockefeller', 'chowder', 'oyster_bar', 'crabmeat', 'maine', 'shell', 'shells', 'snow', 'seafood_section', 'smoked_salmon', 'fried_oysters', 'stone_crab', 'peel', 'jumbo', 'snails', 'alligator', 'gazpacho']

['squid', 'ocean', 'wasabi', 'roe', 'miso', 'ponzu', 'spicy_sauce', 'marinade', 'nori']

['sea_bass', 'fishes', 'snapper', 'cod', 'toro', 'tuna_tartar', 'quail', 'seared', 'mackerel', 'sea_urchin', 'kobe_beef', 'carpaccio', 'hamachi', 'butterfish', 'grilled_squid', 'diver', 'sea_scallops', 'freshness', 'ahi_tuna', 'courses', 'seared_tuna', 'fin']

['crawfish', 'dungeness_crab', 'catfish', 'grits', 'etouffee', 'sampler', 'pan', 'sausages']

['fish_balls', 'saut_ed', 'pork_chops', 'veg', 'snow_peas', 'carrots', 'bean_sprouts', 'poultry']

['skewers', 'kimchi', 'bulgogi', 'galbi', 'fusion', 'overly_salty']

['sashimi', '

1


In [23]:
# Root Node Candidate Generation!


parent_cand = get_common_ent_for_list(topic_hier['ROOT'])
if len(parent_cand) > 1000:
    parent_cand = type_consistent_for_list(parent_cand, rep_words, ename2embed_bert, False)
print(len(parent_cand))
print(parent_cand)


217
['flavors', 'replacement', 'app', 'dosa', 'carne_asada', 'southwestern', 'pan', 'main_ingredient', 'bun', 'flatbread', 'portion_sizes', 'sprouts', 'cookie', 'dates', 'dipping_sauces', 'cheeseburger', 'appetizer', 'burrito', 'deserts', 'lunch_special', 'cocktail', 'combo', 'sampler', 'sake', 'roast', 'treats', 'goods', 'coconut', 'hundreds', 'soy', 'stated', 'summer', 'typical', 'charcuterie', 'shops', 'thinking', 'watering', 'ground', 'curious', 'skimp', 'fried_egg', 'seasonal', 'adam', 'spinach', 'called', 'burgers', 'bonus', 'amount', 'herb', 'hair', 'vinegar', 'essence', 'none', 'olive', 'scotch_egg', 'guests', 'chicken_wings', 'parmigiana', 'depending', 'hype', '100', 'locations', 'pretty_solid', 'thursday_night', 'fact', 'chopped_salad', 'hook', 'husband_loves', 'beans', 'sweet_sour', 'american', 'days', 'smoked_salmon', 'yonge', 'lola', 'event', 'sun', 'dreams', 'tangy', 'member', 'venetian', 'border', 'art', 'double', 'fly', 'description', 'leftovers', 'snacks', 'sour_cream'

In [24]:
parent_entity_ratio_alltopics = {}
parent_entity_count_alltopics = {}
from random import sample


for test_topic in topic_hier['ROOT']:
    print(f'test topic: {test_topic}')
    
    test_data = process_test_data([test_topic], list(parent_cand), max_seq_length)
    print(f"test data point number: {len(test_data)}")
    
#     if len(test_data) > 10000:
#         test_data = sample(test_data, 10000)
    
    
    entity_ratio, entity_count = relation_inference(test_data, TEST_BATCH_SIZE,mode='child')
    parent_entity_ratio_alltopics[test_topic] = entity_ratio
    parent_entity_count_alltopics[test_topic] = entity_count

parent_entities_count = sum_all_rel(topic_hier['ROOT'], parent_entity_count_alltopics, mode='parent')
parent_result = get_threshold_from_dict(parent_entities_count, 1/2)
# print(len(parent_result))

parent_result = type_consistent_for_list(parent_result, rep_words, ename2embed_bert, False)
print(parent_result)

    

test topic: seafood
collecting positive samples!
test data point number: 4996
test topic: burger
collecting positive samples!
test data point number: 12986
test topic: salad
collecting positive samples!
test data point number: 17558
test topic: dessert
collecting positive samples!
test data point number: 7048
['southwestern', 'lunch_special', 'roast', 'seasonal', 'guests', 'carb', 'entr_es', 'salad_bar', 'sushi_roll', 'buffet', 'recipes', 'fried_tofu', 'ambiance', 'plating', 'generous_portion', 'vegetarians', 'carne_asada', 'summer', 'lola', 'dreams', 'leftovers', 'ramen', 'barbecue']


In [25]:
# print('------------------New topic finding!------------------')

topic_cand = defaultdict(int)
for topic in parent_result:
    for ent in ent_ent_index[topic]:
        topic_cand[ent] += 1
# topic_cand = sorted(topic_cand.items(), key = lambda x: x[1], reverse=True)
topic_cand = [x for x in topic_cand if topic_cand[x] >= len(parent_result)/2]

remove_list = []
for topic in child_entities_count:
    remove_list.extend(child_entities_count[topic])
# print(remove_list)
remove_list.extend(parent_result)

tmp = []
for topic in topic_cand:
    if topic not in remove_list:
        tmp.append(topic)
topic_cand = tmp

topic_entity_ratio_alltopics = {}
topic_entity_count_alltopics = {}


for test_topic in parent_result:
    print(f'test topic: {test_topic}')
    
    test_data = process_test_data([test_topic], list(topic_cand), max_seq_length)
    print(f"test data point number: {len(test_data)}")
    if len(test_data) > 10000:
        test_data = sample(test_data, 10000)   
    
    entity_ratio, entity_count = relation_inference(test_data, TEST_BATCH_SIZE,mode='child')
    topic_entity_ratio_alltopics[test_topic] = entity_ratio
    topic_entity_count_alltopics[test_topic] = entity_count

topic_entities_count = sum_all_rel(parent_result, topic_entity_count_alltopics, mode='child')

# print(topic_entities_count)


topic_entities = get_threshold_from_dict(topic_entities_count, 1/3)
cap_list = [word_cap[x] for x in topic_hier['ROOT']]
print([(x, word_cap[x]) for x in topic_entities if x in word_cap])
topic_entities = [x for x in topic_entities if word_cap[x] < max(cap_list)*1.0 and word_cap[x] > min(cap_list)*1.0]
for t in topic_hier['ROOT']:
    if t in topic_hier:
        for t1 in topic_hier[t]:
            if t1 in topic_entities:
                topic_entities.remove(t1)
print(topic_entities)

# topic_entities = [x for x in topic_entities if word_cap[x] < max(cap_list) and word_cap[x] > min(cap_list)]
# topic_entities = type_consistent_for_list(topic_entities, rep_words, ename2embed_bert, False)
# print(topic_entities)
for t in topic_hier['ROOT']:
    if t in topic_hier:
        for t1 in topic_hier[t]:
            if t1 in topic_entities:
                topic_entities.remove(t1)
    for t1 in child_entities[t]:
        if t1 in topic_entities:
            topic_entities.remove(t1)
print(topic_entities)

    

test topic: southwestern
collecting positive samples!
test data point number: 876
test topic: lunch_special
collecting positive samples!
test data point number: 5572
test topic: roast
collecting positive samples!
test data point number: 1968
test topic: seasonal
collecting positive samples!
test data point number: 3550
test topic: guests
collecting positive samples!
test data point number: 9922
test topic: carb
collecting positive samples!
test data point number: 1138
test topic: entr_es
collecting positive samples!
test data point number: 986
test topic: salad_bar
collecting positive samples!
test data point number: 5778
test topic: sushi_roll
collecting positive samples!
test data point number: 1302
test topic: buffet
collecting positive samples!
test data point number: 32112
test topic: recipes
collecting positive samples!
test data point number: 2180
test topic: fried_tofu
collecting positive samples!
test data point number: 958
test topic: ambiance
collecting positive samples!
tes

In [26]:
topic_hier1 = {}

topic_hier1['ROOT']= topic_entities
for topic in topic_hier:
    if topic == 'ROOT':
        for t in topic_hier[topic]:
            if t not in topic_hier1[topic]:
                topic_hier1[topic].append(t)
    else:
        topic_hier1[topic] = [x for x in topic_hier[topic]]
        
entity_ratio_alltopics1 = {}
entity_count_alltopics1 = {}

for test_topic in topic_hier1['ROOT']:
    if test_topic in topic_hier['ROOT']:
        entity_ratio_alltopics1[test_topic] = entity_ratio_alltopics[test_topic]
        entity_count_alltopics1[test_topic] = entity_count_alltopics[test_topic]
        continue
    
    sim_ranking = topic_sim(test_topic, vocabulary_inv, topic_emb, word_emb)
    cap_ranking, target_cap = rank_cap_customed(word_cap, vocabulary_inv, [vocabulary[word] for word in topic_hier[train_topic]])
    coefficient = max(word_cap[test_topic] / word_cap[train_topic],1)
    test_cand = aggregate_ranking(sim_ranking, cap_ranking, word_cap, test_topic, vocabulary_inv, pretrain, target_cap*coefficient)

    test_data = process_test_data([test_topic], test_cand, max_seq_length)
    print(f"test data point number: {len(test_data)}")
    
#     if len(test_data) > 10000:
#         test_data=test_data[:10000]

    entity_ratio, entity_count = relation_inference(test_data, TEST_BATCH_SIZE)
    entity_ratio_alltopics1[test_topic] = entity_ratio
    entity_count_alltopics1[test_topic] = entity_count

    

collecting positive samples!
test data point number: 6406
collecting positive samples!
test data point number: 21444
collecting positive samples!
collecting positive samples!
test data point number: 13898
collecting positive samples!
test data point number: 21344
collecting positive samples!
test data point number: 23296
collecting positive samples!
test data point number: 29354
collecting positive samples!
test data point number: 24366
collecting positive samples!
test data point number: 5866
collecting positive samples!
test data point number: 12046
collecting positive samples!
test data point number: 23326
collecting positive samples!
test data point number: 8708
collecting positive samples!
test data point number: 15468
collecting positive samples!
test data point number: 17924
collecting positive samples!
test data point number: 19322
collecting positive samples!
test data point number: 8006
collecting positive samples!
test data point number: 5668
collecting positive samples!
tes

In [29]:
# subtopic finding for new topics!
topic_hier1 = {}

topic_hier1['ROOT']= topic_entities
for topic in topic_hier:
    if topic == 'ROOT':
        for t in topic_hier[topic]:
            if t not in topic_hier1[topic]:
                topic_hier1[topic].append(t)
    else:
        topic_hier1[topic] = [x for x in topic_hier[topic]]
# print(topic_hier)
print(topic_hier1)
save_tree_to_file(topic_hier1, 'intermediate.txt')

{'ROOT': ['red', 'fries', 'cheese', 'bacon', 'burger', 'spicy', 'steak', 'beef', 'sandwich', 'cut', 'fish', 'pizza', 'extra', 'coffee', 'bread', 'hot', 'bowl', 'white', 'shrimp', 'veggies', 'pork', 'dry', 'rice', 'seafood', 'salad', 'dessert'], 'dessert': ['ice_cream', 'cake', 'pastries']}


In [30]:
entity_ratio_alltopics1 = {}
entity_count_alltopics1 = {}

for test_topic in topic_hier1['ROOT']:
    if test_topic in topic_hier['ROOT']:
        entity_ratio_alltopics1[test_topic] = entity_ratio_alltopics[test_topic]
        entity_count_alltopics1[test_topic] = entity_count_alltopics[test_topic]
        continue
    
    sim_ranking = topic_sim(test_topic, vocabulary_inv, topic_emb, word_emb)
    cap_ranking, target_cap = rank_cap_customed(word_cap, vocabulary_inv, [vocabulary[word] for word in topic_hier[train_topic]])
    coefficient = max(word_cap[test_topic] / word_cap[train_topic],1)
    test_cand = aggregate_ranking(sim_ranking, cap_ranking, word_cap, test_topic, vocabulary_inv, pretrain, target_cap*coefficient)
    print(f'test topic: {test_topic}')  
    test_data = process_test_data([test_topic], test_cand, max_seq_length)
    print(f"test data point number: {len(test_data)}")
    
#     if len(test_data) > 10000:
#         test_data=test_data[:10000]

    entity_ratio, entity_count = relation_inference(test_data, TEST_BATCH_SIZE)
    entity_ratio_alltopics1[test_topic] = entity_ratio
    entity_count_alltopics1[test_topic] = entity_count
    
    
child_entities_count1 = sum_all_rel(topic_hier1['ROOT'], entity_count_alltopics1, mode='child')

child_entities1 = type_consistent(child_entities_count1, ename2embed_bert)

from sklearn.cluster import AffinityPropagation

clusters_all = {}
k=0
start_list = [0]
for j,topic in enumerate(topic_hier1['ROOT']):   
    
    X = []
    for ent in child_entities1[topic]:
        if ent not in word_emb:
            continue
        X.append(word_emb[ent])
    X = np.array(X)
    if len(X) == 0:
        print(topic)
        continue

    clustering = AffinityPropagation().fit(X)
    n_clusters = max(clustering.labels_) + 1
    clusters = {}
    for i in range(n_clusters):
        clusters[str(i)] = [child_entities1[topic][x] for x in range(len(clustering.labels_)) if clustering.labels_[x] == i]
        
        clusters_all[str(k)] = clusters[str(i)]
        k+=1
    start_list.append(k)
print('-----')
print(k)
print('-----')
#     new_clusters = type_consistent_col(clusters, ename2embed_bert)
    
#     i=0
#     for k in new_clusters: 
# #         print(clusters[k])
#         if len(new_clusters[k])>1:
#             print(i,':', end='\t')
#             print(new_clusters[k])
#             i += 1
# print(clusters_all)    
# new_clusters = type_consistent_col(clusters_all, ename2embed_bert)
new_clusters = type_consistent_cocluster(clusters_all, ename2embed_bert, n_cluster_min = 2, print_cls = True, save_file='dblp_field+_cls8')

print(start_list)

tmp = defaultdict(list)

topic_idx = 0
for k in range(len(clusters_all)):
    if k >= start_list[topic_idx]:
#         print('\n',topic_hier1['ROOT'][topic_idx])
        topic_idx += 1
    if str(k) in new_clusters and len(new_clusters[str(k)]) > 1:
#         print(new_clusters[str(k)])
        tmp[topic_hier1['ROOT'][topic_idx-1]].append(new_clusters[str(k)])

child_entities1 = tmp
for t in topic_hier['ROOT']:
    child_entities1[t] = child_entities[t]
    
for t in topic_hier1['ROOT']:
    print(t)
    for l in child_entities1[t]:
        print(l)
    print('')

    

test topic: red
collecting positive samples!
test data point number: 6406
test topic: fries
collecting positive samples!
test data point number: 21444
test topic: cheese
collecting positive samples!
test data point number: 45000
test topic: bacon
collecting positive samples!
test data point number: 13898
test topic: spicy
collecting positive samples!
test data point number: 21344
test topic: steak
collecting positive samples!
test data point number: 23296
test topic: beef
collecting positive samples!
test data point number: 29354
test topic: sandwich
collecting positive samples!
test data point number: 24366
test topic: cut
collecting positive samples!
test data point number: 5866
test topic: fish
collecting positive samples!
test data point number: 12046
test topic: pizza
collecting positive samples!
test data point number: 23326
test topic: extra
collecting positive samples!
test data point number: 8708
test topic: coffee
collecting positive samples!
test data point number: 15468
tes

[('101', 'lettuce'), ('101', 'lettuce'), ('101', 'lettuce'), ('101', 'lettuce'), ('101', 'lettuce'), ('230', 'romaine_lettuce'), ('230', 'romaine_lettuce'), ('101', 'lettuce'), ('230', 'romaine_lettuce'), ('311', 'romaine'), ('101', 'lettuce')]
172
[('64', 'wonton'), ('201', 'lentil'), ('64', 'wonton'), ('201', 'lentil')]
173
[('72', 'msg'), ('204', 'grease'), ('204', 'grease'), ('237', 'dust'), ('72', 'msg'), ('204', 'grease'), ('275', 'crumbs')]
174
[('13', 'sticks'), ('18', 'rings'), ('13', 'sticks')]
175
[('15', 'wraps'), ('17', 'dips'), ('34', 'benedicts'), ('45', 'grilled_cheese_sandwiches'), ('114', 'bowls'), ('15', 'wraps'), ('17', 'dips'), ('148', 'bento_boxes'), ('181', 'dressings'), ('114', 'bowls'), ('186', 'iced_coffees'), ('192', 'crepes'), ('194', 'snacks'), ('17', 'dips'), ('207', 'sandwiches'), ('207', 'salads'), ('221', 'platters'), ('271', 'burritos'), ('302', 'courses'), ('181', 'dressings'), ('192', 'crepes'), ('207', 'salads'), ('207', 'sandwiches'), ('347', 'main

  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))
  for c in range(self.n_clusters))


[0, 13, 21, 29, 39, 61, 77, 94, 108, 128, 141, 154, 176, 185, 196, 206, 215, 235, 244, 251, 263, 273, 282, 292, 311, 335, 377]
red
['bell_pepper', 'chilies', 'green_peppers', 'carrots', 'jalapeno', 'red_peppers', 'fresh_basil', 'cucumber', 'red_onion', 'green_onions', 'cashews', 'big_chunks', 'celery', 'scallions', 'cucumbers']
['olives', 'roasted_garlic', 'oregano', 'squash', 'artichokes', 'olive', 'mozzarella', 'gorgonzola', 'parsley', 'ricotta', 'pancetta']
['mole', 'burro', 'posole', 'tamales', 'ground_beef', 'chimi', 'tortilla_chips']
['tomatillo', 'chiles', 'tamarind', 'lime_juice', 'sprinkled', 'chilli', 'limes', 'peanuts', 'powder', 'chili_powder', 'chili_flakes', 'rim', 'shredded_cheese']

fries
['pickles', 'kettle_chips', 'thousand_island', 'dill', 'vinegar', 'mayonnaise', 'relish', 'strings', 'american_cheese', 'spuds', 'kraut', 'banana_peppers', 'dijon_mustard']

cheese
['caramelized_onions', 'blue_cheese', 'jalape_os', 'kraut', 'pepper_jack', 'gooey_cheese', 'pickles', 'mo

  for c in range(self.n_clusters))
  for c in range(self.n_clusters))


In [32]:
# print the keyword taxonomy, nodes in which will be enriched later by concept learning.
with open(os.path.join(dataset, 'keyword_taxonomy.txt'), 'w') as fout:
    for topic in topic_hier1['ROOT']:  
        if len(child_entities1[topic]) > 0:      
            fout.write(topic+'\n')
            for cls in child_entities1[topic]:
                fout.write(' '.join(cls)+'\n')
            fout.write('\n')

for topic in topic_hier1['ROOT']:
    if len(child_entities1[topic]) > 0:
        with open(os.path.join(dataset, 'topics_'+topic+'.txt'),'w') as fout:
            for cls in child_entities1[topic]:
                fout.write(' '.join(cls)+'\n')