# Storyline

This is a step by step story how this model works.

## Setup

In [1]:
from collections import defaultdict
from os import path

from allennlp.commands.elmo import ElmoEmbedder
import torch
import torch.nn as nn

from et4el.models import ELMoWeightedSum, BiLSTM, CNN, SelfAttentiveSum
from et4el.utils import Mention, Mentions, MentionHandler

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
# Load Elmo Embedder
ELMO_OPTIONS_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
ELMO_WEIGHT_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
cuda_device = -1 if device.type == "cpu" else device.index or 0
embedder = ElmoEmbedder(ELMO_OPTIONS_FILE, ELMO_WEIGHT_FILE, cuda_device=cuda_device)

In [4]:
# Define example Mentions -> batch
mention_a = Mention("Ant", "Tired of dealing with a growing jumble of build difficulties, developer James Davidson created", ", a build tool for Java projects.")
mention_b = Mention("Washington", "In the northwestern US state of", ", there are typically two harvests: one from late April to May and another from late June into July")
mentions = Mentions([mention_a, mention_b])
# Prepare (trim) mentions
mentions = MentionHandler(10, 50).prepare_mentions(mentions)
bsz = len(mentions)
mentions

Mentions(Ant, Washington)

In [5]:
# Define hyperparameters like dimensions and dropouts
mention_dropout_rate=0.5
input_dropout_rate=0.5
rnn_dim=50
cnn_dim=50
mask_dim=50
attention_dim=100
answer_num=60000
max_mention_length=10
max_context_length=50
threshold=0.58

embeddings_dim = 1024
output_dim = 2 * rnn_dim + embeddings_dim + cnn_dim
combined_dim = embeddings_dim + mask_dim

## Embedder

In [6]:
# Embed mentions
embs = embedder.embed_batch(mentions.tokens)
embs[0].shape, embs[1].shape # embs has the embeddings for each mention at the respective index

((3, 22, 1024), (3, 26, 1024))

In [7]:
# Sentence Embeddings
max_tokens_length = max(mentions.tokens_lengths)
sentence_embeddings = torch.zeros([bsz, 3, max_tokens_length, embeddings_dim], device=device)
for i, emb in enumerate(embs):
    _, token_length, _ = emb.shape
    sentence_embeddings[i, :, :token_length, :] = torch.from_numpy(emb)
sentence_embeddings.shape

torch.Size([2, 3, 26, 1024])

In [8]:
# Mention Embeddings
max_mention_tokens_length = max([mention.mention.count(" ") + 1 for mention in mentions])
mention_embeddings = torch.zeros([bsz, 3, max_mention_tokens_length, embeddings_dim], device=device)
for i, (emb, mention) in enumerate(zip(embs, mentions)):
    start_ind, end_ind = mention.borders
    mention_length = end_ind - start_ind
    mention_embeddings[i, :, :mention_length, :] = torch.from_numpy(emb[:, start_ind:end_ind, :])
mention_embeddings.shape

torch.Size([2, 3, 1, 1024])

## Sentence encoder

In [9]:
# Define networks
weighted_sum = ELMoWeightedSum().to(device)
location_LNN = nn.Linear(4, mask_dim).to(device)
input_dropout = nn.Dropout(input_dropout_rate).to(device)
bi_lstm = BiLSTM(combined_dim, rnn_dim).to(device)
attentive_sum = SelfAttentiveSum(rnn_dim * 2, attention_dim).to(device)

In [10]:
weighted_embeddings = weighted_sum(sentence_embeddings)
weighted_embeddings.shape

torch.Size([2, 26, 1024])

In [11]:
max_seq_length = max(mentions.tokens_lengths)
bsz = len(mentions)
location_tokens = torch.zeros([bsz, max_seq_length, 4], device=device)
for i, mention in enumerate(mentions):
    start_ind, end_ind = mention.borders
    location_tokens[i, :start_ind, 0] = 1.0
    location_tokens[i, start_ind, 1] = 1.0
    location_tokens[i, start_ind + 1:end_ind, 2] = 1.0
    location_tokens[i, end_ind:mention.tokens_length, 3] = 1.0
location_tokens = location_tokens.view(-1, 4)
location_tokens.shape

torch.Size([52, 4])

In [12]:
location_mask = location_LNN(location_tokens)
location_mask = location_mask.view(weighted_embeddings.size()[0], -1, mask_dim)
location_mask.shape

torch.Size([2, 26, 50])

In [13]:
weighted_embeddings = torch.cat((weighted_embeddings, location_mask), 2)
weighted_embeddings = input_dropout(weighted_embeddings)
weighted_embeddings.shape

torch.Size([2, 26, 1074])

In [14]:
sequence_lengths = torch.tensor(mentions.tokens_lengths, device=device)
sequence_rep = bi_lstm(weighted_embeddings, sequence_lengths)
sequence_rep = attentive_sum(sequence_rep)
sequence_rep.shape

torch.Size([2, 100])

## Mention Encoder

In [15]:
# Define Networks
weighted_sum = ELMoWeightedSum().to(device)
input_dropout = nn.Dropout(mention_dropout_rate).to(device)
bi_lstm = BiLSTM(embeddings_dim, embeddings_dim // 2).to(device)
attentive_sum = SelfAttentiveSum(embeddings_dim, attention_dim).to(device)
cnn = CNN(cnn_dim).to(device)

# Load char dictionary
PATH_TO_CHARDICT = path.normpath("../et4el/ontology/char_vocab.english.txt")
char_dict = defaultdict(int)
char_vocab = [u"<unk>"]
with open(PATH_TO_CHARDICT, encoding="utf-8") as f:
    char_vocab.extend(c.strip() for c in f.readlines())
    char_dict.update({c: i for i, c in enumerate(char_vocab)})

In [16]:
weighted_embeddings = weighted_sum(mention_embeddings)
weighted_embeddings = input_dropout(weighted_embeddings)
weighted_embeddings.shape

torch.Size([2, 1, 1024])

In [17]:
mention_lengths = torch.tensor([mention.mention.count(" ") + 1 for mention in mentions], device=device)
mention_word = bi_lstm(weighted_embeddings, mention_lengths)
mention_word = attentive_sum(mention_word)
mention_word.shape

torch.Size([2, 1024])

In [18]:
def pad_slice(seq, seq_length, pad_token="<none>"):
    return seq + ([pad_token] * (seq_length - len(seq)))

In [19]:
mentions_characters = [[char_dict[x] for x in list(mention.mention)] for mention in mentions]
max_span_chars = max(max(len(characters) for characters in mentions_characters), 5)
mentions_characters = [
    pad_slice(characters, max_span_chars, pad_token=0) for characters in mentions_characters
]
mention_chars = torch.tensor(mentions_characters, dtype=torch.int64, device=device)
mention_chars = cnn(mention_chars)
mention_chars.shape

torch.Size([2, 50])

In [20]:
mention_rep = torch.cat((mention_word, mention_chars), 1)
mention_rep.shape

torch.Size([2, 1074])

## Decoder

In [21]:
linear = nn.Linear(output_dim, answer_num, bias=False).to(device)

In [22]:
representation = torch.cat((sequence_rep, mention_rep), 1)
logits = linear(representation)
logits.shape

torch.Size([2, 60000])

In [23]:
PATH_TO_VOCAB = "../et4el/ontology/conll_categories.txt"
with open(PATH_TO_VOCAB, encoding="utf-8") as f:
    text = [x.strip() for x in f.readlines()]
    text = text[:answer_num]
    file_content = dict(zip(text, range(len(text))))
answer2id = file_content
id2answer = {v: k for k, v in answer2id.items()}

In [24]:
# Decode predictions to categories
outputs = torch.sigmoid(logits)
predictions = []
for output in outputs:
    output_indices = (output > threshold).nonzero().squeeze(1)
    if len(output_indices) == 0:
        output_indices = torch.argmax(output, dim=0, keepdim=True)
    predicted_categories = [(id2answer[i.item()], output[i].item()) for i in output_indices]
    predictions.append(predicted_categories)
predictions

[[('hamlets in devon', 0.5857832431793213)],
 [('sport in birmingham, west midlands', 0.5891037583351135),
  ('former new york central railroad stations', 0.5822668671607971),
  ('freedesktop.org', 0.5833342671394348),
  ('by universal television', 0.5883475542068481),
  ('populated coastal places in new zealand', 0.5829136967658997),
  ('of the amazon', 0.5838868618011475),
  ('economy of the midwestern united states', 0.6019372940063477),
  ('1st-century establishments', 0.5869855880737305),
  ('towns in boone county, arkansas', 0.5844034552574158),
  ('bayfield-class attack transports', 0.5825484991073608),
  ('historiography of england', 0.5862821936607361),
  ('economic nationalism', 0.583835780620575),
  ('buildings and structures in sydney', 0.5802591443061829),
  ('cities in wayne county, michigan', 0.5895113348960876),
  ('19th-century in mexico', 0.587650716304779),
  ('czech-american culture in kansas', 0.5931905508041382),
  ('sc freiburg', 0.6039355397224426),
  ('in pales

## With own Models

In [25]:
import torch
from et4el.encoder import MentionEncoder, SentenceEncoder, SimpleDecoder
from et4el.embedder import ELMoPretrainedEmbedder
from et4el.utils import Mentions, Mention, MentionHandler

In [26]:
pre_handler = MentionHandler(max_mention_length, max_context_length)
embedder = ELMoPretrainedEmbedder(device).to(device)
sentence_encoder = SentenceEncoder(input_dropout_rate, rnn_dim, embeddings_dim, mask_dim, attention_dim).to(device)
mention_encoder = MentionEncoder(mention_dropout_rate, cnn_dim, embeddings_dim, attention_dim).to(device)
decoder = SimpleDecoder(output_dim, answer_num).to(device)

In [27]:
mentions = pre_handler.prepare_mentions(mentions)
embeddings = embedder.embed(mentions)
sentence_rep = sentence_encoder(mentions, embeddings)
mention_rep = mention_encoder(mentions, embeddings)

output = torch.cat((sentence_rep, mention_rep), 1)
logits = decoder(output)
logits.shape

torch.Size([2, 60000])