<a href="https://colab.research.google.com/github/yalopez84/GAN_study/blob/master/Estudiando_KBGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [88]:
#Libraries
import os
import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.optim import Adam, SGD, Adagrad
from torch.autograd import Variable
from random import randint
from collections import defaultdict
from numpy.random import choice, randint
import numpy as n
import datetime
import yaml
import sys
import logging
import subprocess
from collections import namedtuple
from itertools import count

In [89]:
#Directories and devices
from google.colab import drive
drive.mount('/content/drive')
data_dir="/content/drive/MyDrive/NegativeStrategies/OAGAN-NS/data/"
os.chdir(data_dir)
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print("Device",device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Device cuda


In [90]:
import pdb
#base_model.py
class BaseModule(nn.Module):
    def __init__(self):
        super(BaseModule,self).__init__()
    def score(self,src,rel,dst):
        raise NotImplementedError
    def dist(self,src,rel,dst):
        raise NotImplementedError
    def prob_logit(self,src,rel,dst):
        raise NotImplementedError
    def prob(self,src,rel,dst):
        return f.softmax(self.prob_logit(src,rel,dst))
    def constraint(self):
        pass
    def pair_loss(self,src,rel,dst, src_bad,dst_bad):
        d_good=self.dist(src,rel,dst)
        d_bad=self.dist(src_bad,rel,dst_bad)
        return f.relu(self.margin + d_good - d_bad)

    def softmax_loss(self, src, rel, dst, truth):
        probs=self.prob(src,rel,dst)
        n=probs.size(0)
        truth_probs=torch.log(probs[torch.arange(0,n).type(torch.LongTensor).cuda(),truth]+1e-30)
        return -truth_probs

class BaseModel(object):
    def __init__(self):
        self.mdl= None
        self.weight_decay = 0
    def save(self,filename):
        torch.save(self.mdl.state_dict(),filename)

    def load(self,filename):
        self.mdl.load_state_dict(torch.load(filename,map_location=lambda storage, location:storage.cuda()))

    def gen_step(self,src,rel,dst,n_sample=1,temperature=1.0,train=True):
        if not hasattr(self,'opt'):
            self.opt=Adam(self.mdl.parameters(), weight_decay=self.weight_decay)
        n,m=dst.size()
        rel_var=Variable(rel.cuda())
        src_var = Variable(src.cuda())
        dst_var = Variable(dst.cuda())
        logits=self.mdl.prob_logit(src_var, rel_var, dst_var)/temperature
        probs=f.softmax(logits)
        row_idx=torch.arange(0,n).type(torch.LongTensor).unsqueeze(1).expand(n, n_sample)
        sample_idx=torch.multinomial(probs,n_sample, replacement=True)
        sample_srcs = src[row_idx, sample_idx.data.cpu()]
        sample_dsts = dst[row_idx, sample_idx.data.cpu()]
        rewards = yield sample_srcs, sample_dsts
        if train:
            self.mdl.zero_grad()
            log_probs = f.log_softmax(logits)
            reinforce_loss = -torch.sum(Variable(rewards) * log_probs[row_idx.cuda(), sample_idx.data])
            reinforce_loss.backward()
            self.opt.step()
            self.mdl.constraint()
        yield None

    def dis_step(self,src, rel, dst, src_fake, dst_fake, train=True):
        if not hasattr(self,'opt'):
            self.opt = Adam(self.mdl.parameters(), weight_decay=self.weight_decay)
        src_var = Variable(src.cuda())
        rel_var = Variable(rel.cuda())
        dst_var = Variable(dst.cuda())
        src_fake_var = Variable(src_fake.cuda())
        dst_fake_var = Variable(dst_fake.cuda())
        losses = self.mdl.pair_loss(src_var, rel_var, dst_var, src_fake_var, dst_fake_var)
        fake_scores = self.mdl.score(src_fake_var, rel_var, dst_fake_var)
        if train:
            self.mdl.zero_grad()
            torch.sum(losses).backward()
            self.opt.step()
            self.mdl.constraint()
        return losses.data, -fake_scores.data

    def test_link(self, test_data, n_ent, heads, tails, filt=True):
        mrr_tot = 0
        mr_tot = 0
        hit10_tot = 0
        count = 0
        for batch_s, batch_r, batch_t in batch_by_size(config().test_batch_size, *test_data):
            batch_size = batch_s.size(0)
            rel_var = Variable(batch_r.unsqueeze(1).expand(batch_size, n_ent).cuda())
            src_var = Variable(batch_s.unsqueeze(1).expand(batch_size, n_ent).cuda())
            dst_var = Variable(batch_t.unsqueeze(1).expand(batch_size, n_ent).cuda())
            all_var = Variable(torch.arange(0, n_ent).unsqueeze(0).expand(batch_size, n_ent)
                               .type(torch.LongTensor).cuda(), volatile=True)
            with torch.no_grad():
                batch_dst_scores = self.mdl.score(src_var, rel_var, all_var).data
                batch_src_scores = self.mdl.score(all_var, rel_var, dst_var).data

            for s, r, t, dst_scores, src_scores in zip(batch_s, batch_r, batch_t, batch_dst_scores, batch_src_scores):

                #if filt:
                    #if tails[(s.item(), r.item())]._nnz() > 1:
                        #tmp = dst_scores[t]
                        #dst_scores += tails[(s.item(), r.item())].cuda() * 1e30
                        #dst_scores[t] = tmp

                    #if heads[(t.item(), r.item())]._nnz() > 1:
                       # tmp = src_scores[s]
                        #src_scores += heads[(t.item(), r.item())].cuda() * 1e30
                        #src_scores[s] = tmp
                mrr, mr, hit10 =mrr_mr_hitk(dst_scores, t)
                mrr_tot += mrr
                mr_tot += mr
                hit10_tot += hit10
                mrr, mr, hit10 = mrr_mr_hitk(src_scores, s)
                mrr_tot += mrr
                mr_tot += mr
                hit10_tot += hit10
                count += 2
        logging.info('Test_MRR=%f, Test_MR=%f, Test_H@10=%f', mrr_tot / count, mr_tot / count, hit10_tot / count)
        return mrr_tot / count



In [91]:
#data_utils.py
def heads_tails(n_ent, train_data, valid_data=None, test_data=None):
    train_src, train_rel, train_dst = train_data
    if valid_data:
        valid_src, valid_rel, valid_dst = valid_data
    else:
        valid_src = valid_rel = valid_dst = []
    if test_data:
        test_src, test_rel, test_dst = test_data
    else:
        test_src = test_rel = test_dst = []
    all_src = train_src + valid_src + test_src
    all_rel = train_rel + valid_rel + test_rel
    all_dst = train_dst + valid_dst + test_dst
    heads = defaultdict(lambda: set())
    tails = defaultdict(lambda: set())
    for s, r, t in zip(all_src, all_rel, all_dst):
        tails[(s, r)].add(t)
        heads[(t, r)].add(s)
    heads_sp = {}
    tails_sp = {}
    for k in tails.keys():
        tails_sp[k] = torch.sparse.FloatTensor(torch.LongTensor([list(tails[k])]),
                                               torch.ones(len(tails[k])), torch.Size([n_ent]))
    for k in heads.keys():
        heads_sp[k] = torch.sparse.FloatTensor(torch.LongTensor([list(heads[k])]),
                                               torch.ones(len(heads[k])), torch.Size([n_ent]))
    return heads_sp, tails_sp


def inplace_shuffle(*lists):
    idx = []
    for i in range(len(lists[0])):
        a=i
        if a==0:
            a=1
        idx.append(randint(0, a))
    for ls in lists:
        for i, item in enumerate(ls):
            j = idx[i]
            ls[i], ls[j] = ls[j], ls[i]


def batch_by_num(n_batch, *lists, n_sample=None):
    if n_sample is None:
        n_sample = len(lists[0])
    for i in range(n_batch):
        head = int(n_sample * i / n_batch)
        tail = int(n_sample * (i + 1) / n_batch)
        ret = [ls[head:tail] for ls in lists]
        if len(ret) > 1:
            yield ret
        else:
            yield ret[0]

def batch_by_size(batch_size, *lists, n_sample=None):
    if n_sample is None:
        n_sample = len(lists[0])
    head = 0
    while head < n_sample:
        tail = min(n_sample, head + batch_size)
        ret = [ls[head:tail] for ls in lists]
        head += batch_size
        if len(ret) > 1:
            yield ret
        else:
            yield ret[0]

In [92]:
#config.py
class ConfigDict(dict):
    __getattr__ = dict.__getitem__

def _make_config_dict(obj):
    if isinstance(obj, dict):
        return ConfigDict({k: _make_config_dict(v) for k, v in obj.items()})
    elif isinstance(obj, list):
        return [_make_config_dict(x) for x in obj]
    else:
        return obj

_config = None

def config():
    arg_dict ={
        '--config':'config_fb15k237.yaml',
        '--pretrain_config':'<model_name>'
    }
    global _config
    if _config is None:
        config_path='config_fb15k237.yaml'
        with open(os.path.join(data_dir, config_path)) as f:
            _config = _make_config_dict(yaml.full_load(f))
    return _config

def _dump_config(obj, prefix):
    if isinstance(obj, dict):
        for k, v in obj.items():
            _dump_config(v, prefix + (k,))
    elif isinstance(obj, list):
        for i, v in enumerate(obj):
            _dump_config(v, prefix + (str(i),))
    else:
        if isinstance(obj, str):
            rep = obj
        else:
            rep = repr(obj)
        logging.debug('%s=%s', '.'.join(prefix), rep)

def dump_config():
    return _dump_config(_config, tuple())

In [93]:
#metrics.py
def mrr_mr_hitk(scores, target, k=10):
    values, sorted_idx = torch.sort(scores, descending=True)
    find_target = sorted_idx == target
    target_rank = torch.nonzero(find_target)[0, 0] + 1
    return 1 / target_rank, target_rank, int(target_rank <= k)

In [94]:
#corrupter.py
def get_bern_prob(data, n_ent, n_rel):
    src, rel, dst = data
    edges = defaultdict(lambda: defaultdict(lambda: set()))
    rev_edges = defaultdict(lambda: defaultdict(lambda: set()))
    for s, r, t in zip(src, rel, dst):
        edges[r][s].add(t)
        rev_edges[r][t].add(s)
    bern_prob = torch.zeros(n_rel)
    i=0
    for r in edges.keys():
        tph = sum(len(tails) for tails in edges[r].values()) / len(edges[r])
        htp = sum(len(heads) for heads in rev_edges[r].values()) / len(rev_edges[r])
        bern_prob[r] = tph / (tph + htp)

    return bern_prob

class BernCorrupter(object):
    def __init__(self, data, n_ent, n_rel):
        self.bern_prob = get_bern_prob(data, n_ent, n_rel)
        self.n_ent = n_ent

    def corrupt(self, src, rel, dst):
        prob = self.bern_prob[rel]
        selection = torch.bernoulli(prob).numpy().astype('int64')
        ent_random = choice(self.n_ent, len(src))
        src_out = (1 - selection) * src.numpy() + selection * ent_random
        dst_out = selection * dst.numpy() + (1 - selection) * ent_random
        return torch.from_numpy(src_out), torch.from_numpy(dst_out)

class BernCorrupterMulti(object):
    def __init__(self, data, n_ent, n_rel, n_sample):
        self.bern_prob = get_bern_prob(data, n_ent, n_rel)
        self.n_ent = n_ent
        self.n_sample = n_sample

    def corrupt(self, src, rel, dst, keep_truth=True):
        n = len(src)
        prob = self.bern_prob[rel]
        selection = torch.bernoulli(prob).numpy().astype('bool')
        src_out = n.tile(src.numpy(), (self.n_sample, 1)).transpose()
        dst_out = n.tile(dst.numpy(), (self.n_sample, 1)).transpose()
        rel_out = rel.unsqueeze(1).expand(n, self.n_sample)
        if keep_truth:
            ent_random = choice(self.n_ent, (n, self.n_sample - 1))
            src_out[selection, 1:] = ent_random[selection]
            dst_out[~selection, 1:] = ent_random[~selection]
        else:
            ent_random = choice(self.n_ent, (n, self.n_sample))
            src_out[selection, :] = ent_random[selection]
            dst_out[~selection, :] = ent_random[~selection]
        return torch.from_numpy(src_out), rel_out, torch.from_numpy(dst_out)

In [95]:
#read_data.py
KBIndex = namedtuple('KBIndex', ['ent_list', 'rel_list', 'ent_id', 'rel_id'])

def index_ent_rel(*filenames):
    ent_set = set()
    rel_set = set()
    for filename in filenames:
        with open(filename) as f:
            for ln in f:
                s, r, t = ln.strip().split('\t')[:3]
                ent_set.add(s)
                ent_set.add(t)
                rel_set.add(r)
    ent_list = sorted(list(ent_set))
    rel_list = sorted(list(rel_set))
    ent_id = dict(zip(ent_list, count()))
    rel_id = dict(zip(rel_list, count()))
    return KBIndex(ent_list, rel_list, ent_id, rel_id)


def graph_size(kb_index):
    return len(kb_index.ent_id), len(kb_index.rel_id)


def read_data(filename, kb_index):
    src = []
    rel = []
    dst = []
    with open(filename) as f:
        for ln in f:
            s, r, t = ln.strip().split('\t')
            src.append(kb_index.ent_id[s])
            rel.append(kb_index.rel_id[r])
            dst.append(kb_index.ent_id[t])
    return src, rel, dst

In [96]:
#trans_e.py
class TransEModule(BaseModule):
    def __init__(self, n_ent, n_rel, config):
        super(TransEModule, self).__init__()
        self.p = config.p
        self.margin = config.margin
        self.temp = config.get('temp', 1)
        self.rel_embed = nn.Embedding(n_rel, config.dim)
        self.ent_embed = nn.Embedding(n_ent, config.dim)
        self.init_weight()

    def init_weight(self):
        for param in self.parameters():
            param.data.normal_(1 / param.size(1) ** 0.5)
            param.data.renorm_(2, 0, 1)

    def forward(self, src, rel, dst):
        return torch.norm(self.ent_embed(dst) - self.ent_embed(src) - self.rel_embed(rel) + 1e-30, p=self.p, dim=-1)

    def dist(self, src, rel, dst):
        return self.forward(src, rel, dst)

    def score(self, src, rel, dst):
        return self.forward(src, rel, dst)

    def prob_logit(self, src, rel, dst):
        return -self.forward(src, rel ,dst) / self.temp

    def constraint(self):
        self.ent_embed.weight.data.renorm_(2, 0, 1)
        self.rel_embed.weight.data.renorm_(2, 0, 1)

class TransE(BaseModel):
    def __init__(self, n_ent, n_rel, config):
        super(TransE, self).__init__()
        self.mdl = TransEModule(n_ent, n_rel, config)
        self.mdl.cuda()
        self.config = config

    def pretrain(self, train_data, corrupter, tester):
        src, rel, dst = train_data
        n_train = len(src)
        optimizer = Adam(self.mdl.parameters())
        #optimizer = SGD(self.mdl.parameters(), lr=1e-4)
        n_epoch = self.config.n_epoch
        n_batch = self.config.n_batch
        best_perf = 0
        for epoch in range(n_epoch):
            print('---------epoch_number---', epoch)
            epoch_loss = 0
            rand_idx = torch.randperm(n_train)
            src = src[rand_idx]
            rel = rel[rand_idx]
            dst = dst[rand_idx]
            src_corrupted, dst_corrupted = corrupter.corrupt(src, rel, dst)
            src_cuda = src.cuda()
            rel_cuda = rel.cuda()
            dst_cuda = dst.cuda()
            src_corrupted = src_corrupted.cuda()
            dst_corrupted = dst_corrupted.cuda()
            for s0, r, t0, s1, t1 in batch_by_num(n_batch, src_cuda, rel_cuda, dst_cuda, src_corrupted, dst_corrupted,
                                                  n_sample=n_train):
                self.mdl.zero_grad()
                loss = torch.sum(self.mdl.pair_loss(Variable(s0), Variable(r), Variable(t0), Variable(s1), Variable(t1)))
                loss.backward()
                optimizer.step()
                self.mdl.constraint()
                epoch_loss += loss.item()
                print('epoch_loss', epoch_loss)
            logging.info('Epoch %d/%d, Loss=%f', epoch + 1, n_epoch, epoch_loss / n_train)
            if (epoch + 1) % self.config.epoch_per_test == 0:
                print('probando tester')
                test_perf = tester()
                print('El MRR en epoch ',epoch,'fue de :', test_perf)
                if test_perf > best_perf:
                    task_dir = config().task.dir
                    direccion=os.path.join(data_dir, task_dir)
                    self.save(os.path.join(direccion, self.config.model_file))
                    #self.save(os.path.join(direccion, self.config.save_to))
                    best_perf = test_perf
        return best_perf

In [97]:
#pretrain.py
def pretrain():
    task_dir = config().task.dir
    direccion=os.path.join(data_dir, task_dir)
    kb_index = index_ent_rel(os.path.join(direccion, 'train.txt'),
                        os.path.join(direccion, 'valid.txt'),
                        os.path.join(direccion, 'test.txt'))
    n_ent, n_rel = graph_size(kb_index)
    train_data = read_data(os.path.join(direccion, 'train.txt'), kb_index)
    inplace_shuffle(*train_data)
    valid_data = read_data(os.path.join(direccion, 'valid.txt'), kb_index)
    test_data = read_data(os.path.join(direccion, 'test.txt'), kb_index)

    heads, tails = heads_tails(n_ent, train_data, valid_data, test_data)
    valid_data = [torch.LongTensor(vec) for vec in valid_data]
    test_data = [torch.LongTensor(vec) for vec in test_data]
    tester = lambda: gen.test_link(valid_data, n_ent, heads, tails)
    train_data = [torch.LongTensor(vec) for vec in train_data]
    mdl_type = config().pretrain_config
    gen_config = config()[mdl_type]
    if mdl_type == 'TransE':
        corrupter = BernCorrupter(train_data, n_ent, n_rel)
        gen = TransE(n_ent, n_rel, gen_config)
    elif mdl_type == 'TransD':
        corrupter = BernCorrupter(train_data, n_ent, n_rel)
        #gen = TransD(n_ent, n_rel, gen_config)
    elif mdl_type == 'DistMult':
        corrupter = BernCorrupterMulti(train_data, n_ent, n_rel, gen_config.n_sample)
        #gen = DistMult(n_ent, n_rel, gen_config)
    elif mdl_type == 'ComplEx':
        corrupter = BernCorrupterMulti(train_data, n_ent, n_rel, gen_config.n_sample)
        #gen = ComplEx(n_ent, n_rel, gen_config)
    result=gen.pretrain(train_data, corrupter, tester)
    print('El best MRR fue de :', result)
    print('Cargando modelo preentrenado')
    gen.load(os.path.join(direccion, gen_config.model_file))
    print('Ejecutando pruebas')
    result=gen.test_link(test_data, n_ent, heads, tails)
    print('El MRR para el conjunto de pruebas en el modelo entrenado con :', gen_config.n_epoch,'epochs es de: ',result)

In [98]:
if __name__ == "__main__":
    pretrain()
    print('Fin')


---------epoch_number--- 0
epoch_loss 8169.421875
epoch_loss 16308.8974609375
epoch_loss 24489.46728515625
epoch_loss 32658.96923828125
epoch_loss 40818.22216796875
epoch_loss 48873.00634765625
epoch_loss 56976.3564453125
epoch_loss 65152.31884765625
epoch_loss 73290.87255859375
epoch_loss 81449.64453125
epoch_loss 89490.61669921875
epoch_loss 97768.07763671875
epoch_loss 105860.23974609375
epoch_loss 114013.388671875
epoch_loss 122114.90869140625
epoch_loss 130228.6474609375
epoch_loss 138379.607421875
epoch_loss 146450.353515625
epoch_loss 154596.80908203125
epoch_loss 162757.85693359375
epoch_loss 170800.66162109375
epoch_loss 178903.13232421875
epoch_loss 186991.892578125
epoch_loss 195178.490234375
epoch_loss 203267.900390625
epoch_loss 211333.6240234375
epoch_loss 219596.8798828125
epoch_loss 227707.64404296875
epoch_loss 235776.927734375
epoch_loss 243848.08251953125
epoch_loss 251878.78369140625
epoch_loss 259932.53076171875
epoch_loss 268120.94287109375
epoch_loss 276308.06298

  all_var = Variable(torch.arange(0, n_ent).unsqueeze(0).expand(batch_size, n_ent)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
epoch_loss 27594.15380859375
epoch_loss 34509.810546875
epoch_loss 41434.50537109375
epoch_loss 48379.1748046875
epoch_loss 55279.9013671875
epoch_loss 62106.9921875
epoch_loss 69121.48779296875
epoch_loss 76051.90283203125
epoch_loss 82989.689453125
epoch_loss 89847.240234375
epoch_loss 96745.24462890625
epoch_loss 103556.99365234375
epoch_loss 110521.89794921875
epoch_loss 117422.25048828125
epoch_loss 124269.17333984375
epoch_loss 131190.17578125
epoch_loss 138182.9228515625
epoch_loss 145057.2041015625
epoch_loss 151824.68505859375
epoch_loss 158656.31201171875
epoch_loss 165612.82373046875
epoch_loss 172495.25732421875
epoch_loss 179342.11328125
epoch_loss 186283.6572265625
epoch_loss 193147.83203125
epoch_loss 200012.25244140625
epoch_loss 206912.84423828125
epoch_loss 213759.2099609375
epoch_loss 220631.865234375
epoch_loss 227438.49609375
epoch_loss 234290.6953125
epoch_loss 241247.15283203125
epoch_loss 248124.15