In [1]:
import torch
import torch.nn as nn

import io

import json

In [2]:
import re
import io

LOWER = False
DIGIT_0 = False
UNK_TOKEN = "#UNK#"

BRACKETS = {"-LCB-": "{", "-LRB-": "(", "-LSB-": "[", "-RCB-": "}", "-RRB-": ")", "-RSB-": "]"}


class Vocabulary:
    unk_token = UNK_TOKEN

    def __init__(self):
        self.word2id = {}
        self.id2word = []
        self.counts = []
        self.unk_id = 0

    @staticmethod
    def normalize(token, lower=LOWER, digit_0=DIGIT_0):
        if token in [Vocabulary.unk_token, "<s>", "</s>"]:
            return token
        elif token in BRACKETS:
            token = BRACKETS[token]
        else:
            if digit_0:
                token = re.sub("[0-9]", "0", token)

        if lower:
            return token.lower()
        else:
            return token

    @staticmethod
    def load(path):
        voca = Vocabulary()
        voca.load_from_file(path)
        return voca

    def load_from_file(self, path):
        self.word2id = {}
        self.id2word = []
        self.counts = []

        f = io.open(path, "r", encoding='utf-8', errors='ignore')
        for line in f:
            line = line.strip()
            comps = line.split('\t')
            if len(comps) == 0 or len(comps) > 2: 
                raise Exception('sthing wrong')

            token = Vocabulary.normalize(comps[0].strip())
            self.id2word.append(token)
            self.word2id[token] = len(self.id2word) - 1

            if len(comps) == 2:
                self.counts.append(float(comps[1]))
            else: 
                self.counts.append(1)

        f.close()

        if Vocabulary.unk_token not in self.word2id:
            self.id2word.append(Vocabulary.unk_token)
            self.word2id[Vocabulary.unk_token] = len(self.id2word) - 1
            self.counts.append(1)
            
        self.unk_id = self.word2id[Vocabulary.unk_token]

    def size(self):
        return len(self.id2word)

    def get_id(self, token):
        tok = Vocabulary.normalize(token)
        return self.word2id.get(tok, self.unk_id)


In [3]:
## utils
import numpy as np


############################## removing stopwords #######################

STOPWORDS = {'a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'all',
             'almost', 'alone', 'along', 'already', 'also', 'although', 'always', 'am', 'among',
             'amongst', 'amoungst', 'amount', 'an', 'and', 'another', 'any', 'anyhow', 'anyone',
             'anything', 'anyway', 'anywhere', 'are', 'around', 'as', 'at', 'back', 'be',
             'became', 'because', 'become', 'becomes', 'becoming', 'been', 'before', 'beforehand',
             'behind', 'being', 'below', 'beside', 'besides', 'between', 'beyond', 'both', 'bottom',
             'but', 'by', 'call', 'can', 'cannot', 'cant', 'dont', 'co', 'con', 'could', 'couldnt',
             'cry', 'de', 'describe', 'detail', 'do', 'done', 'down', 'due', 'during', 'each', 'eg',
             'eight', 'either', 'eleven', 'else', 'elsewhere', 'empty', 'enough', 'etc', 'even',
             'ever', 'every', 'everyone', 'everything', 'everywhere', 'except', 'few', 'fifteen',
             'fify', 'fill', 'find', 'fire', 'first', 'five', 'for', 'former', 'formerly', 'forty',
             'found', 'four', 'from', 'front', 'full', 'further', 'get', 'give', 'go', 'had',
             'has', 'hasnt', 'have', 'he', 'hence', 'her', 'here', 'hereafter', 'hereby', 'herein',
             'hereupon', 'hers', 'herself', 'him', 'himself', 'his', 'how', 'however', 'hundred',
             'i', 'ie', 'if', 'in', 'inc', 'indeed', 'interest', 'into', 'is', 'it', 'its', 'itself',
             'keep', 'last', 'latter', 'latterly', 'least', 'less', 'ltd', 'made', 'many', 'may',
             'me', 'meanwhile', 'might', 'mill', 'mine', 'more', 'moreover', 'most', 'mostly',
             'move', 'much', 'must', 'my', 'myself', 'name', 'namely', 'neither', 'never', 'nevertheless',
             'next', 'nine', 'no', 'nobody', 'none', 'noone', 'nor', 'not', 'nothing', 'now',
             'nowhere', 'of', 'off', 'often', 'on', 'once', 'one', 'only', 'onto', 'or', 'other',
             'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'own', 'part', 'per',
             'perhaps', 'please', 'put', 'rather', 're', 'same', 'see', 'seem', 'seemed', 'seeming',
             'seems', 'serious', 'several', 'she', 'should', 'show', 'side', 'since', 'sincere', 'six',
             'sixty', 'so', 'some', 'somehow', 'someone', 'something', 'sometime', 'sometimes',
             'somewhere', 'still', 'such', 'system', 'take', 'ten', 'than', 'that', 'the', 'their',
             'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore',
             'therein', 'thereupon', 'these', 'they', 'thick', 'thin', 'third', 'this', 'those', 'though',
             'three', 'through', 'throughout', 'thru', 'thus', 'to', 'together', 'too', 'top', 'toward',
             'towards', 'twelve', 'twenty', 'two', 'un', 'under', 'until', 'up', 'upon', 'us', 'very',
             'via', 'was', 'we', 'well', 'were', 'what', 'whatever', 'when', 'whence', 'whenever',
             'where', 'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether',
             'which', 'while', 'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'will',
             'with', 'within', 'without', 'would', 'yet', 'you', 'your', 'yours', 'yourself', 'yourselves',
             'st', 'years', 'yourselves', 'new', 'used', 'known', 'year', 'later', 'including', 'used',
             'end', 'did', 'just', 'best', 'using'}


def is_important_word(s):
    """
    an important word is not a stopword, a number, or len == 1
    """
    try:
        if len(s) <= 1 or s.lower() in STOPWORDS:
            return False
        float(s)
        return False
    except:
        return True


def is_stopword(s):
    return s.lower() in STOPWORDS


############################### coloring ###########################

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


def tokgreen(s):
    return bcolors.OKGREEN + s + bcolors.ENDC


def tfail(s):
    return bcolors.FAIL + s + bcolors.ENDC


def tokblue(s):
    return bcolors.OKBLUE + s + bcolors.ENDC


############################ process list of lists ###################

def flatten_list_of_lists(list_of_lists):
    """
    making inputs to torch.nn.EmbeddingBag
    """
    list_of_lists = [[]] + list_of_lists
    offsets = np.cumsum([len(x) for x in list_of_lists])[:-1]
    flatten = sum(list_of_lists[1:], [])
    return flatten, offsets


def load_voca_embs(voca_path, embs_path):
    voca = Vocabulary.load(voca_path)
    embs = np.load(embs_path)

    # check if sizes are matched
    if embs.shape[0] == voca.size() - 1:
        unk_emb = np.mean(embs, axis=0, keepdims=True)
        embs = np.append(embs, unk_emb, axis=0)
    elif embs.shape[0] != voca.size():
        print(embs.shape, voca.size())
        raise Exception("embeddings and vocabulary have differnt number of items ")

    return voca, embs


def make_equal_len(lists, fill_in=0, to_right=True):
    lens = [len(l) for l in lists]
    max_len = max(1, max(lens))
    if to_right:
        eq_lists = [l + [fill_in] * (max_len - len(l)) for l in lists]
        mask = [[1.] * l + [0.] * (max_len - l) for l in lens]
    else:
        eq_lists = [[fill_in] * (max_len - len(l)) + l for l in lists]
        mask = [[0.] * (max_len - l) + [1.] * l for l in lens]
    return eq_lists, mask

In [4]:
def load(path, model_class, suffix=''):
    with io.open(path + '.config', 'r', encoding='utf8') as f:
        config = json.load(f)

    word_voca = Vocabulary()
    word_voca.__dict__ = config['word_voca']
    config['word_voca'] = word_voca
    entity_voca = Vocabulary()
    entity_voca.__dict__ = config['entity_voca']
    config['entity_voca'] = entity_voca

    if 'snd_word_voca' in config:
        snd_word_voca = Vocabulary()
        snd_word_voca.__dict__ = config['snd_word_voca']
        config['snd_word_voca'] = snd_word_voca

    model = model_class(config)
    model.load_state_dict(torch.load(path + '.state_dict' + suffix))
    return model


class AbstractWordEntity(nn.Module):
    """
    abstract class containing word and entity embeddings and vocabulary
    """

    def __init__(self, config=None):
        super(AbstractWordEntity, self).__init__()
        if config is None:
            return

        self.emb_dims = config['emb_dims']
        self.word_voca = config['word_voca']
        self.entity_voca = config['entity_voca']
        self.freeze_embs = config['freeze_embs']

        self.word_embeddings = config['word_embeddings_class'](self.word_voca.size(), self.emb_dims)
        self.entity_embeddings = config['entity_embeddings_class'](self.entity_voca.size(), self.emb_dims)

        if 'word_embeddings' in config:
            self.word_embeddings.weight = nn.Parameter(torch.Tensor(config['word_embeddings']))
        if 'entity_embeddings' in config:
            self.entity_embeddings.weight = nn.Parameter(torch.Tensor(config['entity_embeddings']))

        if 'snd_word_voca' in config:
            self.snd_word_voca = config['snd_word_voca']
            self.snd_word_embeddings = config['word_embeddings_class'](self.snd_word_voca.size(), self.emb_dims)
        if 'snd_word_embeddings' in config:
            self.snd_word_embeddings.weight = nn.Parameter(torch.Tensor(config['snd_word_embeddings']))

        if self.freeze_embs:
            self.word_embeddings.weight.requires_grad = False
            self.entity_embeddings.weight.requires_grad = False
            if 'snd_word_voca' in config:
                self.snd_word_embeddings.weight.requires_grad = False

    def print_weight_norm(self):
        pass

    def save(self, path, suffix='', save_config=True):
        torch.save(self.state_dict(), path + '.state_dict' + suffix)

        if save_config:
            config = {'word_voca': self.word_voca.__dict__,
                      'entity_voca': self.entity_voca.__dict__}
            if 'snd_word_voca' in self.__dict__:
                config['snd_word_voca'] = self.snd_word_voca.__dict__

            for k, v in self.__dict__.items():
                if not hasattr(v, '__dict__'):
                    config[k] = v

            with io.open(path + '.config', 'w', encoding='utf8') as f:
                json.dump(config, f)

    def load_params(self, path, param_names):
        params = torch.load(path)
        for pname in param_names:
            self._parameters[pname].data = params[pname]

    def loss(self, scores, grth):
        pass


In [5]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable

from nel.local_ctx_att_ranker import LocalCtxAttRanker
import numpy as np


class STArgmax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, scores):
        max_values, _ = scores.max(dim=-1, keepdim=True)
        return (scores >= max_values).float()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


class MulRelRanker(LocalCtxAttRanker):
    """
    multi-relational global model with context token attention, using loopy belief propagation
    """

    def __init__(self, config):
        super(MulRelRanker, self).__init__(config)
        self.df = config['df']  # damping factor
        self.n_loops = config['n_loops']
        self.n_rels = config['n_rels']
        self.dr = config['dr']
        self.ew_hid_dims = self.emb_dims

        self.max_dist = 1000
        self.ent_top_n = 1000

        self.oracle = config.get('oracle', False)
        self.ent_ent_comp = config.get('ent_ent_comp', 'bilinear')  # bilinear, trans_e, fbilinear
        self.ctx_comp = config.get('ctx_comp', 'bow')  # bow or rnn

        self.mode = config.get('mulrel_type', 'ment-norm')  # ment-norm, rel-norm

        # options for ment-norm
        self.first_head_uniform = config.get('first_head_uniform', False)
        self.use_pad_ent = config.get('use_pad_ent', False)

        # options for rel-norm
        self.use_stargmax = config.get('use_stargmax', False)

        self.use_local = config.get('use_local', False)
        self.use_local_only = config.get('use_local_only', False)
        self.freeze_local = config.get('freeze_local', False)

        if self.freeze_local:
            self.att_mat_diag.requires_grad = False
            self.tok_score_mat_diag.requires_grad = False

        if self.use_local:
            self.ent_localctx_comp = torch.nn.Parameter(torch.ones(self.emb_dims))

        if self.use_pad_ent:
            self.pad_ent_emb = torch.nn.Parameter(torch.randn(1, self.emb_dims) * 0.1)
            self.pad_ctx_vec = torch.nn.Parameter(torch.randn(1, self.emb_dims) * 0.1)

        self.ctx_layer = torch.nn.Sequential(
                torch.nn.Linear(self.emb_dims * 3, self.ew_hid_dims),
                torch.nn.Tanh(),
                torch.nn.Dropout(p=self.dr))

        self.rel_embs = torch.randn(self.n_rels, self.emb_dims) * 0.01
        if self.ent_ent_comp == 'bilinear':
            self.rel_embs[0] = 1 + torch.randn(self.emb_dims) * 0.01
            if self.mode == 'ment-norm' and self.n_rels > 1 and self.first_head_uniform:
                self.rel_embs[1] = 1
            if self.mode == 'rel-norm':
                self.rel_embs.fill_(0).add_(torch.randn(self.n_rels, self.emb_dims) * 0.1)

        self.rel_embs = torch.nn.Parameter(self.rel_embs)

        self.ew_embs = torch.nn.Parameter(torch.randn(self.n_rels, self.ew_hid_dims) *
                                          (0.01 if self.mode == 'ment-norm' else 0.1))

        self._coh_ctx_vecs = None

        self.score_combine = torch.nn.Sequential(
                torch.nn.Linear(2, self.hid_dims),
                torch.nn.ReLU(),
                torch.nn.Linear(self.hid_dims, 1))

        print('---------------- model config -----------------')
        for k, v in self.__dict__.items():
            if not hasattr(v, '__dict__'):
                print(k, v)
        print('-----------------------------------------------')

    def print_weight_norm(self):
        LocalCtxAttRanker.print_weight_norm(self)
        print(self.ctx_layer[0].weight.data.norm(), self.ctx_layer[0].bias.data.norm())
        print('relations', self.rel_embs.data.norm(p=2, dim=1))
        X = F.normalize(self.rel_embs)
        diff = (X.view(self.n_rels, 1, -1) - X.view(1, self.n_rels, -1)).pow(2).sum(dim=2).sqrt()
        print(diff)

        print('ew_embs', self.ew_embs.data.norm(p=2, dim=1))
        X = F.normalize(self.ew_embs)
        diff = (X.view(self.n_rels, 1, -1) - X.view(1, self.n_rels, -1)).pow(2).sum(dim=2).sqrt()
        print(diff)

    def forward(self, token_ids, tok_mask, entity_ids, entity_mask, p_e_m, gold=None):
        n_ments, n_cands = entity_ids.size()
        n_rels = self.n_rels

        if self.mode == 'ment-norm' and self.first_head_uniform:
            self.ew_embs.data[0] = 0

        if not self.oracle:
            gold = None

        if self.use_local:
            local_ent_scores = super(MulRelRanker, self).forward(token_ids, tok_mask,
                                                                 entity_ids, entity_mask,
                                                                 p_e_m=None)
            ent_vecs = self._entity_vecs
        else:
            ent_vecs = self.entity_embeddings(entity_ids)
            local_ent_scores = Variable(torch.zeros(n_ments, n_cands).cuda(), requires_grad=False)

        # compute context vectors
        ltok_vecs = self.snd_word_embeddings(self.s_ltoken_ids) * self.s_ltoken_mask.view(n_ments, -1, 1)
        local_lctx_vecs = torch.sum(ltok_vecs, dim=1) / torch.sum(self.s_ltoken_mask, dim=1, keepdim=True).add_(1e-5)
        rtok_vecs = self.snd_word_embeddings(self.s_rtoken_ids) * self.s_rtoken_mask.view(n_ments, -1, 1)
        local_rctx_vecs = torch.sum(rtok_vecs, dim=1) / torch.sum(self.s_rtoken_mask, dim=1, keepdim=True).add_(1e-5)
        mtok_vecs = self.snd_word_embeddings(self.s_mtoken_ids) * self.s_mtoken_mask.view(n_ments, -1, 1)
        ment_vecs = torch.sum(mtok_vecs, dim=1) / torch.sum(self.s_mtoken_mask, dim=1, keepdim=True).add_(1e-5)
        bow_ctx_vecs = torch.cat([local_lctx_vecs, ment_vecs, local_rctx_vecs], dim=1)

        if self.use_pad_ent:
            ent_vecs = torch.cat([ent_vecs, self.pad_ent_emb.view(1, 1, -1).repeat(1, n_cands, 1)], dim=0)
            tmp = torch.zeros(1, n_cands)
            tmp[0, 0] = 1
            tmp = Variable(tmp.cuda())
            entity_mask = torch.cat([entity_mask, tmp], dim=0)
            p_e_m = torch.cat([p_e_m, tmp], dim=0)
            local_ent_scores = torch.cat([local_ent_scores,
                                          Variable(torch.zeros(1, n_cands).cuda(), requires_grad=False)],
                                         dim=0)
            n_ments += 1

            if self.oracle:
                tmp = Variable(torch.zeros(1, 1).cuda().long())
                gold = torch.cat([gold, tmp], dim=0)

        if self.use_local_only:
            inputs = torch.cat([Variable(torch.zeros(n_ments * n_cands, 1).cuda()),
                                local_ent_scores.view(n_ments * n_cands, -1),
                                torch.log(p_e_m + 1e-20).view(n_ments * n_cands, -1)], dim=1)
            scores = self.score_combine(inputs).view(n_ments, n_cands)
            return scores

        if n_ments == 1:
            ent_scores = local_ent_scores

        else:
            # distance - to consider only neighbor mentions
            ment_pos = torch.arange(0, n_ments).long().cuda()
            dist = (ment_pos.view(n_ments, 1) - ment_pos.view(1, n_ments)).abs()
            dist.masked_fill_(dist == 1, -1)
            dist.masked_fill_((dist > 1) & (dist <= self.max_dist), -1)
            dist.masked_fill_(dist > self.max_dist, 0)
            dist.mul_(-1)

            ctx_vecs = self.ctx_layer(bow_ctx_vecs)
            if self.use_pad_ent:
                ctx_vecs = torch.cat([ctx_vecs, self.pad_ctx_vec], dim=0)

            m1_ctx_vecs, m2_ctx_vecs = ctx_vecs, ctx_vecs
            rel_ctx_vecs = m1_ctx_vecs.view(1, n_ments, -1) * self.ew_embs.view(n_rels, 1, -1)
            rel_ctx_ctx_scores = torch.matmul(rel_ctx_vecs, m2_ctx_vecs.view(1, n_ments, -1).permute(0, 2, 1))  # n_rels x n_ments x n_ments

            rel_ctx_ctx_scores = rel_ctx_ctx_scores.add_((1 - Variable(dist.float().cuda())).mul_(-1e10))
            eye = Variable(torch.eye(n_ments).cuda()).view(1, n_ments, n_ments)
            rel_ctx_ctx_scores.add_(eye.mul_(-1e10))
            rel_ctx_ctx_scores.mul_(1 / np.sqrt(self.ew_hid_dims))  # scaling proposed by "attention is all you need"

            # get top_n neighbour
            if self.ent_top_n < n_ments:
                topk_values, _ = torch.topk(rel_ctx_ctx_scores, k=min(self.ent_top_n, n_ments), dim=2)
                threshold = topk_values[:, :, -1:]
                mask = 1 - (rel_ctx_ctx_scores >= threshold).float()
                rel_ctx_ctx_scores.add_(mask.mul_(-1e10))

            if self.mode == 'ment-norm':
                rel_ctx_ctx_probs = F.softmax(rel_ctx_ctx_scores, dim=2)
                rel_ctx_ctx_weights = rel_ctx_ctx_probs + rel_ctx_ctx_probs.permute(0, 2, 1)
                self._rel_ctx_ctx_weights = rel_ctx_ctx_probs
            elif self.mode == 'rel-norm':
                ctx_ctx_rel_scores = rel_ctx_ctx_scores.permute(1, 2, 0).contiguous()
                if not self.use_stargmax:
                    ctx_ctx_rel_probs = F.softmax(ctx_ctx_rel_scores.view(n_ments * n_ments, n_rels))\
                                        .view(n_ments, n_ments, n_rels)
                else:
                    ctx_ctx_rel_probs = STArgmax.apply(ctx_ctx_rel_scores)
                self._rel_ctx_ctx_weights = ctx_ctx_rel_probs.permute(2, 0, 1).contiguous()

            # compute phi(ei, ej)
            if self.mode == 'ment-norm':
                if self.ent_ent_comp == 'bilinear':
                    if self.ent_ent_comp == 'bilinear':
                        rel_ent_vecs = ent_vecs.view(1, n_ments, n_cands, -1) * self.rel_embs.view(n_rels, 1, 1, -1)
                    elif self.ent_ent_comp == 'trans_e':
                        rel_ent_vecs = ent_vecs.view(1, n_ments, n_cands, -1) - self.rel_embs.view(n_rels, 1, 1, -1)
                    else:
                        raise Exception("unknown ent_ent_comp")

                    rel_ent_ent_scores = torch.matmul(rel_ent_vecs.view(n_rels, n_ments, 1, n_cands, -1),
                                                      ent_vecs.view(1, 1, n_ments, n_cands, -1).permute(0, 1, 2, 4, 3))
                    # n_rels x n_ments x n_ments x n_cands x n_cands

                rel_ent_ent_scores = rel_ent_ent_scores.permute(0, 1, 3, 2, 4)  # n_rel x n_ments x n_cands x n_ments x n_cands
                rel_ent_ent_scores = (rel_ent_ent_scores * entity_mask).add_((entity_mask - 1).mul_(1e10))
                ent_ent_scores = torch.sum(rel_ent_ent_scores *
                                           rel_ctx_ctx_weights.view(n_rels, n_ments, 1, n_ments, 1), dim=0)\
                                 .mul(1. / n_rels)  # n_ments x n_cands x n_ments x n_cands

            elif self.mode == 'rel-norm':
                rel_vecs = torch.matmul(ctx_ctx_rel_probs.view(n_ments, n_ments, 1, n_rels),
                                        self.rel_embs.view(1, 1, n_rels, -1))\
                           .view(n_ments, n_ments, -1)
                ent_rel_vecs = ent_vecs.view(n_ments, 1, n_cands, -1) * rel_vecs.view(n_ments, n_ments, 1, -1)  # n_ments x n_ments x n_cands x dims
                ent_ent_scores = torch.matmul(ent_rel_vecs,
                                              ent_vecs.view(1, n_ments, n_cands, -1).permute(0, 1, 3, 2))\
                                 .permute(0, 2, 1, 3)

            if gold is None:
                # LBP
                prev_msgs = Variable(torch.zeros(n_ments, n_cands, n_ments).cuda())

                for _ in range(self.n_loops):
                    mask = 1 - Variable(torch.eye(n_ments).cuda())
                    ent_ent_votes = ent_ent_scores + local_ent_scores * 1 + \
                                    torch.sum(prev_msgs.view(1, n_ments, n_cands, n_ments) *
                                              mask.view(n_ments, 1, 1, n_ments), dim=3)\
                                    .view(n_ments, 1, n_ments, n_cands)
                    msgs, _ = torch.max(ent_ent_votes, dim=3)
                    msgs = (F.softmax(msgs, dim=1).mul(self.dr) +
                            prev_msgs.exp().mul(1 - self.dr)).log()
                    prev_msgs = msgs

                # compute marginal belief
                mask = 1 - Variable(torch.eye(n_ments).cuda())
                ent_scores = local_ent_scores * 1 + torch.sum(msgs * mask.view(n_ments, 1, n_ments), dim=2)
                ent_scores = F.softmax(ent_scores, dim=1)
            else:
                onehot_gold = Variable(torch.zeros(n_ments, n_cands).cuda()).scatter_(1, gold, 1)
                ent_scores = torch.sum(torch.sum(ent_ent_scores * onehot_gold, dim=3), dim=2)

        # combine with p_e_m
        inputs = torch.cat([ent_scores.view(n_ments * n_cands, -1),
                            torch.log(p_e_m + 1e-20).view(n_ments * n_cands, -1)], dim=1)
        scores = self.score_combine(inputs).view(n_ments, n_cands)

        if self.use_pad_ent:
            scores = scores[:-1]
        return scores

    def regularize(self, max_norm=1):
        super(MulRelRanker, self).regularize(max_norm)

    def loss(self, scores, true_pos, lamb=1e-7):
        loss = F.multi_margin_loss(scores, true_pos, margin=self.margin)
        if self.use_local_only:
            return loss

        # regularization
        X = F.normalize(self.rel_embs)
        diff = (X.view(self.n_rels, 1, -1) - X.view(1, self.n_rels, -1)).pow(2).sum(dim=2).add_(1e-5).sqrt()
        diff = diff * (diff < 1).float()
        loss -= torch.sum(diff).mul(lamb)

        X = F.normalize(self.ew_embs)
        diff = (X.view(self.n_rels, 1, -1) - X.view(1, self.n_rels, -1)).pow(2).sum(dim=2).add_(1e-5).sqrt()
        diff = diff * (diff < 1).float()
        loss -= torch.sum(diff).mul(lamb)
        return loss


In [6]:
import re

wiki_link_prefix = 'http://en.wikipedia.org/wiki/'


def read_csv_file(path):
    data = {}
    with open(path, 'r', encoding='utf8') as f:
        for line in f:
            comps = line.strip().split('\t')
            doc_name = comps[0] + ' ' + comps[1]
            mention = comps[2]
            lctx = comps[3]
            rctx = comps[4]

            if comps[6] != 'EMPTYCAND':
                cands = [c.split(',') for c in comps[6:-2]]
                cands = [(','.join(c[2:]).replace('"', '%22').replace(' ', '_'), float(c[1])) for c in cands]
            else:
                cands = []

            gold = comps[-1].split(',')
            if gold[0] == '-1':
                gold = (','.join(gold[2:]).replace('"', '%22').replace(' ', '_'), 1e-5, -1)
            else:
                gold = (','.join(gold[3:]).replace('"', '%22').replace(' ', '_'), 1e-5, -1)

            if doc_name not in data:
                data[doc_name] = []
            data[doc_name].append(
                {
                    'mention': mention,
                    'context': (lctx, rctx),
                    'candidates': cands,
                    'gold': gold
                })
    return data


def read_conll_file(data, path):
    conll = {}
    with open(path, 'r', encoding='utf8') as f:
        cur_sent = None
        cur_doc = None

        for line in f:
            line = line.strip()
            if line.startswith('-DOCSTART-'):
                docname = line.split()[1][1:]
                conll[docname] = {'sentences': [], 'mentions': []}
                cur_doc = conll[docname]
                cur_sent = []

            else:
                if line == '':
                    cur_doc['sentences'].append(cur_sent)
                    cur_sent = []

                else:
                    comps = line.split('\t')
                    tok = comps[0]
                    cur_sent.append(tok)

                    if len(comps) >=6 :
                        bi = comps[1]
                        wikilink = comps[4]
                        if bi == 'I':
                            cur_doc['mentions'][-1]['end'] += 1
                        else:
                            new_ment = {'sent_id': len(cur_doc['sentences']),
                                        'start': len(cur_sent) - 1,
                                        'end': len(cur_sent),
                                        'wikilink': wikilink}
                            cur_doc['mentions'].append(new_ment)

    # merge with data
    rmpunc = re.compile('[\W]+')
    for doc_name, content in data.items():
        conll_doc = conll[doc_name.split()[0]]
        content[0]['conll_doc'] = conll_doc

        cur_conll_m_id = 0
        for m in content:
            mention = m['mention']
            gold = m['gold']

            while True:
                cur_conll_m = conll_doc['mentions'][cur_conll_m_id]
                cur_conll_mention = ' '.join(conll_doc['sentences'][cur_conll_m['sent_id']][cur_conll_m['start']:cur_conll_m['end']])
                if rmpunc.sub('', cur_conll_mention.lower()) == rmpunc.sub('', mention.lower()):
                    m['conll_m'] = cur_conll_m
                    cur_conll_m_id += 1
                    break
                else:
                    cur_conll_m_id += 1

    return data


def load_person_names(path):
    data = []
    with open(path, 'r', encoding='utf8') as f:
        for line in f:
            data.append(line.strip().replace(' ', '_'))
    return set(data)


def find_coref(ment, mentlist, person_names):
    cur_m = ment['mention'].lower()
    coref = []
    for m in mentlist:
        if len(m['candidates']) == 0 or m['candidates'][0][0] not in person_names:
            continue

        mention = m['mention'].lower()
        start_pos = mention.find(cur_m)
        if start_pos == -1 or mention == cur_m:
            continue

        end_pos = start_pos + len(cur_m) - 1
        if (start_pos == 0 or mention[start_pos-1] == ' ') and \
                (end_pos == len(mention) - 1 or mention[end_pos + 1] == ' '):
            coref.append(m)

    return coref


def with_coref(dataset, person_names):
    for data_name, content in dataset.items():
        for cur_m in content:
            coref = find_coref(cur_m, content, person_names)
            if coref is not None and len(coref) > 0:
                cur_cands = {}
                for m in coref:
                    for c, p in m['candidates']:
                        cur_cands[c] = cur_cands.get(c, 0) + p
                for c in cur_cands.keys():
                    cur_cands[c] /= len(coref)
                cur_m['candidates'] = sorted(list(cur_cands.items()), key=lambda x: x[1])[::-1]


def dataset_eval(testset, system_pred):
    gold = []
    pred = []

    for doc_name, content in testset.items():
        gold += [c['gold'][0] for c in content]
        pred += [c['pred'][0] for c in system_pred[doc_name]]

    true_pos = 0
    for g, p in zip(gold, pred):
        if g == p and p != 'NIL':
            true_pos += 1

    precision = true_pos / len([p for p in pred if p != 'NIL'])
    recall = true_pos / len(gold)
    f1 = 2 * precision * recall / (precision + recall)
    return f1


class CoNLLDataset:
    """
    reading dataset from CoNLL dataset, extracted by https://github.com/dalab/deep-ed/
    """

    def __init__(self, path, person_path, conll_path):
        print('load csv')
        self.train = read_csv_file(path + '/aida_train.csv')
        self.testA = read_csv_file(path + '/aida_testA.csv')
        self.testB = read_csv_file(path + '/aida_testB.csv')
        self.ace2004 = read_csv_file(path + '/wned-ace2004.csv')
        self.aquaint = read_csv_file(path + '/wned-aquaint.csv')
        self.clueweb = read_csv_file(path + '/wned-clueweb.csv')
        self.msnbc = read_csv_file(path + '/wned-msnbc.csv')
        self.wikipedia = read_csv_file(path + '/wned-wikipedia.csv')
        self.wikipedia.pop('Jiří_Třanovský Jiří_Třanovský', None)

        print('process coref')
        person_names = load_person_names(person_path)
        with_coref(self.train, person_names)
        with_coref(self.testA, person_names)
        with_coref(self.testB, person_names)
        with_coref(self.ace2004, person_names)
        with_coref(self.aquaint, person_names)
        with_coref(self.clueweb, person_names)
        with_coref(self.msnbc, person_names)
        with_coref(self.wikipedia, person_names)

        print('load conll')
        read_conll_file(self.train, conll_path + '/AIDA/aida_train.txt')
        read_conll_file(self.testA, conll_path + '/AIDA/testa_testb_aggregate_original')
        read_conll_file(self.testB, conll_path + '/AIDA/testa_testb_aggregate_original')
        read_conll_file(self.ace2004, conll_path + '/wned-datasets/ace2004/ace2004.conll')
        read_conll_file(self.aquaint, conll_path + '/wned-datasets/aquaint/aquaint.conll')
        read_conll_file(self.msnbc, conll_path + '/wned-datasets/msnbc/msnbc.conll')
        read_conll_file(self.clueweb, conll_path + '/wned-datasets/clueweb/clueweb.conll')
        read_conll_file(self.wikipedia, conll_path + '/wned-datasets/wikipedia/wikipedia.conll')


#if __name__ == "__main__":
#    path = '/datastore/ple/workspace/nel/preprocess_data/data/generated/test_train_data/'
#    conll_path = '/datastore/ple/workspace/nel/preprocess_data/data/basic_data/test_datasets/'
#    person_path = '/datastore/ple/workspace/nel/preprocess_data/data/basic_data/p_e_m_data/persons.txt'

#    dataset = CoNLLDataset(path, person_path, conll_path)
    # from pprint import pprint
    # pprint(dataset.ace2004, width=200)


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class NTEE(AbstractWordEntity):
    """
    NTEE model, proposed in Yamada et al. "Learning Distributed Representations of Texts and Entities from Knowledge Base"
    """

    def __init__(self, config):
        config['word_embeddings_class'] = nn.EmbeddingBag
        config['entity_embeddings_class'] = nn.Embedding
        super(NTEE, self).__init__(config)
        self.linear = nn.Linear(self.emb_dims, self.emb_dims)

    def compute_sent_vecs(self, token_ids, token_offsets, use_sum=False):
        sum_vecs = self.word_embeddings(token_ids, token_offsets)
        if use_sum:
            return sum_vecs

        sum_vecs = F.normalize(sum_vecs)
        sent_vecs = self.linear(sum_vecs)
        return sent_vecs

    def forward(self, token_ids, token_offsets, entity_ids, use_sum=False):
        sent_vecs = self.compute_sent_vecs(token_ids, token_offsets, use_sum)
        entity_vecs = self.entity_embeddings(entity_ids)

        # compute scores
        batchsize, dims = sent_vecs.size()
        n_entities = entity_vecs.size(1)
        scores = torch.bmm(entity_vecs, sent_vecs.view(batchsize, dims, 1)).view(batchsize, n_entities)

        log_probs = F.log_softmax(scores, dim=1)
        return log_probs

    def predict(self, token_ids, token_offsets, entity_ids, gold_entity_ids=None):
        log_probs = self.forward(token_ids, token_offsets, entity_ids)
        _, pred_entity_ids = torch.max(log_probs, dim=1)

        acc = None
        if gold_entity_ids is not None:
            acc = torch.eq(gold_entity_ids, pred_entity_ids).sum()
        return pred_entity_ids, acc

    def loss(self, log_probs, true_pos):
        return F.nll_loss(log_probs, true_pos)


def create_ntee_from_components(dir_path):
    word_dict_path = dir_path + '/dict.word'
    word_embs_path = dir_path + '/word_embeddings.npy'
    entity_dict_path = dir_path + '/dict.entity'
    entity_embs_path = dir_path + '/entity_embeddings.npy'
    W_path = dir_path + '/W.npy'
    b_path = dir_path + '/b.npy'

    print('load voca and embeddings')
    word_voca, word_embs = load_voca_embs(word_dict_path, word_embs_path)
    entity_voca, entity_embs = load_voca_embs(entity_dict_path, entity_embs_path)
    config = {'word_embeddings': word_embs,
              'entity_embeddings': entity_embs,
              'word_voca':  word_voca,
              'entity_voca': entity_voca,
              'emb_dims': word_embs.shape[1]}
    print(word_embs.shape, entity_embs.shape)

    # create model
    print('create model')
    model = NTEE(config)

    W = np.load(W_path)
    b = np.load(b_path)
    model.linear.weight = nn.Parameter(torch.FloatTensor(W).t())
    model.linear.bias = nn.Parameter(torch.FloatTensor(b))

    return model


In [8]:
## ed ranker

import torch
from torch.autograd import Variable
import numpy as np

from random import shuffle
import torch.optim as optim



ModelClass = MulRelRanker
wiki_prefix = 'en.wikipedia.org/wiki/'


class EDRanker:
    """
    ranking candidates
    """

    def __init__(self, config):
        print('--- create model ---')

        config['entity_embeddings'] = config['entity_embeddings'] / \
                                      np.maximum(np.linalg.norm(config['entity_embeddings'],
                                                                axis=1, keepdims=True), 1e-12)
        config['entity_embeddings'][config['entity_voca'].unk_id] = 1e-10
        config['word_embeddings'] = config['word_embeddings'] / \
                                    np.maximum(np.linalg.norm(config['word_embeddings'],
                                                              axis=1, keepdims=True), 1e-12)
        config['word_embeddings'][config['word_voca'].unk_id] = 1e-10

        print('prerank model')
        self.prerank_model = NTEE(config)
        self.args = config['args']

        print('main model')
        if self.args.mode == 'eval':
            print('try loading model from', self.args.model_path)
            self.model = load_model(self.args.model_path, ModelClass)
        else:
            print('create new model')
            if config['mulrel_type'] == 'rel-norm':
                config['use_stargmax'] = False
            if config['mulrel_type'] == 'ment-norm':
                config['first_head_uniform'] = False
                config['use_pad_ent'] = True

            config['use_local'] = True
            config['use_local_only'] = False
            config['oracle'] = False
            self.model = ModelClass(config)

        self.prerank_model.cuda()
        self.model.cuda()

    def prerank(self, dataset, predict=False):
        new_dataset = []
        has_gold = 0
        total = 0

        for content in dataset:
            items = []

            if self.args.keep_ctx_ent > 0:
                # rank the candidates by ntee scores
                lctx_ids = [m['context'][0][max(len(m['context'][0]) - self.args.prerank_ctx_window // 2, 0):]
                            for m in content]
                rctx_ids = [m['context'][1][:min(len(m['context'][1]), self.args.prerank_ctx_window // 2)]
                            for m in content]
                ment_ids = [[] for m in content]
                token_ids = [l + m + r if len(l) + len(r) > 0 else [self.prerank_model.word_voca.unk_id]
                             for l, m, r in zip(lctx_ids, ment_ids, rctx_ids)]

                entity_ids = [m['cands'] for m in content]
                entity_ids = Variable(torch.LongTensor(entity_ids).cuda())

                entity_mask = [m['mask'] for m in content]
                entity_mask = Variable(torch.FloatTensor(entity_mask).cuda())

                token_ids, token_offsets = flatten_list_of_lists(token_ids)
                token_offsets = Variable(torch.LongTensor(token_offsets).cuda())
                token_ids = Variable(torch.LongTensor(token_ids).cuda())

                log_probs = self.prerank_model.forward(token_ids, token_offsets, entity_ids, use_sum=True)
                log_probs = (log_probs * entity_mask).add_((entity_mask - 1).mul_(1e10))
                _, top_pos = torch.topk(log_probs, dim=1, k=self.args.keep_ctx_ent)
                top_pos = top_pos.data.cpu().numpy()

            else:
                top_pos = [[]] * len(content)

            # select candidats: mix between keep_ctx_ent best candidates (ntee scores) with
            # keep_p_e_m best candidates (p_e_m scores)
            for i, m in enumerate(content):
                sm = {'cands': [],
                      'named_cands': [],
                      'p_e_m': [],
                      'mask': [],
                      'true_pos': -1}
                m['selected_cands'] = sm

                selected = set(top_pos[i])
                idx = 0
                while len(selected) < self.args.keep_ctx_ent + self.args.keep_p_e_m:
                    if idx not in selected:
                        selected.add(idx)
                    idx += 1

                selected = sorted(list(selected))
                for idx in selected:
                    sm['cands'].append(m['cands'][idx])
                    sm['named_cands'].append(m['named_cands'][idx])
                    sm['p_e_m'].append(m['p_e_m'][idx])
                    sm['mask'].append(m['mask'][idx])
                    if idx == m['true_pos']:
                        sm['true_pos'] = len(sm['cands']) - 1

                if not predict:
                    if sm['true_pos'] == -1:
                        continue

                items.append(m)
                if sm['true_pos'] >= 0:
                    has_gold += 1
                total += 1

                if predict:
                    # only for oracle model, not used for eval
                    if sm['true_pos'] == -1:
                        sm['true_pos'] = 0  # a fake gold, happens only 2%, but avoid the non-gold

            if len(items) > 0:
                new_dataset.append(items)

        print('recall', has_gold / total)
        return new_dataset

    def get_data_items(self, dataset, predict=False):
        data = []
        cand_source = 'candidates'

        for doc_name, content in dataset.items():
            items = []
            conll_doc = content[0].get('conll_doc', None)

            for m in content:
                try:
                    named_cands = [c[0] for c in m[cand_source]]
                    p_e_m = [min(1., max(1e-3, c[1])) for c in m[cand_source]]
                except:
                    named_cands = [c[0] for c in m['candidates']]
                    p_e_m = [min(1., max(1e-3, c[1])) for c in m['candidates']]

                try:
                    true_pos = named_cands.index(m['gold'][0])
                    p = p_e_m[true_pos]
                except:
                    true_pos = -1

                named_cands = named_cands[:min(self.args.n_cands_before_rank, len(named_cands))]
                p_e_m = p_e_m[:min(self.args.n_cands_before_rank, len(p_e_m))]

                if true_pos >= len(named_cands):
                    if not predict:
                        true_pos = len(named_cands) - 1
                        p_e_m[-1] = p
                        named_cands[-1] = m['gold'][0]
                    else:
                        true_pos = -1

                cands = [self.model.entity_voca.get_id(wiki_prefix + c) for c in named_cands]
                mask = [1.] * len(cands)
                if len(cands) == 0 and not predict:
                    continue
                elif len(cands) < self.args.n_cands_before_rank:
                    cands += [self.model.entity_voca.unk_id] * (self.args.n_cands_before_rank - len(cands))
                    named_cands += [Vocabulary.unk_token] * (self.args.n_cands_before_rank - len(named_cands))
                    p_e_m += [1e-8] * (self.args.n_cands_before_rank - len(p_e_m))
                    mask += [0.] * (self.args.n_cands_before_rank - len(mask))

                lctx = m['context'][0].strip().split()
                lctx_ids = [self.prerank_model.word_voca.get_id(t) for t in lctx if is_important_word(t)]
                lctx_ids = [tid for tid in lctx_ids if tid != self.prerank_model.word_voca.unk_id]
                lctx_ids = lctx_ids[max(0, len(lctx_ids) - self.args.ctx_window//2):]

                rctx = m['context'][1].strip().split()
                rctx_ids = [self.prerank_model.word_voca.get_id(t) for t in rctx if is_important_word(t)]
                rctx_ids = [tid for tid in rctx_ids if tid != self.prerank_model.word_voca.unk_id]
                rctx_ids = rctx_ids[:min(len(rctx_ids), self.args.ctx_window//2)]

                ment = m['mention'].strip().split()
                ment_ids = [self.prerank_model.word_voca.get_id(t) for t in ment if is_important_word(t)]
                ment_ids = [tid for tid in ment_ids if tid != self.prerank_model.word_voca.unk_id]

                m['sent'] = ' '.join(lctx + rctx)

                # secondary local context (for computing relation scores)
                if conll_doc is not None:
                    conll_m = m['conll_m']
                    sent = conll_doc['sentences'][conll_m['sent_id']]
                    start = conll_m['start']
                    end = conll_m['end']

                    snd_lctx = [self.model.snd_word_voca.get_id(t)
                                for t in sent[max(0, start - self.args.snd_local_ctx_window//2):start]]
                    snd_rctx = [self.model.snd_word_voca.get_id(t)
                                for t in sent[end:min(len(sent), end + self.args.snd_local_ctx_window//2)]]
                    snd_ment = [self.model.snd_word_voca.get_id(t)
                                for t in sent[start:end]]

                    if len(snd_lctx) == 0:
                        snd_lctx = [self.model.snd_word_voca.unk_id]
                    if len(snd_rctx) == 0:
                        snd_rctx = [self.model.snd_word_voca.unk_id]
                    if len(snd_ment) == 0:
                        snd_ment = [self.model.snd_word_voca.unk_id]
                else:
                    snd_lctx = [self.model.snd_word_voca.unk_id]
                    snd_rctx = [self.model.snd_word_voca.unk_id]
                    snd_ment = [self.model.snd_word_voca.unk_id]

                items.append({'context': (lctx_ids, rctx_ids),
                              'snd_ctx': (snd_lctx, snd_rctx),
                              'ment_ids': ment_ids,
                              'snd_ment': snd_ment,
                              'cands': cands,
                              'named_cands': named_cands,
                              'p_e_m': p_e_m,
                              'mask': mask,
                              'true_pos': true_pos,
                              'doc_name': doc_name,
                              'raw': m
                              })

            if len(items) > 0:
                # note: this shouldn't affect the order of prediction because we use doc_name to add predicted entities,
                # and we don't shuffle the data for prediction
                if len(items) > 100:
                    print(len(items))
                    for k in range(0, len(items), 100):
                        data.append(items[k:min(len(items), k + 100)])
                else:
                    data.append(items)

        return self.prerank(data, predict)

    def train(self, org_train_dataset, org_dev_datasets, config):
        print('extracting training data')
        train_dataset = self.get_data_items(org_train_dataset, predict=False)
        print('#train docs', len(train_dataset))

        dev_datasets = []
        for dname, data in org_dev_datasets:
            dev_datasets.append((dname, self.get_data_items(data, predict=True)))
            print(dname, '#dev docs', len(dev_datasets[-1][1]))

        print('creating optimizer')
        optimizer = optim.Adam([p for p in self.model.parameters() if p.requires_grad], lr=config['lr'])
        best_f1 = -1
        not_better_count = 0
        is_counting = False
        eval_after_n_epochs = self.args.eval_after_n_epochs

        for e in range(config['n_epochs']):
            shuffle(train_dataset)

            total_loss = 0
            for dc, batch in enumerate(train_dataset):  # each document is a minibatch
                self.model.train()
                optimizer.zero_grad()

                # convert data items to pytorch inputs
                token_ids = [m['context'][0] + m['context'][1]
                             if len(m['context'][0]) + len(m['context'][1]) > 0
                             else [self.model.word_voca.unk_id]
                             for m in batch]
                s_ltoken_ids = [m['snd_ctx'][0] for m in batch]
                s_rtoken_ids = [m['snd_ctx'][1] for m in batch]
                s_mtoken_ids = [m['snd_ment'] for m in batch]

                entity_ids = Variable(torch.LongTensor([m['selected_cands']['cands'] for m in batch]).cuda())
                true_pos = Variable(torch.LongTensor([m['selected_cands']['true_pos'] for m in batch]).cuda())
                p_e_m = Variable(torch.FloatTensor([m['selected_cands']['p_e_m'] for m in batch]).cuda())
                entity_mask = Variable(torch.FloatTensor([m['selected_cands']['mask'] for m in batch]).cuda())

                token_ids, token_mask = make_equal_len(token_ids, self.model.word_voca.unk_id)
                s_ltoken_ids, s_ltoken_mask = make_equal_len(s_ltoken_ids, self.model.snd_word_voca.unk_id,
                                                                   to_right=False)
                s_rtoken_ids, s_rtoken_mask = make_equal_len(s_rtoken_ids, self.model.snd_word_voca.unk_id)
                s_rtoken_ids = [l[::-1] for l in s_rtoken_ids]
                s_rtoken_mask = [l[::-1] for l in s_rtoken_mask]
                s_mtoken_ids, s_mtoken_mask = make_equal_len(s_mtoken_ids, self.model.snd_word_voca.unk_id)

                token_ids = Variable(torch.LongTensor(token_ids).cuda())
                token_mask = Variable(torch.FloatTensor(token_mask).cuda())
                # too ugly but too lazy to fix it
                self.model.s_ltoken_ids = Variable(torch.LongTensor(s_ltoken_ids).cuda())
                self.model.s_ltoken_mask = Variable(torch.FloatTensor(s_ltoken_mask).cuda())
                self.model.s_rtoken_ids = Variable(torch.LongTensor(s_rtoken_ids).cuda())
                self.model.s_rtoken_mask = Variable(torch.FloatTensor(s_rtoken_mask).cuda())
                self.model.s_mtoken_ids = Variable(torch.LongTensor(s_mtoken_ids).cuda())
                self.model.s_mtoken_mask = Variable(torch.FloatTensor(s_mtoken_mask).cuda())

                scores = self.model.forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m,
                                            gold=true_pos.view(-1, 1))
                loss = self.model.loss(scores, true_pos)

                loss.backward()
                optimizer.step()
                self.model.regularize(max_norm=100)

                loss = loss.cpu().data.numpy()
                total_loss += loss
                print('epoch', e, "%0.2f%%" % (dc/len(train_dataset) * 100), loss, end='\r')

            print('epoch', e, 'total loss', total_loss, total_loss / len(train_dataset))

            if (e + 1) % eval_after_n_epochs == 0:
                dev_f1 = 0
                for di, (dname, data) in enumerate(dev_datasets):
                    predictions = self.predict(data)
                    f1 = dataset_eval(org_dev_datasets[di][1], predictions)
                    print(dname, tokgreen('micro F1: ' + str(f1)))

                    if dname == 'aida-A':
                        dev_f1 = f1

                if config['lr'] == 1e-4 and dev_f1 >= self.args.dev_f1_change_lr:
                    eval_after_n_epochs = 2
                    is_counting = True
                    best_f1 = dev_f1
                    not_better_count = 0

                    config['lr'] = 1e-5
                    print('change learning rate to', config['lr'])
                    if self.args.mulrel_type == 'rel-norm':
                        optimizer = optim.Adam([p for p in self.model.parameters() if p.requires_grad], lr=config['lr'])
                    elif self.args.mulrel_type == 'ment-norm':
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = config['lr']

                if is_counting:
                    if dev_f1 < best_f1:
                        not_better_count += 1
                    else:
                        not_better_count = 0
                        best_f1 = dev_f1
                        print('save model to', self.args.model_path)
                        self.model.save(self.args.model_path)

                if not_better_count == self.args.n_not_inc:
                    break

                self.model.print_weight_norm()

    def predict(self, data):
        predictions = {items[0]['doc_name']: [] for items in data}
        self.model.eval()

        for batch in data:  # each document is a minibatch
            token_ids = [m['context'][0] + m['context'][1]
                         if len(m['context'][0]) + len(m['context'][1]) > 0
                         else [self.model.word_voca.unk_id]
                         for m in batch]
            s_ltoken_ids = [m['snd_ctx'][0] for m in batch]
            s_rtoken_ids = [m['snd_ctx'][1] for m in batch]
            s_mtoken_ids = [m['snd_ment'] for m in batch]

            lctx_ids = s_ltoken_ids
            rctx_ids = s_rtoken_ids
            m_ids = s_mtoken_ids

            entity_ids = Variable(torch.LongTensor([m['selected_cands']['cands'] for m in batch]).cuda())
            p_e_m = Variable(torch.FloatTensor([m['selected_cands']['p_e_m'] for m in batch]).cuda())
            entity_mask = Variable(torch.FloatTensor([m['selected_cands']['mask'] for m in batch]).cuda())
            true_pos = Variable(torch.LongTensor([m['selected_cands']['true_pos'] for m in batch]).cuda())

            token_ids, token_mask = make_equal_len(token_ids, self.model.word_voca.unk_id)
            s_ltoken_ids, s_ltoken_mask = make_equal_len(s_ltoken_ids, self.model.snd_word_voca.unk_id,
                                                               to_right=False)
            s_rtoken_ids, s_rtoken_mask = make_equal_len(s_rtoken_ids, self.model.snd_word_voca.unk_id)
            s_rtoken_ids = [l[::-1] for l in s_rtoken_ids]
            s_rtoken_mask = [l[::-1] for l in s_rtoken_mask]
            s_mtoken_ids, s_mtoken_mask = make_equal_len(s_mtoken_ids, self.model.snd_word_voca.unk_id)

            token_ids = Variable(torch.LongTensor(token_ids).cuda())
            token_mask = Variable(torch.FloatTensor(token_mask).cuda())
            # too ugly, but too lazy to fix it
            self.model.s_ltoken_ids = Variable(torch.LongTensor(s_ltoken_ids).cuda())
            self.model.s_ltoken_mask = Variable(torch.FloatTensor(s_ltoken_mask).cuda())
            self.model.s_rtoken_ids = Variable(torch.LongTensor(s_rtoken_ids).cuda())
            self.model.s_rtoken_mask = Variable(torch.FloatTensor(s_rtoken_mask).cuda())
            self.model.s_mtoken_ids = Variable(torch.LongTensor(s_mtoken_ids).cuda())
            self.model.s_mtoken_mask = Variable(torch.FloatTensor(s_mtoken_mask).cuda())

            scores = self.model.forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m,
                                        gold=true_pos.view(-1, 1))
            scores = scores.cpu().data.numpy()

            # print out relation weights
            if self.args.mode == 'eval' and self.args.print_rel:
                print('================================')
                weights = self.model._rel_ctx_ctx_weights.cpu().data.numpy()
                voca = self.model.snd_word_voca
                for i in range(len(batch)):
                    print(' '.join([voca.id2word[id] for id in lctx_ids[i]]),
                          tokgreen(' '.join([voca.id2word[id] for id in m_ids[i]])),
                          ' '.join([voca.id2word[id] for id in rctx_ids[i]]))
                    for j in range(len(batch)):
                        if i == j:
                            continue
                        np.set_printoptions(precision=2)
                        print('\t', weights[:, i, j], '\t',
                              ' '.join([voca.id2word[id] for id in lctx_ids[j]]),
                              tokgreen(' '.join([voca.id2word[id] for id in m_ids[j]])),
                              ' '.join([voca.id2word[id] for id in rctx_ids[j]]))

            pred_ids = np.argmax(scores, axis=1)
            pred_entities = [m['selected_cands']['named_cands'][i] if m['selected_cands']['mask'][i] == 1
                             else (m['selected_cands']['named_cands'][0] if m['selected_cands']['mask'][0] == 1 else 'NIL')
                             for (i, m) in zip(pred_ids, batch)]
            doc_names = [m['doc_name'] for m in batch]

            if self.args.mode == 'eval' and self.args.print_incorrect:
                gold = [item['selected_cands']['named_cands'][item['selected_cands']['true_pos']]
                        if item['selected_cands']['true_pos'] >= 0 else 'UNKNOWN' for item in batch]
                pred = pred_entities
                for i in range(len(gold)):
                    if gold[i] != pred[i]:
                        print('--------------------------------------------')
                        pprint(batch[i]['raw'])
                        print(gold[i], pred[i])

            for dname, entity in zip(doc_names, pred_entities):
                predictions[dname].append({'pred': (entity, 0.)})

        return predictions


In [9]:
##filter word2vec

import sys
import numpy as np 



core_voca_path = sys.argv[1]
word_embs_dir = sys.argv[2]

print('load core voca from', core_voca_path)
core_voca = Vocabulary.load(core_voca_path)

print('load full voca and embs')
full_voca, full_embs = load_voca_embs(
    word_embs_dir + '/all_dict.word', word_embs_dir + '/all_word_embeddings.npy')

print('select word ids')
selected = []
for word in core_voca.id2word: 
    word_id = full_voca.word2id.get(word, -1)
    if word_id >= 0: 
        selected.append(word_id)

print('save...')
selected_embs = full_embs[selected, :]
np.save(word_embs_dir + '/word_embeddings', selected_embs)

with open(word_embs_dir + '/dict.word', 'w', encoding='utf8') as f:
    for i in selected: 
        f.write(full_voca.id2word[i] + '\t1000\n')


load core voca from -f


FileNotFoundError: [Errno 2] No such file or directory: '-f'

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class LocalCtxAttRanker(AbstractWordEntity):
    """
    local model with context token attention (from G&H's EMNLP paper)
    """

    def __init__(self, config):
        config['word_embeddings_class'] = nn.Embedding
        config['entity_embeddings_class'] = nn.Embedding
        super(LocalCtxAttRanker, self).__init__(config)

        self.hid_dims = config['hid_dims']
        self.tok_top_n = config['tok_top_n']
        self.margin = config['margin']

        self.att_mat_diag = nn.Parameter(torch.ones(self.emb_dims))
        self.tok_score_mat_diag = nn.Parameter(torch.ones(self.emb_dims))
        self.local_ctx_dr = nn.Dropout(p=0)

        self.score_combine_linear_1 = nn.Linear(2, self.hid_dims)
        self.score_combine_act_1 = nn.ReLU()
        self.score_combine_linear_2 = nn.Linear(self.hid_dims, 1)

    def print_weight_norm(self):
        print('att_mat_diag', self.att_mat_diag.data.norm())
        print('tok_score_mat_diag', self.tok_score_mat_diag.data.norm())
        print('f - l1.w, b', self.score_combine_linear_1.weight.data.norm(),  self.score_combine_linear_1.bias.data.norm())
        print('f - l2.w, b', self.score_combine_linear_2.weight.data.norm(),  self.score_combine_linear_2.bias.data.norm())

    def print_attention(self, gold_pos):
        token_ids = self._token_ids.data.cpu().numpy()
        entity_ids = self._entity_ids.data.cpu().numpy()
        att_probs = self._att_probs.data.cpu().numpy()
        top_tok_att_ids = self._top_tok_att_ids.data.cpu().numpy()
        gold_pos = gold_pos.data.cpu().numpy()
        scores = self._scores.data.cpu().numpy()

        print('===========================================')
        for tids, eids, ap, aids, gpos, ss in zip(token_ids, entity_ids, att_probs, top_tok_att_ids, gold_pos, scores):
            selected_tids = tids[aids]
            print('-------------------------------')
            print(tokgreen(repr([(self.entity_voca.id2word[e], s) for e, s in zip(eids, ss)])),
                  tokblue(repr(self.entity_voca.id2word[eids[gpos]] if gpos > -1 else 'UNKNOWN')))
            print([(self.word_voca.id2word[t], a[0]) for t, a in zip(selected_tids, ap)])

    def forward(self, token_ids, tok_mask, entity_ids, entity_mask, p_e_m=None):
        batchsize, n_words = token_ids.size()
        n_entities = entity_ids.size(1)
        tok_mask = tok_mask.view(batchsize, 1, -1)

        tok_vecs = self.word_embeddings(token_ids)
        entity_vecs = self.entity_embeddings(entity_ids)

        # att
        ent_tok_att_scores = torch.bmm(entity_vecs * self.att_mat_diag, tok_vecs.permute(0, 2, 1))
        ent_tok_att_scores = (ent_tok_att_scores * tok_mask).add_((tok_mask - 1).mul_(1e10))
        tok_att_scores, _ = torch.max(ent_tok_att_scores, dim=1)
        top_tok_att_scores, top_tok_att_ids = torch.topk(tok_att_scores, dim=1, k=min(self.tok_top_n, n_words))
        att_probs = F.softmax(top_tok_att_scores, dim=1).view(batchsize, -1, 1)
        att_probs = att_probs / torch.sum(att_probs, dim=1, keepdim=True)

        selected_tok_vecs = torch.gather(tok_vecs, dim=1,
                                         index=top_tok_att_ids.view(batchsize, -1, 1).repeat(1, 1, tok_vecs.size(2)))
        ctx_vecs = torch.sum((selected_tok_vecs * self.tok_score_mat_diag) * att_probs, dim=1, keepdim=True)
        ctx_vecs = self.local_ctx_dr(ctx_vecs)
        ent_ctx_scores = torch.bmm(entity_vecs, ctx_vecs.permute(0, 2, 1)).view(batchsize, n_entities)

        # combine with p(e|m) if p_e_m is not None
        if p_e_m is not None:
            inputs = torch.cat([ent_ctx_scores.view(batchsize * n_entities, -1),
                                torch.log(p_e_m + 1e-20).view(batchsize * n_entities, -1)], dim=1)
            hidden = self.score_combine_linear_1(inputs)
            hidden = self.score_combine_act_1(hidden)
            scores = self.score_combine_linear_2(hidden).view(batchsize, n_entities)
        else:
            scores = ent_ctx_scores

        scores = (scores * entity_mask).add_((entity_mask - 1).mul_(1e10))

        # printing attention (debugging)
        self._token_ids = token_ids
        self._entity_ids = entity_ids
        self._att_probs = att_probs
        self._top_tok_att_ids = top_tok_att_ids
        self._scores = scores

        self._entity_vecs = entity_vecs
        self._local_ctx_vecs = ctx_vecs

        return scores

    def regularize(self, max_norm=1):
        l1_w_norm = self.score_combine_linear_1.weight.norm()
        l1_b_norm = self.score_combine_linear_1.bias.norm()
        l2_w_norm = self.score_combine_linear_2.weight.norm()
        l2_b_norm = self.score_combine_linear_2.bias.norm()

        if (l1_w_norm > max_norm).data.all():
            self.score_combine_linear_1.weight.data = self.score_combine_linear_1.weight.data * max_norm / l1_w_norm.data
        if (l1_b_norm > max_norm).data.all():
            self.score_combine_linear_1.bias.data = self.score_combine_linear_1.bias.data *  max_norm / l1_b_norm.data
        if (l2_w_norm > max_norm).data.all():
            self.score_combine_linear_2.weight.data = self.score_combine_linear_2.weight.data * max_norm / l2_w_norm.data
        if (l2_b_norm > max_norm).data.all():
            self.score_combine_linear_2.bias.data = self.score_combine_linear_2.bias.data *  max_norm / l2_b_norm.data

    def loss(self, scores, true_pos):
        loss = F.multi_margin_loss(scores, true_pos, margin=self.margin)
        return loss


In [11]:
import time

totaltime = {}
start_at = {}


def tik(name):
    start_at[name] = int(round(time.time() * 1000))


def tok(name):
    if name not in start_at:
        raise Exception("not tik yet")
    if name not in totaltime:
        totaltime[name] = 0.
    totaltime[name] += int(round(time.time() * 1000)) - start_at[name]


def print_time(name=None):
    print('------- running time -------')
    if name is not None:
        print(name, totaltime[name])
    else:
        for name,t in totaltime.items():
            print('---', name, t)
    print('---------------------------')


def reset():
    global totaltime
    global start_at
    totaltime = {}
    start_at = {}


In [12]:
from pprint import pprint

import argparse

parser = argparse.ArgumentParser()


datadir = 'data/generated/test_train_data/'
conll_path = 'data/basic_data/test_datasets/'
person_path = 'data/basic_data/p_e_m_data/persons.txt'
voca_emb_dir = 'data/generated/embeddings/word_ent_embs/'

ModelClass = MulRelRanker


# general args
parser.add_argument("--mode", type=str,
                    help="train or eval",
                    default='train')
parser.add_argument("--model_path", type=str,
                    help="model path to save/load",
                    default='')

# args for preranking (i.e. 2-step candidate selection)
parser.add_argument("--n_cands_before_rank", type=int,
                    help="number of candidates",
                    default=30)
parser.add_argument("--prerank_ctx_window", type=int,
                    help="size of context window for the preranking model",
                    default=50)
parser.add_argument("--keep_p_e_m", type=int,
                    help="number of top candidates to keep w.r.t p(e|m)",
                    default=4)
parser.add_argument("--keep_ctx_ent", type=int,
                    help="number of top candidates to keep w.r.t using context",
                    default=4)

# args for local model
parser.add_argument("--ctx_window", type=int,
                    help="size of context window for the local model",
                    default=100)
parser.add_argument("--tok_top_n", type=int,
                    help="number of top contextual words for the local model",
                    default=25)


# args for global model
parser.add_argument("--mulrel_type", type=str,
                    help="type for multi relation (rel-norm or ment-norm)",
                    default='ment-norm')
parser.add_argument("--n_rels", type=int,
                    help="number of relations",
                    default=5)
parser.add_argument("--hid_dims", type=int,
                    help="number of hidden neurons",
                    default=100)
parser.add_argument("--snd_local_ctx_window", type=int,
                    help="local ctx window size for relation scores",
                    default=6)
parser.add_argument("--dropout_rate", type=float,
                    help="dropout rate for relation scores",
                    default=0.3)


# args for training
parser.add_argument("--n_epochs", type=int,
                    help="max number of epochs",
                    default=200)
parser.add_argument("--dev_f1_change_lr", type=float,
                    help="dev f1 to change learning rate",
                    default=0.915)
parser.add_argument("--n_not_inc", type=int,
                    help="number of evals after dev f1 not increase",
                    default=10)
parser.add_argument("--eval_after_n_epochs", type=int,
                    help="number of epochs to eval",
                    default=5)
parser.add_argument("--learning_rate", type=float,
                    help="learning rate",
                    default=1e-4)
parser.add_argument("--margin", type=float,
                    help="margin",
                    default=0.01)

# args for LBP
parser.add_argument("--df", type=float,
                    help="dumpling factor (for LBP)",
                    default=0.5)
parser.add_argument("--n_loops", type=int,
                    help="number of LBP loops",
                    default=10)

# args for debugging
parser.add_argument("--print_rel", action='store_true')
parser.add_argument("--print_incorrect", action='store_true')

args, unknown = parser.parse_known_args()


In [13]:
if __name__ == "__main__":
    print('load conll at', datadir)
    conll = CoNLLDataset(datadir, person_path, conll_path)

    print('create model')
    word_voca, word_embeddings = load_voca_embs(voca_emb_dir + 'dict.word',
                                                      voca_emb_dir + 'word_embeddings.npy')
    print('word voca size', word_voca.size())
    snd_word_voca, snd_word_embeddings = load_voca_embs(voca_emb_dir + '/glove/dict.word',
                                                              voca_emb_dir + '/glove/word_embeddings.npy')
    print('snd word voca size', snd_word_voca.size())

    entity_voca, entity_embeddings = load_voca_embs(voca_emb_dir + 'dict.entity',
                                                          voca_emb_dir + 'entity_embeddings.npy')
    config = {'hid_dims': args.hid_dims,
              'emb_dims': entity_embeddings.shape[1],
              'freeze_embs': True,
              'tok_top_n': args.tok_top_n,
              'margin': args.margin,
              'word_voca': word_voca,
              'entity_voca': entity_voca,
              'word_embeddings': word_embeddings,
              'entity_embeddings': entity_embeddings,
              'snd_word_voca': snd_word_voca,
              'snd_word_embeddings': snd_word_embeddings,
              'dr': args.dropout_rate,
              'args': args}

    if ModelClass == MulRelRanker:
        config['df'] = args.df
        config['n_loops'] = args.n_loops
        config['n_rels'] = args.n_rels
        config['mulrel_type'] = args.mulrel_type
    else:
        raise Exception('unknown model class')

    pprint(config)
    ranker = EDRanker(config=config)

    dev_datasets = [('aida-A', conll.testA),
                    ('aida-B', conll.testB),
                    ('msnbc', conll.msnbc),
                    ('aquaint', conll.aquaint),
                    ('ace2004', conll.ace2004),
                    ('clueweb', conll.clueweb),
                    ('wikipedia', conll.wikipedia)
                    ]

    print('training...')
    config = {'lr': args.learning_rate, 'n_epochs': args.n_epochs}
    pprint(config)
    ranker.train(conll.train, dev_datasets, config)

    print("\nEvaluating...")
    org_dev_datasets = dev_datasets  # + [('aida-train', conll.train)]
    dev_datasets = []
    for dname, data in org_dev_datasets:
        dev_datasets.append((dname, ranker.get_data_items(data, predict=True)))
        print(dname, '#dev docs', len(dev_datasets[-1][1]))

    vecs = ranker.model.rel_embs.cpu().data.numpy()

    for di, (dname, data) in enumerate(dev_datasets):
        ranker.model._coh_ctx_vecs = []
        predictions = ranker.predict(data)
        print(dname, tokgreen('micro F1: ' + str(dataset_eval(org_dev_datasets[di][1], predictions))))

load conll at data/generated/test_train_data/
load csv
process coref
load conll
create model
word voca size 492408
snd word voca size 60862
{'args': Namespace(mode='train', model_path='', n_cands_before_rank=30, prerank_ctx_window=50, keep_p_e_m=4, keep_ctx_ent=4, ctx_window=100, tok_top_n=25, mulrel_type='ment-norm', n_rels=5, hid_dims=100, snd_local_ctx_window=6, dropout_rate=0.3, n_epochs=200, dev_f1_change_lr=0.915, n_not_inc=10, eval_after_n_epochs=5, learning_rate=0.0001, margin=0.01, df=0.5, n_loops=10, print_rel=False, print_incorrect=False),
 'df': 0.5,
 'dr': 0.3,
 'emb_dims': 300,
 'entity_embeddings': array([[ 0.06      , -0.075     ,  0.014     , ...,  0.083     ,
        -0.02      , -0.031     ],
       [-0.1       ,  0.058     ,  0.041     , ..., -0.075     ,
         0.089     , -0.047     ],
       [ 0.014     ,  0.062     ,  0.028     , ..., -0.026     ,
         0.06      , -0.068     ],
       ...,
       [ 0.014     ,  0.018     , -0.025     , ...,  0.019     ,
  

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.r

epoch 0 0.63% 0.009599868epoch 0 0.73% 0.030653555epoch 0 0.84% -3.1622776e-09epoch 0 0.94% -3.1622776e-09epoch 0 1.05% -3.1622776e-09epoch 0 1.15% 0.0090118535epoch 0 1.26% 0.010226317epoch 0 1.36% -3.1622776e-09epoch 0 1.47% 0.011400476epoch 0 1.57% 0.01960607

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 3.88% 0.003629722409

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.r

epoch 0 6.82% 0.017629059049

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 6.93% 0.015093019epoch 0 7.03% 0.0060781203epoch 0 7.14% 0.0023291365epoch 0 7.24% 0.002380739epoch 0 7.35% -3.1622776e-09epoch 0 7.45% -3.1622776e-09epoch 0 7.56% 0.009906992epoch 0 7.66% -3.1622776e-09epoch 0 7.76% 0.0021020146epoch 0 7.87% -3.1622776e-09epoch 0 7.97% -3.1622776e-09epoch 0 8.08% 0.0017806442epoch 0 8.18% 0.0014314814

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 10.91% -3.1622776e-09

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 11.02% 0.00017395771epoch 0 11.12% -3.1622776e-09epoch 0 11.23% -3.1622776e-09epoch 0 11.33% 0.03140534epoch 0 11.44% 0.001059558epoch 0 11.54% 0.0013248508epoch 0 11.65% 0.00068837014epoch 0 11.75% 0.03612699epoch 0 11.86% 0.0003128293epoch 0 11.96% 0.0012280152epoch 0 12.07% 0.006698726epoch 0 12.17% -3.1622776e-09epoch 0 12.28% -3.1622776e-09

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 12.38% 0.0005999954epoch 0 12.49% 0.015689798epoch 0 12.59% 0.019660551epoch 0 12.70% 0.0799461epoch 0 12.80% -3.1622776e-09epoch 0 12.91% 0.02307387epoch 0 13.01% 0.001462177epoch 0 13.12% 0.030512244epoch 0 13.22% -3.1622776e-09

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 16.79% 0.002718228609

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 16.89% 0.014491152epoch 0 17.00% 0.0036770222epoch 0 17.10% 7.5256845e-05epoch 0 17.21% 0.00042263296epoch 0 17.31% -3.1622776e-09epoch 0 17.42% 0.0010779361epoch 0 17.52% 0.0011833672epoch 0 17.63% 0.005364581epoch 0 17.73% 0.007966773epoch 0 17.84% -3.1622776e-09epoch 0 17.94% 0.0085604265epoch 0 18.05% 0.005338028epoch 0 18.15% 0.0018571042epoch 0 18.26% 0.007395204epoch 0 18.36% 0.0006858181epoch 0 18.47% 0.00096230523epoch 0 18.57% 0.006143187

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 21.93% 0.003088459879

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 24.24% 0.001858393859

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 24.34% 0.0028254683epoch 0 24.45% 0.00025767146epoch 0 24.55% 0.0018945585epoch 0 24.66% 0.009097236epoch 0 24.76% 0.0038019507epoch 0 24.87% 0.007903037epoch 0 24.97% -3.1622776e-09epoch 0 25.08% 0.0024352176epoch 0 25.18% 0.0017438522epoch 0 25.29% 0.002868364epoch 0 25.39% 0.00075482274epoch 0 25.50% 0.00030977637epoch 0 25.60% 0.0002740437

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 25.71% -3.1622776e-09epoch 0 25.81% -3.1622776e-09epoch 0 25.92% -3.1622776e-09epoch 0 26.02% 0.0033034747epoch 0 26.13% 0.0001305126epoch 0 26.23% 0.007344491epoch 0 26.34% 0.006622865epoch 0 26.44% 0.0007318565epoch 0 26.55% 0.006056784epoch 0 26.65% 0.0009135384epoch 0 26.76% 0.0006879401epoch 0 26.86% 0.006352129epoch 0 26.97% 0.003747736epoch 0 27.07% 0.0018585352epoch 0 27.18% 0.00030010176

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 27.28% 0.00014817672epoch 0 27.39% 0.0028833249epoch 0 27.49% 0.0014763328epoch 0 27.60% 0.0022900258epoch 0 27.70% 0.0018417381epoch 0 27.81% 0.0010929675epoch 0 27.91% 0.0036643317epoch 0 28.02% -3.1622776e-09epoch 0 28.12% 0.0021401579epoch 0 28.23% -3.1622776e-09epoch 0 28.33% 0.0046823435epoch 0 28.44% -3.1622776e-09epoch 0 28.54% 0.0018657987epoch 0 28.65% 9.85078e-05epoch 0 28.75% 0.0006174272epoch 0 28.86% 0.0023694762

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 35.05% 0.001964189359

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 38.09% 0.005838713177

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 44.18% 0.000889902446

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 44.28% 0.0018033917epoch 0 44.39% -1.1030456e-06epoch 0 44.49% 0.0002222654epoch 0 44.60% 0.00021747801epoch 0 44.70% 0.0009104105epoch 0 44.81% 0.0017208899epoch 0 44.91% 0.000648185epoch 0 45.02% -1.101469e-06epoch 0 45.12% 0.0012527276epoch 0 45.23% 0.0015180184epoch 0 45.33% -1.1013333e-06epoch 0 45.44% 0.0011546082epoch 0 45.54% -1.1013256e-06epoch 0 45.65% -1.1013847e-06epoch 0 45.75% 0.0012679126epoch 0 45.86% -1.1016201e-06

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 49.00% 0.000136833176

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 52.47% 0.004899784836

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 61.59% -1.8426231e-06

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 67.79% 0.000154406666

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 72.30% 0.001382788506

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 82.79% 3.257003e-0556

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 85.41% 0.000547564606

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 89.40% 0.003921978676

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 89.51% 0.0020178678epoch 0 89.61% -1.5907999e-06epoch 0 89.72% 0.00025309564epoch 0 89.82% 0.0025684347epoch 0 89.93% 0.0004264762epoch 0 90.03% -1.5856754e-06epoch 0 90.14% 0.0006410343epoch 0 90.24% 0.00044036497epoch 0 90.35% 0.0031858848epoch 0 90.45% 0.003083799epoch 0 90.56% 0.00035074138epoch 0 90.66% 0.00011732842epoch 0 90.77% 0.0005510752

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


epoch 0 total loss 3.2029768098186224 0.0033609410386344412
epoch 1 total loss 0.8445715113322194 0.0008862240412720036
epoch 2 total loss 0.6542167932067287 0.0006864814199441014
epoch 3 total loss 0.5834956527050963 0.0006122724582424935
epoch 4 total loss 0.49960529303098156 0.0005242447985634644
aida-A [92mmicro F1: 0.8364471349545977[0m
aida-B [92mmicro F1: 0.8234113712374583[0m
msnbc [92mmicro F1: 0.9349655700076511[0m
aquaint [92mmicro F1: 0.8881118881118881[0m
ace2004 [92mmicro F1: 0.8853118712273641[0m
clueweb [92mmicro F1: 0.7614983830398849[0m
wikipedia [92mmicro F1: 0.7357443976037276[0m
att_mat_diag tensor(17.4242, device='cuda:0')
tok_score_mat_diag tensor(17.8641, device='cuda:0')
f - l1.w, b tensor(5.3760, device='cuda:0') tensor(4.0614, device='cuda:0')
f - l2.w, b tensor(0.5742, device='cuda:0') tensor(0.0088, device='cuda:0')
tensor(13.1788, device='cuda:0') tensor(1.0102, device='cuda:0')
relations tensor([18.8949,  1.6499,  1.6459,  1.6533,  1.6140], 

epoch 25 total loss 0.2687396475818673 0.0002819933342936698
epoch 26 total loss 0.26674645857889345 0.0002799018453083877
epoch 27 total loss 0.2626942512661685 0.0002756497914650247
epoch 28 total loss 0.2787444858000754 0.0002924915905562176
epoch 29 total loss 0.2471340813709162 0.0002593222259925668
aida-A [92mmicro F1: 0.8913474585116377[0m
aida-B [92mmicro F1: 0.8925306577480491[0m
msnbc [92mmicro F1: 0.9349655700076511[0m
aquaint [92mmicro F1: 0.9034965034965036[0m
ace2004 [92mmicro F1: 0.9014084507042254[0m
clueweb [92mmicro F1: 0.775601868487244[0m
wikipedia [92mmicro F1: 0.7645884180164189[0m
att_mat_diag tensor(17.4062, device='cuda:0')
tok_score_mat_diag tensor(17.6131, device='cuda:0')
f - l1.w, b tensor(5.3760, device='cuda:0') tensor(4.0614, device='cuda:0')
f - l2.w, b tensor(0.5742, device='cuda:0') tensor(0.0088, device='cuda:0')
tensor(24.8261, device='cuda:0') tensor(1.3374, device='cuda:0')
relations tensor([20.2224,  3.3582,  3.3225,  3.3127,  3.289

epoch 50 total loss 0.1687417117666996 0.00017706370594616959
epoch 51 total loss 0.1641886795111418 0.0001722861275038214
epoch 52 total loss 0.1604826781656925 0.00016839735379401102
epoch 53 total loss 0.16437056556389962 0.00017247698380262288
epoch 54 total loss 0.16477095293407729 0.00017289711745443578
aida-A [92mmicro F1: 0.9180670076192464[0m
aida-B [92mmicro F1: 0.9297658862876255[0m
msnbc [92mmicro F1: 0.9410864575363428[0m
aquaint [92mmicro F1: 0.8811188811188811[0m
ace2004 [92mmicro F1: 0.8611670020120724[0m
clueweb [92mmicro F1: 0.7754222062522458[0m
wikipedia [92mmicro F1: 0.7850011093854005[0m
change learning rate to 1e-05
save model to 


RuntimeError: [enforce fail at C:\cb\pytorch_1000000000000\work\caffe2\serialize\inline_container.cc:365] . invalid file name: .state_dict