<a href="https://colab.research.google.com/github/yalopez84/Goog-Negative-Sampling/blob/master/UniformSampling_Freebase.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from tqdm import tqdm, trange
import random
import csv
from time import time

In [None]:
from google.colab import drive
drive.mount('/content/drive')
data_dir="/content/drive/MyDrive/NegativeStrategies/GoodNegativeSampling/FB13/"
os.chdir(data_dir)  


In [None]:
class InputExample(object):    
    def __init__(self, guid, text_a, text_b=None, text_c=None, label=None):
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.text_c = text_c
        self.label = label

In [None]:
class Triple(object):
    def __init__(self, guid, subject , predicate , obj, label):
        self.guid=guid
        self.subject=subject
        self.predicate=predicate
        self.obj=obj
        self.label=label

In [None]:
class DataProcessor(object):
    def get_train_examples(self, data_dir):
        raise NotImplementedError()
    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        with open(input_file, "r", encoding="utf-8") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []            
            for line in reader:
                lines.append(line)  
            return lines  

In [None]:
class KGProcessor(DataProcessor):
    def __init__(self):
        self.labels = set()
    
    def get_train_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train_reduced_11082.tsv")), "train", data_dir)
        
    def _create_examples(self, lines, set_type, data_dir):       

        examples = [] 
         
        train_with_corrupts=[]  

        examples_file=os.path.join(data_dir, "train_reduced_11082_neg_and_descrip_kgbert.tsv")
        train_with_corrupts_file=os.path.join(data_dir, "train_reduced_11082_neg_kgbert.tsv")
        
        ent2text = {}
        with open(os.path.join(data_dir, "entity2text.txt"), 'r', encoding="utf-8") as f:
            ent_lines = f.readlines()
            for line in tqdm(ent_lines):
                temp = line.strip().split('\t')
                if len(temp) == 2:
                    ent2text[temp[0]] = temp[1]      
        entities = list(ent2text.keys())

              
        rel2text = {}
        with open(os.path.join(data_dir, "relation2text.txt"), 'r', encoding="utf-8") as f:
            rel_lines = f.readlines()
            for line in rel_lines:
                temp = line.strip().split('\t')                
                rel2text[temp[0]] = temp[1]       
        lines_str_set = set(['\t'.join(line) for line in lines])
        
                      
        examples = []
        for (i, line) in enumerate(lines):
            if i==6526:
                 break
            print("******i", i)
            head_ent_text = ent2text[line[0]]
            tail_ent_text = ent2text[line[2]]
            relation_text = rel2text[line[1]]
            guidP = "%s-%s" % (set_type, i)

                       
            guidN = "%s-%s" % (set_type + "_corrupt", i) 
            rnd = random.random()
            if rnd <= 0.5:
                tmp_head = ''
                while True:
                    
                    tmp_ent_list = set(entities)
                    tmp_ent_list.remove(line[0])
                    tmp_ent_list = list(tmp_ent_list)
                    tmp_head = random.choice(tmp_ent_list)
                    tmp_triple_str = tmp_head + '\t' + line[1] + '\t' + line[2]
                    if tmp_triple_str not in lines_str_set:
                            break
                examples.append(InputExample(guid=guidP, text_a=head_ent_text, text_b=relation_text, text_c = tail_ent_text, label="1"))
                train_with_corrupts.append(Triple(guid=guidP,subject= line[0], predicate=line[1], obj=line[2], label="1"))

                tmp_head_text = ent2text[tmp_head]       
                examples.append(InputExample(guid=guidN, text_a=tmp_head_text, text_b=relation_text, text_c = relation_text, label="0")) 
                train_with_corrupts.append(Triple(guid=guidN,subject=tmp_head, predicate=line[1], obj=line[2], label="0"))
            else:
                tmp_tail = ''
                while True:
                    tmp_ent_list = set(entities)
                    tmp_ent_list.remove(line[2])
                    tmp_ent_list = list(tmp_ent_list)
                    tmp_tail = random.choice(tmp_ent_list)
                    tmp_triple_str = line[0] + '\t' + line[1] + '\t' + tmp_tail
                    if tmp_triple_str not in lines_str_set:
                            break
                examples.append(InputExample(guid=guidP, text_a=head_ent_text, text_b=relation_text, text_c = tail_ent_text, label="1"))
                train_with_corrupts.append(Triple(guid=guidP,subject= line[0], predicate=line[1], obj=line[2], label="1"))

                tmp_tail_text = ent2text[tmp_tail]
                examples.append(InputExample(guid=guidN, text_a=head_ent_text, text_b=relation_text, text_c = tmp_tail_text, label="0"))  
                train_with_corrupts.append(Triple(guid=guidN,subject= line[0] , predicate=line[1], obj=tmp_tail, label="0"))
        
        with open(examples_file, "w", encoding="utf-8") as writer:
                for sample in examples:
                    writer.write("%s\t%s\t%s\t%s\t%s\n" % (sample.guid, sample.text_a, sample.text_b, sample.text_c, sample.label))
        
        #Generando los ejemplos a utilizar en entrenamiento
        with open(train_with_corrupts_file, "w", encoding="utf-8") as writer:
                for triple in train_with_corrupts:
                   writer.write("%s\t%s\t%s\t%s\n" % (triple.subject, triple.predicate, triple.obj, triple.label))
        return examples 

In [None]:
def main():
        
    arg_dict ={
        "task_name": "kg",
        "data_dir": data_dir,      
        }
    processors = {
        "kg": KGProcessor,
        }  
    task_name = arg_dict["task_name"].lower()
    processor = processors[task_name]()
 
    train_examples = processor.get_train_examples(arg_dict["data_dir"])   
    print("len(train_examples)",len(train_examples))  

In [None]:
main()