In [1]:
from run_bert_link_prediction import *
from nltk.stem.wordnet import WordNetLemmatizer
import pickle
from nltk.corpus import wordnet as wn

## Setup directories

## Parameter

In [2]:
# basic parameters
data_path = "./data/wn"
data_saved_path = "./output_wn_result"
bert_model="bert-base-cased"
task_name="kg"
max_seq_length=50
eval_batch_size=1500

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# load precessor
processors = {"kg": KGProcessor,}
processor = processors[task_name]()

# obtain label
label_list = ["0","1"]
num_labels = len(label_list)

# obtain entity list
entity_list = processor.get_entities(data_path)

# load model
tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=False)
model = BertForSequenceClassification.from_pretrained(data_saved_path, num_labels=num_labels)
location_detail=model.to(device)

In [4]:
# load All data
train_triples = processor.get_train_triples(data_path)
dev_triples = processor.get_dev_triples(data_path)
test_triples = processor.get_test_triples(data_path)
all_triples = train_triples + dev_triples + test_triples

all_triples_str_set = set()
for triple in all_triples:
    triple_str = '\t'.join(triple)
    all_triples_str_set.add(triple_str)

## Kg-Bert Link Prediction

In [5]:
def _create_new_examples(lines, corup_loc = "head",set_type="test",data_dir=data_path, remove_entity=False,entity_list=entity_list):
    #use NLTK WorfNet to Build Corrupted Lines for Prediction
    ent2text = {}
    with open(os.path.join(data_dir, "entity2text.txt"), 'r') as f:
        ent_lines = f.readlines()
        for line in ent_lines:
            temp = line.strip().split('\t')
            if len(temp) == 2:
                end = temp[1]#.find(',')
                ent2text[temp[0]] = temp[1]#[:end]

    if data_dir.find("FB15") != -1:
        with open(os.path.join(data_dir, "entity2textlong.txt"), 'r') as f:
            ent_lines = f.readlines()
            for line in ent_lines:
                temp = line.strip().split('\t')
                #first_sent_end_position = temp[1].find(".")
                ent2text[temp[0]] = temp[1]#[:first_sent_end_position + 1] 

    entities = list(ent2text.keys())

    rel2text = {}
    with open(os.path.join(data_dir, "relation2text.txt"), 'r') as f:
        rel_lines = f.readlines()
        for line in rel_lines:
            temp = line.strip().split('\t')
            rel2text[temp[0]] = temp[1]      

    lines_str_set = set(['\t'.join(line) for line in lines])
    examples = []
    
    corrupt_list=[]
    
    for (i, line) in enumerate(lines):
        label = "1"
        #print(lines)
        
        ent_a = line[0]
        ent_b = line[1]
        ent_c = line[2]
        head_ent_text = ent2text[line[0]]
        tail_ent_text = ent2text[line[2]]
        relation_text = rel2text[line[1]]
        text_a = head_ent_text
        text_b = relation_text
        text_c = tail_ent_text
        
        corrupt_list.append(line)
        
        guid = "%s-%s" % (set_type, i)
        examples.append(
            InputExample(guid=guid, text_a=text_a, text_b=text_b, text_c = text_c, label=label))
        
        ent_a_modify = ent_a.split(":")[1].split(".")[0]
        ent_c_modify = ent_c.split(":")[1].split(".")[0]
        
        if corup_loc == "head":
                text_a_candits = wn.synsets(ent_a_modify)
                for item in text_a_candits:
                    if item.name() == ent_a.split(":")[1]:
                        continue
                        
                    if remove_entity and item.name() not in entity_list:
                        continue
                    text_a = item.definition()
                    
                    corrupt_list.append(["wn:"+item.name(),ent_b,ent_c])
                    
                    examples.append(
                        InputExample(guid=guid, text_a=text_a, text_b=text_b, text_c = text_c, label="0"))
        
        if corup_loc == "tail":
                text_c_candits = wn.synsets(ent_c_modify)
                for item in text_c_candits:
                    if item.name() == ent_c.split(":")[1]:
                        continue
                        
                    if remove_entity and item.name() not in entity_list:
                        continue
                        
                    text_c = item.definition()
                    corrupt_list.append([ent_a,ent_b,"wn:"+item.name()])
                    
                    examples.append(
                        InputExample(guid=guid, text_a=text_a, text_b=text_b, text_c = text_c, label="0"))

    return examples,corrupt_list

def rank_accuracy(ranks):
    # check the accuracy for different hits
    max_dep = 1
    threshold = 0
    accuracy_dict = dict()
    ite = 0
    
    while threshold <= max_dep and ite <500:
        for rank in ranks:
            max_dep = max(max_dep, rank)
            if rank <= threshold:
                accuracy_dict[threshold] = accuracy_dict.get(threshold,0)+1
                
        accuracy_dict[threshold] = accuracy_dict[threshold]/len(ranks)
        threshold += 1
        ite += 1
    return accuracy_dict

In [6]:
# load test data
def kg_bert_prediction(triples):
    ranks_left = []
    ranks_right = []
    count = 0
    for test_triple in tqdm(triples):
        head = test_triple[0]
        relation = test_triple[1]
        tail = test_triple[2]

        head_corrupt_list = [test_triple]
        temp_examples, head_corrupt_list = _create_new_examples(head_corrupt_list, corup_loc = "head")
        temp_features = convert_examples_to_features(temp_examples, label_list, max_seq_length, tokenizer, print_info = False)

        all_input_ids = torch.tensor([f.input_ids for f in temp_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in temp_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in temp_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in temp_features], dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

        # Run prediction for temp data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)
        model.eval()
        preds = []

        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:

            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(
                    preds[0], logits.detach().cpu().numpy(), axis=0)

        preds = preds[0]

        rel_values = preds[:, all_label_ids[0]]
        rel_values = torch.tensor(rel_values)

        _, argsort1 = torch.sort(rel_values, descending=True)

        argsort1 = argsort1.cpu().numpy()
        rank1 = np.where(argsort1 == 0)[0][0]

        ranks_left.append(rank1)

        # build corrupted tail
        head_corrupt_list = [test_triple]
        temp_examples,head_corrupt_list = _create_new_examples(head_corrupt_list, corup_loc = "tail")
        temp_features = convert_examples_to_features(temp_examples, label_list, max_seq_length, tokenizer, print_info = False)

        all_input_ids = torch.tensor([f.input_ids for f in temp_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in temp_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in temp_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in temp_features], dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

        # Run prediction for temp data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)
        model.eval()
        preds = []

        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:

            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(
                    preds[0], logits.detach().cpu().numpy(), axis=0)

        preds = preds[0]

        rel_values = preds[:, all_label_ids[0]]
        rel_values = torch.tensor(rel_values)

        _, argsort2 = torch.sort(rel_values, descending=True)

        argsort2 = argsort2.cpu().numpy()
        rank2 = np.where(argsort2 == 0)[0][0]

        ranks_right.append(rank2)
    return ranks_left, ranks_right

**WordNet NLTK Train Prediction**

In [23]:
ranks_left_train1, rank_right_train1 = kg_bert_prediction(train_triples)

100%|██████████| 8000/8000 [11:26<00:00, 11.65it/s]


In [24]:
accuracy_dict_left = rank_accuracy(ranks_left_train1)
# show the accuracy for highest rank candit
accuracy_dict_left[0]

0.901125

In [25]:
accuracy_dict_right = rank_accuracy(rank_right_train1)
# show the accuracy for highest rank candit
accuracy_dict_right[0]

0.8045

**WordNet NLTK DEV Prediction**

In [26]:
ranks_left_dev1, rank_right_dev1 = kg_bert_prediction(dev_triples)

100%|██████████| 1000/1000 [01:23<00:00, 12.03it/s]


In [27]:
accuracy_dict_left = rank_accuracy(ranks_left_dev1)
# show the accuracy for highest rank candit
accuracy_dict_left[0]

0.858

In [28]:
accuracy_dict_right = rank_accuracy(rank_right_dev1)
# show the accuracy for highest rank candit
accuracy_dict_right[0]

0.775

**WordNet NLTK TEST Prediction**

In [29]:
ranks_left_test1, rank_right_test1 = kg_bert_prediction(test_triples)

100%|██████████| 1000/1000 [01:32<00:00, 10.85it/s]


In [30]:
accuracy_dict_left = rank_accuracy(ranks_left_test1)
# show the accuracy for highest rank candit
accuracy_dict_left[0]

0.884

In [31]:
accuracy_dict_right = rank_accuracy(rank_right_test1)
# show the accuracy for highest rank candit
accuracy_dict_right[0]

0.772

## Kg_Bert II Link Prediction

Filter the training data set and make prediction

In [36]:
data_saved_path = "./output_wn_result2"
tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=False)
model = BertForSequenceClassification.from_pretrained(data_saved_path, num_labels=num_labels)
location_detail=model.to(device)

**WordNet NLTK Train Prediction**

In [48]:
ranks_left_train2, rank_right_train2 = kg_bert_prediction(train_triples)

100%|██████████| 8000/8000 [15:33<00:00,  8.57it/s]  


In [49]:
accuracy_dict_left = rank_accuracy(ranks_left_train2)
# show the accuracy for highest rank candit
accuracy_dict_left[0]

0.9045

In [50]:
accuracy_dict_right = rank_accuracy(rank_right_train2)
# show the accuracy for highest rank candit
accuracy_dict_right[0]

0.821125

**WordNet NLTK DEV Prediction**

In [37]:
ranks_left_dev2, rank_right_dev2 = kg_bert_prediction(dev_triples)

100%|██████████| 1000/1000 [01:32<00:00, 10.82it/s]


In [38]:
accuracy_dict_left = rank_accuracy(ranks_left_dev2)
# show the accuracy for highest rank candit
accuracy_dict_left[0]

0.866

In [39]:
accuracy_dict_right = rank_accuracy(rank_right_dev2)
# show the accuracy for highest rank candit
accuracy_dict_right[0]

0.778

**WordNet NLTK TEST Prediction**

In [40]:
ranks_left_test2, rank_right_test2 = kg_bert_prediction(test_triples)

100%|██████████| 1000/1000 [01:30<00:00, 11.05it/s]


In [41]:
accuracy_dict_left = rank_accuracy(ranks_left_test2)
# show the accuracy for highest rank candit
accuracy_dict_left[0]

0.879

In [42]:
accuracy_dict_right = rank_accuracy(rank_right_test2)
# show the accuracy for highest rank candit
accuracy_dict_right[0]

0.78

## Show the Output&Input example for test

In [9]:
triples = test_triples
ranks_left = []
ranks_right = []
count = 0
res= []
with open("wn_TestResult_head.txt","w") as f1, open("wn_TestResult_tail.txt","w") as f2:
    for test_triple in tqdm(triples):
        head = test_triple[0]
        relation = test_triple[1]
        tail = test_triple[2]
        # write the actuacl triple to predict
        f1.write("The Actual Triple to Test: "+"\t".join(test_triple)+"\n")
        f1.write("\n")
        head_corrupt_list = [test_triple]
        
        temp_examples,head_corrupt_list = _create_new_examples(head_corrupt_list, corup_loc = "head")
        temp_features = convert_examples_to_features(temp_examples, label_list, max_seq_length, tokenizer, print_info = False)
        
        all_input_ids = torch.tensor([f.input_ids for f in temp_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in temp_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in temp_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in temp_features], dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

        # Run prediction for temp data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)
        model.eval()
        preds = []

        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:

            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(
                    preds[0], logits.detach().cpu().numpy(), axis=0)

        preds = preds[0]
        res.append(preds)
        rel_values = preds[:, all_label_ids[0]]
        rel_values = torch.tensor(rel_values)
        
        f1.write("-----Prediction Detail-----\n")
        for idx in range(len(rel_values)):
            example_ = temp_examples[idx]
            feature_ = temp_features[idx]
            input_ids = all_input_ids[idx].cpu().numpy()
            input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
            predict_text = "\t".join(head_corrupt_list[idx])
            f1.write("Triple_to_Predict: {}\n".format(predict_text))
            
            #text of triples:
            f1.write("Text of Triples:{a}\t{b}\t{c}\n".format(a=example_.text_a,
                                                          b=example_.text_b,
                                                          c=example_.text_c))
            
            # input tokens& ids
            f1.write("Input_Tokens: {}\n".format(",".join(input_tokens)))
            f1.write("Input_Ids: {}\n".format(",".join([str(_) for _ in input_ids])))
            f1.write("Label_id: {}\n".format(label_ids[idx]))
            f1.write("\n")
            #output_score
            f1.write("Score of This Triple: {}\n".format(rel_values[idx]))
            f1.write("\n")
        f1.write("-----Prediction Final Result-----\n")
        _, argsort1 = torch.sort(rel_values, descending=True)

        argsort1 = argsort1.cpu().numpy()
        rank1 = np.where(argsort1 == 0)[0][0]
        
        f1.write("The rank of correct answer in the candidates list: {}\n".format(rank1))
        f1.write("The Predicted Triples:{}".format(head_corrupt_list[argsort1[0]]))
        f1.write("\n")
        f1.write("############################################################\n")
        ranks_left.append(rank1)
        ############################################################################################
        # write the actuacl triple to predict
        f2.write("The Actual Triple to Test: "+"\t".join(test_triple)+"\n")
        f2.write("\n")
        tail_corrupt_list = [test_triple]
        
        temp_examples,tail_corrupt_list = _create_new_examples(tail_corrupt_list, corup_loc = "tail")
        temp_features = convert_examples_to_features(temp_examples, label_list, max_seq_length, tokenizer, print_info = False)
        
        all_input_ids = torch.tensor([f.input_ids for f in temp_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in temp_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in temp_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in temp_features], dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

        # Run prediction for temp data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)
        model.eval()
        preds = []

        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:

            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(
                    preds[0], logits.detach().cpu().numpy(), axis=0)

        preds = preds[0]
        res.append(preds)
        rel_values = preds[:, all_label_ids[0]]
        rel_values = torch.tensor(rel_values)
        
        f2.write("-----Prediction Detail-----\n")
        for idx in range(len(rel_values)):
            example_ = temp_examples[idx]
            feature_ = temp_features[idx]
            input_ids = all_input_ids[idx].cpu().numpy()
            input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
            predict_text = "\t".join(tail_corrupt_list[idx])
            f2.write("Triple_to_Predict: {}\n".format(predict_text))
            
            #text of triples:
            f2.write("Text of Triples:{a}\t{b}\t{c}\n".format(a=example_.text_a,
                                                          b=example_.text_b,
                                                          c=example_.text_c))
            
            # input tokens& ids
            f2.write("Input_Tokens: {}\n".format(",".join(input_tokens)))
            f2.write("Input_Ids: {}\n".format(",".join([str(_) for _ in input_ids])))
            f2.write("Label_id: {}\n".format(label_ids[idx]))
            f2.write("\n")
            #output_score
            f2.write("Score of This Triple: {}\n".format(rel_values[idx]))
            f2.write("\n")
        f2.write("-----Prediction Final Result-----\n")
        _, argsort1 = torch.sort(rel_values, descending=True)

        argsort1 = argsort1.cpu().numpy()
        rank1 = np.where(argsort1 == 0)[0][0]
        
        f2.write("The rank of correct answer in the candidates list: {}\n".format(rank1))
        f2.write("The Predicted Triples:{}".format(tail_corrupt_list[argsort1[0]]))
        f2.write("\n")
        f2.write("############################################################\n")
        ranks_left.append(rank1)

100%|██████████| 1000/1000 [01:27<00:00, 11.41it/s]


In [None]:
_create_new_examples(tail_corrupt_list, corup_loc = "tail")