In [1]:
from torchtext.vocab import Vocab
from collections import Counter
import json

with open("conll04_test.json") as f:
    data = json.load(f)
    
tokens = []
for datum in data:
    tokens += datum['tokens']

vocab = Vocab(Counter(tokens), vectors="glove.6B.100d")

In [2]:
ENTITY_TO_ID = {"O":0,"Loc":1, "Org":2, "Peop":3, "Other":4}
REL_TO_ID = {"*":0, "Work_For_arg1":1, "Kill_arg1":2, "OrgBased_In_arg1":3, "Live_In_arg1":4,
             "Located_In_arg1":5, "Work_For_arg2":6, "Kill_arg2":7, "OrgBased_In_arg2":8,
             "Live_In_arg2":9, "Located_In_arg2":10}

class DataLoader(object):
    def __init__(self, filename, vocab, batch_size=1):
        self.batch_size = batch_size
        
        with open(filename) as f:
            data = json.load(f)
            
        self.data = self.preprocess(data, vocab)
        self.data = [self.data[i: i + batch_size] for i in range(0, len(self.data), batch_size)]
        
    def preprocess(self, data, vocab):
        
        processed = []
        for i, d in enumerate(data):
#             print(i)
            if i == 10:
                break
            tokens = to_ids([t.lower() for t in d['tokens']], vocab.stoi)

            entities = ['O'] * len(tokens)
            for e in d['entities']:
                entities[e['end']] = e['type']
            entities = to_ids(entities, ENTITY_TO_ID)

            relations = []
            for r in d['relations']:
                curr = ['*'] * len(tokens)
                curr[r['head']] = r['type'] + "_arg1"
                curr[r['tail']] = r['type'] + "_arg2"

                relations += [to_ids(curr, REL_TO_ID)]
                
            relations = relations[0]

            processed += [(tokens, entities, relations)]
                
        return processed
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, key):
        if not isinstance(key, int):
            raise TypeError
        
        if key < 0 or key >= self.__len__():
            raise IndexError
            
        return list(zip(*self.data[key]))
    
    def __iter__(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)
            
def to_ids(tokens, vocab):
    """ Maps a list of tokens to the corresponding ids given by the dictionary, vocab"""

    ids = [vocab[t] if t in vocab else 0 for t in tokens]
    return ids

In [3]:
import torch

class NER_Net(torch.nn.Module):
    '''Simple Named Entity Recognition model'''

    def __init__(self, vocab_size, num_classes, hidden_dim=50, embedding_dim=100):
        super().__init__()

        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim

        # layers
        self.embedding = torch.nn.Embedding(self.vocab_size, self.embedding_dim)
        self.embedding.weight = torch.nn.Parameter(vocab.vectors)
        
        self.lstm = torch.nn.LSTM(self.embedding_dim, self.hidden_dim, batch_first=True)
        self.fc = torch.nn.Linear(self.hidden_dim, num_classes)

        # Initialize fully connected layer
        self.fc.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.fc.weight, gain=1)

    def forward(self, s):
        s = self.embedding(s)   # dim: batch_size x batch_max_len x embedding_dim
        s, _ = self.lstm(s)     # dim: batch_size x batch_max_len x lstm_hidden_dim
        s = self.fc(s)          # dim: batch_size*batch_max_len x num_tags

        return s
    
    
class RE_Net(torch.nn.Module):
    '''Simple Relation extraction model'''

    def __init__(self, vocab_size, num_classes, hidden_dim=50, embedding_dim=100):
        super().__init__()

        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim

        # layers
        self.embedding = torch.nn.Embedding(self.vocab_size, self.embedding_dim)
        self.embedding.weight = torch.nn.Parameter(vocab.vectors)
        
        self.lstm = torch.nn.LSTM(self.embedding_dim, self.hidden_dim, batch_first=True)
        self.fc = torch.nn.Linear(self.hidden_dim, num_classes)

        # Initialize fully connected layer
        self.fc.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.fc.weight, gain=1)

    def forward(self, s):
        s = self.embedding(s)   # dim: batch_size x batch_max_len x embedding_dim
        s, _ = self.lstm(s)     # dim: batch_size x batch_max_len x lstm_hidden_dim
        s = self.fc(s)          # dim: batch_size*batch_max_len x num_tags

        return s

In [4]:
train = DataLoader("conll04_test.json", vocab)

In [5]:
datum = train[0]
tokens = torch.tensor(datum[0])
entities = torch.tensor(datum[1])
relations = torch.tensor(datum[2])

In [6]:
import torch
import itertools
import torch.nn.functional as F

# from pytorch_constraints.constraint import constraint
# from pytorch_constraints.brute_force_solver import *
# from pytorch_constraints.sampling_solver import *
# from pytorch_constraints.tnorm_solver import TNormLogicSolver

# Models
ner = NER_Net(vocab_size=len(vocab), num_classes=len(ENTITY_TO_ID))
re = RE_Net(vocab_size=len(vocab), num_classes=len(REL_TO_ID))

# Optimization
opt = torch.optim.SGD(list(ner.parameters()) + list(re.parameters()), lr=1.0)

# # Plotting
# plot = PlotHelper()
# plot_loss = PlotHelper()

num_samples = 100
for i in range(500):
    opt.zero_grad()
    
    ner_logits = ner(tokens)
    ner_logits = ner_logits.view(-1, ner_logits.shape[2])
    
    re_logits = re(tokens)
    re_logits = re_logits.view(-1, re_logits.shape[2])
    
    
    ner_loss = F.cross_entropy(ner_logits, entities.view(-1))
    re_loss = F.cross_entropy(re_logits, relations.view(-1))
    loss = ner_loss + re_loss
    
    loss.backward()
    opt.step()
    
    print(loss)
    
#     y_prob = torch.softmax(y_logit, dim=-1)
#     plot.add(y0=y_prob[0,1].data, y1=y_prob[1,1].data, y2=y_prob[2,1].data)
#     plot_loss.add(oloss=oloss.data, closs=closs.data, loss=loss.data)
#     opt.step()

# plot.show()
# plot_loss.show()

tensor(4.0911, grad_fn=<AddBackward0>)
tensor(1.3064, grad_fn=<AddBackward0>)
tensor(1.0823, grad_fn=<AddBackward0>)
tensor(1.0197, grad_fn=<AddBackward0>)
tensor(0.9641, grad_fn=<AddBackward0>)
tensor(0.9125, grad_fn=<AddBackward0>)
tensor(0.8636, grad_fn=<AddBackward0>)
tensor(0.8164, grad_fn=<AddBackward0>)
tensor(0.7703, grad_fn=<AddBackward0>)
tensor(0.7250, grad_fn=<AddBackward0>)
tensor(0.6802, grad_fn=<AddBackward0>)
tensor(0.6363, grad_fn=<AddBackward0>)
tensor(0.5933, grad_fn=<AddBackward0>)
tensor(0.5518, grad_fn=<AddBackward0>)
tensor(0.5122, grad_fn=<AddBackward0>)
tensor(0.4746, grad_fn=<AddBackward0>)
tensor(0.4393, grad_fn=<AddBackward0>)
tensor(0.4062, grad_fn=<AddBackward0>)
tensor(0.3752, grad_fn=<AddBackward0>)
tensor(0.3463, grad_fn=<AddBackward0>)
tensor(0.3194, grad_fn=<AddBackward0>)
tensor(0.2943, grad_fn=<AddBackward0>)
tensor(0.2710, grad_fn=<AddBackward0>)
tensor(0.2492, grad_fn=<AddBackward0>)
tensor(0.2289, grad_fn=<AddBackward0>)
tensor(0.2098, grad_fn=<A

tensor(0.0034, grad_fn=<AddBackward0>)
tensor(0.0034, grad_fn=<AddBackward0>)
tensor(0.0033, grad_fn=<AddBackward0>)
tensor(0.0033, grad_fn=<AddBackward0>)
tensor(0.0033, grad_fn=<AddBackward0>)
tensor(0.0033, grad_fn=<AddBackward0>)
tensor(0.0033, grad_fn=<AddBackward0>)
tensor(0.0032, grad_fn=<AddBackward0>)
tensor(0.0032, grad_fn=<AddBackward0>)
tensor(0.0032, grad_fn=<AddBackward0>)
tensor(0.0032, grad_fn=<AddBackward0>)
tensor(0.0032, grad_fn=<AddBackward0>)
tensor(0.0031, grad_fn=<AddBackward0>)
tensor(0.0031, grad_fn=<AddBackward0>)
tensor(0.0031, grad_fn=<AddBackward0>)
tensor(0.0031, grad_fn=<AddBackward0>)
tensor(0.0031, grad_fn=<AddBackward0>)
tensor(0.0031, grad_fn=<AddBackward0>)
tensor(0.0030, grad_fn=<AddBackward0>)
tensor(0.0030, grad_fn=<AddBackward0>)
tensor(0.0030, grad_fn=<AddBackward0>)
tensor(0.0030, grad_fn=<AddBackward0>)
tensor(0.0030, grad_fn=<AddBackward0>)
tensor(0.0030, grad_fn=<AddBackward0>)
tensor(0.0029, grad_fn=<AddBackward0>)
tensor(0.0029, grad_fn=<A

tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0014, grad_fn=<AddBackward0>)
tensor(0.0013, grad_fn=<AddBackward0>)
tensor(0.0013, grad_fn=<AddBackward0>)
tensor(0.0013, grad_fn=<AddBackward0>)
tensor(0.0013, grad_fn=<AddBackward0>)
tensor(0.0013, grad_fn=<AddBackward0>)
tensor(0.0013, grad_fn=<AddBackward0>)
tensor(0.0013, grad_fn=<AddBackward0>)
tensor(0.0013, grad_fn=<AddBackward0>)
tensor(0.0013, grad_fn=<A

In [7]:
torch.argmax(torch.softmax(re(tokens).view(-1, 11), dim=-1), dim=-1)

tensor([3, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])