# Imports

In [104]:
import random
import os
import itertools
import json
import numpy as np
import pandas as pd
import spacy

import requests
import json

nlp = spacy.load("en_core_web_lg")

# Config

In [101]:
input_past_pos_filepath = "data/past_positively_associated.txt"
input_past_neg_filepath = "data/past_negatively_associated.txt"
input_present_pos_filepath = "data/present_positively_associated.txt"
input_present_neg_filepath = "data/present_negatively_associated.txt"

output_past_pos_filepath = "output/past_positively_associated.txt"
output_past_neg_filepath = "output/past_negatively_associated.txt"
output_present_pos_filepath = "output/present_positively_associated.txt"
output_present_neg_filepath = "output/present_negatively_associated.txt"

train_data_filepath = "output/wiki80/wiki80_train.txt"
val_data_filepath = "output/wiki80/wiki80_val.txt"
relations_filepath = "output/wiki80/wiki80_rel2id.json"

output_dirname = "output/"

# Helper Functions

In [93]:
get_files_in_dir = lambda folder_name: [os.path.join(folder_name, file) for file in os.listdir(folder_name) if os.path.isfile(os.path.join(folder_name, file))]

def read_file(filepath):
    with open(filepath, 'r') as f:
        data = f.read()
    data = data.split("\n")
    data = [line.strip() for line in data]
    data = list(set(data))
    return data

def write_file(data, filepath):
    with open(filepath, 'w') as f:
        f.write("\n".join([str(d) for d in data]))
    print("Number of records captured:", len(data))
    print("Data saved at:", filepath)

def get_wikicode(word):
    
    url = "https://www.wikidata.org/w/api.php"
    
    params = {
        "action": "wbsearchentities",
        "search": word,
        "language": "en",
        "format": "json"
    }
    try:
        response = requests.get(url, params=params)
        
        data = json.loads(response.text)

        code = data["search"][0]["id"]
        return code
    except:
        return None

get_ents = lambda doc: [{"entity": ent, "entity_type": ent.label_, "start": ent.start, "end": ent.end} for ent in doc.ents if ent.label_ in ["PERSON", "ORG"]]

def get_single_input(text, relation=None):
    doc = nlp(text)
    tokens = [token.text for token in list(doc)]
    node_types = set([ent['entity_type'] for ent in get_ents(doc)])
    req_node_types = {'ORG', 'PERSON'}
    nodes = get_ents(doc)
    failed = False
    if len(nodes)==2 and node_types==req_node_types:
        for node in nodes:
            if node['entity_type']=="PERSON":
                source = str(node['entity'])
                source_start = node['start']
                source_end = node['end']
                source_wikicode = get_wikicode(source)
                if source_wikicode is None:
                    failed = True
            else:
                target = str(node['entity'])
                target_start = node['start']
                target_end = node['end']
                target_wikicode = get_wikicode(target)
                if target_wikicode is None:
                    failed = True
        if not failed:
            res = {
                "token": tokens,
                "h": {"name": source.lower(), "id": source_wikicode, "pos": [source_start, source_end]},
                "t": {"name": target.lower(), "id": target_wikicode, "pos": [target_start, target_end]}
                }
            if relation is not None:
                res['relation'] = relation
            return res
        else:
            return None
    else:
        return None

def get_input_data(sent_list, relation):
    res = [get_single_input(sent, relation) for sent in sent_list]
    res = [inp for inp in res if inp is not None]
    return res

# Data loading, preprocessing and dumping

In [12]:
past_positively_associated = read_file(input_past_pos_filepath)
inp_data_past_pos_associated = get_input_data(past_positively_associated, "positively associated in the past")
print(len(inp_data_past_pos_associated))
write_file(inp_data_past_pos_associated, output_past_pos_filepath)

Number of records captured: 132
Data saved at: output/past_positively_associated.txt


In [13]:
past_negatively_associated = read_file(input_past_neg_filepath)
inp_data_past_neg_associated = get_input_data(past_negatively_associated, "negatively associated in the past")
print(len(inp_data_past_neg_associated))
write_file(inp_data_past_neg_associated, output_past_neg_filepath)

141
Number of records captured: 141
Data saved at: output/past_negatively_associated.txt


In [14]:
present_positively_associated = read_file(input_present_pos_filepath)
inp_data_present_pos_associated = get_input_data(present_positively_associated, "positively associated in the present")
print(len(inp_data_present_pos_associated))
write_file(inp_data_present_pos_associated, output_present_pos_filepath)

126
Number of records captured: 126
Data saved at: output/present_positively_associated.txt


In [15]:
present_negatively_associated = read_file(input_present_neg_filepath)
inp_data_present_neg_associated = get_input_data(present_negatively_associated, "negatively associated in the present")
print(len(inp_data_present_neg_associated))
write_file(inp_data_present_neg_associated, output_present_neg_filepath)

147
Number of records captured: 147
Data saved at: output/present_negatively_associated.txt


# Train Test Split

In [17]:
n = int(input("Total number of samples per relation: "))
print("Total number of samples per relation:", n)
test_size = float(input("Test size (0 to 1, only in floating points): "))
print("Test size:", test_size)
print("Number of records in the train data:", int(n*(1-test_size)))
print("Number of records in the test data:", int(n*test_size))

Total number of samples per relation: 125
Test size: 0.2
Number of records in the train data: 100
Number of records in the test data: 25


In [107]:
train_data = []
test_data = []

past_pos = read_file(output_past_pos_filepath)
random.shuffle(past_pos)
train_data.extend(past_pos[:int(n*(1-test_size))])
test_data.extend(past_pos[:int(n*test_size)])

past_neg = read_file(output_past_neg_filepath)
random.shuffle(past_neg)
train_data.extend(past_neg[:int(n*(1-test_size))])
test_data.extend(past_neg[:int(n*test_size)])

present_pos = read_file(output_present_pos_filepath)
random.shuffle(present_pos)
train_data.extend(present_pos[:int(n*(1-test_size))])
test_data.extend(present_pos[:int(n*test_size)])

present_neg = read_file(output_present_neg_filepath)
random.shuffle(present_neg)
train_data.extend(present_neg[:int(n*(1-test_size))])
test_data.extend(present_neg[:int(n*test_size)])

relations = list(set(itertools.chain.from_iterable([list(set([eval(d)['relation'] for d in read_file(filepath)])) for filepath in get_files_in_dir(output_dirname)])))
relations = {k:v for k, v in enumerate(relations)}

random.shuffle(train_data)
random.shuffle(test_data)

train_data = "\n".join(train_data)
test_data = "\n".join(test_data)

with open(train_data_filepath, "w") as f:
    f.write(train_data)
    
with open(val_data_filepath, "w") as f:
    f.write(test_data)

with open(relations_filepath, "w") as f:
    json.dump(relations, f)