In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (20, 15)

import re
import os
import io
import nltk
import pickle
import numpy as np
import pandas as pd
from bs4 import BeautifulSoup
from tqdm import tqdm_notebook as tqdm
from nltk import word_tokenize, sent_tokenize
from sklearn.model_selection import train_test_split
from IPython.core.display import display, HTML

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence

nltk.download("punkt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# data

In [None]:
def remove_title(article):
    return "\n\n".join(article.split("\n\n")[1:])

In [None]:
base_path = "/mnt/efs/wikipedia/dumps/text/"
paths = np.random.choice(os.listdir(base_path), size=2)

all_text = ""
for path in paths:
    for filename in tqdm(os.listdir(base_path + path)):
        with open(base_path + path + "/" + filename, "rb") as f:
            all_text += f.read().decode("latin1")

pattern = r"(?:<doc.+>)((.|\s|\S)*?)(?:<\/doc>)"
articles = [remove_title(article[0]) for article in re.findall(pattern, all_text)]

In [None]:
len(articles)

### cleaning pipeline

In [None]:
def tokenize(sentence):
    """moses tokeniser"""
    seq = " ".join(word_tokenize(sentence))
    seq = seq.replace(" n't ", "n 't ")
    return seq.split()


def kmp(sequence, sub):
    """
    Knuth–Morris–Pratt algorithm, returning the starting position
    of a specified subsequence within another, larger sequence.
    Usually used for string matching.
    """
    partial = [0]
    for i in range(1, len(sub)):
        j = partial[i - 1]
        while j > 0 and sub[j] != sub[i]:
            j = partial[j - 1]
        partial.append(j + 1 if sub[j] == sub[i] else j)

    positions, j = [], 0
    for i in range(len(sequence)):
        while j > 0 and sequence[i] != sub[j]:
            j = partial[j - 1]
        if sequence[i] == sub[j]:
            j += 1
        if j == len(sub):
            positions.append(i - (j - 1))
            j = 0

    return positions


def label(tokenised_sequences, link_tokens):
    target_sequences = []

    for i, sequence in enumerate(tokenised_sequences):
        target_sequence = np.zeros(len(sequence))

        for link in link_tokens:
            start_positions = kmp(sequence, link)
            for pos in start_positions:
                target_sequence[pos : pos + len(link)] = 1

        target_sequences.append(target_sequence)

    return target_sequences


def label_linkable_tokens(text, label_all=True):
    parsed_html = BeautifulSoup(text, "html.parser")

    link_tokens = [tokenize(link.text) for link in parsed_html.find_all("a")]

    tokenised_sequences = [
        tokenize(sentence) for sentence in sent_tokenize(parsed_html.text)
    ]

    target_sequences = label(tokenised_sequences, link_tokens)

    return tokenised_sequences, target_sequences

In [None]:
token_sequences, target_sequences = [], []

for article in tqdm(articles):
    try:
        tokenised_seqs, target_seqs = label_linkable_tokens(article)
        token_sequences.extend(tokenised_seqs)
        target_sequences.extend(target_seqs)
    except:
        pass

# character level inputs

In [None]:
unique_characters = set(" ".join([token for seq in token_sequences for token in seq]))

In [None]:
special_cases = ["xxunk", "xxpad", "xxbos", "xxeos"]

for case in special_cases:
    unique_characters.add(case)

In [None]:
char_to_ix = {char: ix for ix, char in enumerate(unique_characters)}
ix_to_char = {ix: char for ix, char in enumerate(unique_characters)}

# fasttext and a word vector embedding matrix 

In [None]:
wv_path = "/mnt/efs/text/word_vectors/wiki-news-300d-1M.vec"
wv_file = io.open(wv_path, "r", encoding="utf-8", newline="\n", errors="ignore")
lines_to_parse = list(wv_file)[1:]

fasttext = {
    line.split()[0]: np.array(line.split()[1:]).astype(np.float32)
    for line in tqdm(lines_to_parse)
}

In [None]:
from collections import Counter

all_tokens = [tok for seq in token_sequences for tok in seq]

article_vocabulary, _ = zip(*Counter(all_tokens).most_common(10000000000))
article_vocabulary = set(article_vocabulary)

In [None]:
for case in special_cases:
    article_vocabulary.add(case)
    fasttext[case] = np.random.random(300)

In [None]:
len(article_vocabulary)

In [None]:
article_vocabulary_list = list(article_vocabulary)
token_to_ix = {token: index for index, token in enumerate(article_vocabulary_list)}
ix_to_token = {index: token for index, token in enumerate(article_vocabulary_list)}

In [None]:
word_vector_embedding_matrix = torch.FloatTensor(
    [
        fasttext[token] if token in fasttext else fasttext["xxunk"]
        for token in article_vocabulary
    ]
)

# dataset and dataloader

In [None]:
class SentenceDataset(Dataset):
    def __init__(self, token_seqs, target_seqs):
        # impose length constraint
        where_big_enough = np.where([len(seq) > 3 for seq in token_seqs])
        self.token_seqs = np.array(token_seqs)[where_big_enough]
        self.target_seqs = np.array(target_seqs)[where_big_enough]

        # indexify
        self.char_ix_seqs = [self.indexify_chars(seq) for seq in self.token_seqs]

        self.token_seqs = [self.indexify_tokens(seq) for seq in self.token_seqs]

        # find prediction points for language model
        self.exit_ix_seqs = [self.find_exit_points(seq) for seq in self.char_ix_seqs]

    def __getitem__(self, ix):
        char_ix_seq = self.char_ix_seqs[ix]
        token_seq = self.token_seqs[ix]
        exit_ix_seq = self.exit_ix_seqs[ix]
        target_seq = self.target_seqs[ix]
        return char_ix_seq, token_seq, exit_ix_seq, target_seq

    def __len__(self):
        return len(self.token_seqs)

    def indexify_tokens(self, token_seq):
        ix_seq = np.array(
            [
                token_to_ix[token]
                if token in article_vocabulary
                else token_to_ix["xxunk"]
                for token in token_seq
            ]
        )
        return torch.LongTensor(ix_seq)

    def indexify_chars(self, token_seq):
        ix_seq = np.array(
            [char_to_ix["xxbos"], char_to_ix[" "]]
            + [char_to_ix[char] for char in " ".join(token_seq)]
            + [char_to_ix[" "], char_to_ix["xxeos"]]
        )
        return torch.LongTensor(ix_seq)

    def find_exit_points(self, char_ix_seq):
        binary = char_ix_seq == char_to_ix[" "]
        return binary.nonzero().squeeze()

In [None]:
def collate_fn(batch):
    char_ix_seqs, token_seqs, exit_ix_seqs, target_seqs = zip(*batch)

    char_seq_lens = torch.LongTensor([len(char_seq) for char_seq in char_ix_seqs])

    sorted_char_lengths, sort_indicies = char_seq_lens.sort(dim=0, descending=True)

    sorted_char_seqs = [char_ix_seqs[i] for i in sort_indicies]
    sorted_token_seqs = [token_seqs[i] for i in sort_indicies]
    sorted_exit_seqs = [exit_ix_seqs[i] for i in sort_indicies]
    sorted_target_seqs = [torch.LongTensor(target_seqs[i]) for i in sort_indicies]
    sorted_token_lengths = torch.LongTensor([len(seq) for seq in sorted_token_seqs])

    padded_char_seqs = pad_sequence(
        sequences=sorted_char_seqs, padding_value=char_to_ix["xxpad"], batch_first=True
    )

    padded_token_seqs = pad_sequence(
        sequences=sorted_token_seqs,
        padding_value=token_to_ix["xxpad"],
        batch_first=True,
    )

    padded_exit_seqs = pad_sequence(
        sequences=sorted_exit_seqs, padding_value=0, batch_first=True
    )

    padded_target_seqs = pad_sequence(
        sequences=sorted_target_seqs, padding_value=0, batch_first=True
    )

    return (
        padded_char_seqs,
        padded_token_seqs,
        padded_exit_seqs,
        sorted_char_lengths,
        sorted_token_lengths,
        padded_target_seqs,
    )

In [None]:
train_tokens, test_tokens, train_targets, test_targets = train_test_split(
    token_sequences, target_sequences, test_size=0.05, random_state=42
)

In [None]:
train_dataset = SentenceDataset(train_tokens, train_targets)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=64,
    num_workers=5,
    shuffle=True,
    collate_fn=collate_fn,
)

In [None]:
test_dataset = SentenceDataset(test_tokens, test_targets)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=1,
    num_workers=5,
    shuffle=True,
    collate_fn=collate_fn,
)

# model

In [None]:
class CharacterLevelNetwork(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super(CharacterLevelNetwork, self).__init__()
        self.output_dim = output_dim
        self.embedding = nn.Embedding(input_dim, embedding_dim)

        self.char_level_lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            bidirectional=True,
        )

        self.head_fwd = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, output_dim),
        )

        self.head_bwd = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, output_dim),
        )

    def forward(self, char_seqs, exit_seqs, lengths):
        x = self.embedding(char_seqs)

        x = pack_padded_sequence(x, lengths=lengths, batch_first=True)

        x, _ = self.char_level_lstm(x)
        out, _ = pad_packed_sequence(x, batch_first=True)

        # pop out the character embeddings at position of the end of each token
        out = torch.stack([out[i, exit_seqs[i]] for i in range(len(out))])

        out_fwd, out_bwd = torch.chunk(out, 2, 2)

        pred_fwd = self.head_fwd(out_fwd[:, 1:])
        pred_bwd = self.head_bwd(out_bwd[:, :-1])

        return pred_fwd, pred_bwd

In [None]:
class LinkLabeller(nn.Module):
    def __init__(self, word_vector_embedding_matrix, hidden_dim=1024):
        super(LinkLabeller, self).__init__()
        self.wv_embedding = nn.Embedding.from_pretrained(word_vector_embedding_matrix)

        self.cln = CharacterLevelNetwork(
            input_dim=len(unique_characters),
            embedding_dim=50,
            hidden_dim=128,
            output_dim=50,
        )

        self.lstm_input_size = word_vector_embedding_matrix.shape[1] + (
            self.cln.output_dim * 2
        )

        self.word_level_lstm = nn.LSTM(
            input_size=self.lstm_input_size,
            hidden_size=hidden_dim,
            num_layers=2,
            bidirectional=True,
            dropout=0.2,
        )

        self.head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 2, hidden_dim // 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 32, 2),
        )

    def forward(self, char_seqs, token_seqs, exit_seqs, c_lens, t_lens):
        wv_seqs = self.wv_embedding(token_seqs)
        char_fwd, char_bwd = self.cln(char_seqs, exit_seqs, c_lens)

        concats = torch.cat([char_fwd, char_bwd, wv_seqs], dim=2)

        sorted_lengths, sort_indicies = t_lens.sort(dim=0, descending=True)

        concats = torch.stack([concats[i] for i in sort_indicies])

        packed = pack_padded_sequence(concats, lengths=sorted_lengths, batch_first=True)

        packed_embedded, _ = self.word_level_lstm(packed)
        embedded, _ = pad_packed_sequence(packed_embedded)

        output = self.head(embedded).permute(1, 2, 0)
        return output, sort_indicies

In [None]:
model = LinkLabeller(word_vector_embedding_matrix).to(device)

# training

In [None]:
stacked = np.hstack(train_targets)
a = len(stacked) - stacked.sum()
b = stacked.sum()
class_weights = torch.Tensor([b, a]) / (b + a)

In [None]:
losses = []

torch.backends.cudnn.benchmark = True

trainable_parameters = filter(lambda p: p.requires_grad, model.parameters())

optimiser = optim.Adam(trainable_parameters, lr=0.001)

loss_function = nn.CrossEntropyLoss(weight=class_weights.cuda())

In [None]:
def train(model, train_loader, loss_function, optimiser, n_epochs):
    model.train()
    for epoch in range(n_epochs):
        loop = tqdm(train_loader)
        for c_seqs, t_seqs, exit_seqs, c_lens, t_lens, targets in loop:
            c_seqs = torch.LongTensor(c_seqs).cuda(non_blocking=True)
            t_seqs = torch.LongTensor(t_seqs).cuda(non_blocking=True)
            exit_seqs = torch.LongTensor(exit_seqs).cuda(non_blocking=True)
            c_lens = torch.LongTensor(c_lens).cuda(non_blocking=True)
            t_lens = torch.LongTensor(t_lens).cuda(non_blocking=True)
            targets = torch.LongTensor(targets).cuda(non_blocking=True)

            optimiser.zero_grad()
            preds, sort_indicies = model(c_seqs, t_seqs, exit_seqs, c_lens, t_lens)

            targets = torch.stack([targets[i] for i in sort_indicies])

            loss = loss_function(preds, targets)
            loss.backward()
            optimiser.step()

            losses.append(loss.item())
            loop.set_description("Epoch {}/{}".format(epoch + 1, n_epochs))
            loop.set_postfix(loss=np.mean(losses[-100:]))

        torch.save(model.state_dict(), "/mnt/efs/models/model_state_dict.pt")

In [None]:
train(
    model=model,
    train_loader=train_loader,
    loss_function=loss_function,
    optimiser=optimiser,
    n_epochs=3,
)

In [None]:
loss_data = pd.Series(losses[20:]).rolling(window=100).mean()
ax = loss_data.plot()
ax.set_ylim(0.1, 0.3);

# test the model on unseen data

In [None]:
iterable_loader = iter(test_loader)

In [None]:
def format_output(tokens, targets, preds):
    target_string, pred_string = "", ""

    for token_id, target, pred in zip(tokens, targets, preds):
        token = ix_to_token[token_id.item()]

        if target.item() == 1:
            target_string += "<b>" + token + "</b> "
        else:
            target_string += token + " "

        if pred.item() == 1:
            pred_string += "<b>" + token + "</b> "
        else:
            pred_string += token + " "

    output_string = (
        "PRED:<br>"
        + pred_string
        + "<br><br>TARG:<br>"
        + target_string
        + "<br><br>------------------------<br><br>"
    )

    return output_string

In [None]:
output = ""
samples = [next(iterable_loader) for i in range(10)]

for (c_seqs, t_seqs, exit_seqs, c_lens, t_lens, targets) in samples:
    c_seqs = torch.LongTensor(c_seqs).cuda(non_blocking=True)
    t_seqs = torch.LongTensor(t_seqs).cuda(non_blocking=True)
    exit_seqs = torch.LongTensor(exit_seqs).cuda(non_blocking=True)
    c_lens = torch.LongTensor(c_lens).cuda(non_blocking=True)
    t_lens = torch.LongTensor(t_lens).cuda(non_blocking=True)
    targets = torch.LongTensor(targets).cuda(non_blocking=True)

    optimiser.zero_grad()
    preds, sort_indicies = model(c_seqs, t_seqs, exit_seqs, c_lens, t_lens)
    preds = nn.LogSoftmax(dim=1)(preds).argmax(dim=1)

    targets = torch.stack([targets[i] for i in sort_indicies])

    target_string = []
    pred_string = []

    output += format_output(t_seqs[0], targets[0], preds[0])

display(HTML(output))

In [None]:
torch.save(model.state_dict(), "/mnt/efs/models/nerd/model_state_dict.pt")

with open("/mnt/efs/models/nerd/token_to_ix.pkl", "wb") as f:
    pickle.dump(token_to_ix, f)

with open("/mnt/efs/models/nerd/ix_to_token.pkl", "wb") as f:
    pickle.dump(ix_to_token, f)

with open("/mnt/efs/models/nerd/char_to_ix.pkl", "wb") as f:
    pickle.dump(char_to_ix, f)

with open("/mnt/efs/models/nerd/unique_characters.pkl", "wb") as f:
    pickle.dump(unique_characters, f)

with open("/mnt/efs/models/nerd/article_vocabulary.pkl", "wb") as f:
    pickle.dump(article_vocabulary, f)

torch.save(word_vector_embedding_matrix, "/mnt/efs/models/nerd/embedding_matrix.pt")