In [None]:
import os
import re
import numpy as np
import pandas as pd

import logging
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pytorch_pretrained_bert.tokenization import BertTokenizer, BasicTokenizer
from pytorch_pretrained_bert.modeling import BertModel

from helperbot import BaseBot, TriangularLR


'''
Fork and eddit from:
https://www.kaggle.com/ceshine/pytorch-bert-baseline-public-score-0-54

We use this notebook to generate BERT embeddings for two mentions and the gender pronoun.
We do not remove punctuation during data pre-processing

This part can also be used as base deep learning model
'''

In [None]:
# df_train = pd.read_csv("gap-test.tsv", delimiter="\t")
df_test = pd.read_csv("gap-development.tsv", delimiter="\t")
df_train_val = pd.concat([
    pd.read_csv("gap-test.tsv", delimiter="\t"),
    pd.read_csv("gap-validation.tsv", delimiter="\t")
], axis=0)

In [None]:
'''
We modify the output of Head Model, so that it will only output extracted bert embeddings
'''

class Head(nn.Module):
    """The MLP submodule"""
    def __init__(self, bert_hidden_size: int):
        super().__init__()
        
        self.bert_hidden_size = bert_hidden_size
        self.fc = nn.Sequential(
            nn.BatchNorm1d(bert_hidden_size * 3),
            nn.Dropout(0.5),
            nn.Linear(bert_hidden_size * 3, 512),    # bert_hidden_size * 3
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.5),
            nn.Linear(512, 3)
        )
        for i, module in enumerate(self.fc):
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
                print("Initing batchnorm")
            elif isinstance(module, nn.Linear):
                if getattr(module, "weight_v", None) is not None:
                    nn.init.uniform_(module.weight_g, 0, 1)
                    nn.init.kaiming_normal_(module.weight_v)
                    print("Initing linear with weight normalization")
                    assert model[i].weight_g is not None
                else:
                    nn.init.kaiming_normal_(module.weight)
                    print("Initing linear")
                nn.init.constant_(module.bias, 0)
                
    def forward(self, bert_outputs, offsets):
        assert bert_outputs.size(2) == self.bert_hidden_size
        
        extracted_outputs = bert_outputs.gather(1, offsets.unsqueeze(2).expand(-1, -1, bert_outputs.size(2))
        ).view(bert_outputs.size(0), 3, -1)

        '''
        We modify the output of Head Model, so that it will only output extracted bert embeddings
        '''
        return extracted_outputs
    
class GAPModel(nn.Module):
    """The main model."""
    def __init__(self, bert_model: str, device: torch.device):
        super().__init__()
        self.device = device
        self.bert_hidden_size = 1024
        self.bert = BertModel.from_pretrained(bert_model).to(device)
        self.head = Head(self.bert_hidden_size).to(device)
    
    def forward(self, token_tensor, offsets):
        token_tensor = token_tensor.to(self.device)
        
        bert_outputs, _ =  self.bert(
            token_tensor, 
            attention_mask=(token_tensor > 0).long(), 
            token_type_ids=None, 
            output_all_encoded_layers=False)  
        head_outputs = self.head(bert_outputs, offsets.to(self.device))
        '''
        Only output BERT embeddings here.
        '''
        return head_outputs            

In [None]:
BERT_MODEL = 'bert-large-uncased'
CASED = True


def insert_tag(row):
    to_be_inserted = sorted([
        (row["A-offset"], " [THISISA] "),
        (row["B-offset"], " [THISISB] "),
        (row["Pronoun-offset"], " [THISISP] ")
    ], key=lambda x: x[0], reverse=True)
    
    text = row["Text"]
    for offset, tag in to_be_inserted:
        text = text[:offset] + tag + text[offset:]
    return text


def tokenize(text, tokenizer):
    entries = {}
    final_tokens = []
    for token in tokenizer.tokenize(text):
        if token in ("[THISISA]", "[THISISB]", "[THISISP]"):
            entries[token] = len(final_tokens)
            continue
        final_tokens.append(token)
    return final_tokens, (entries["[THISISA]"], entries["[THISISB]"], entries["[THISISP]"])


class GAPDataset(Dataset):
    def __init__(self, df, tokenizer, labeled=True):
        self.labeled = labeled
        if labeled:
            tmp = df[["A-coref", "B-coref"]].copy()
            tmp["Neither"] = ~(df["A-coref"] | df["B-coref"])
            self.y = tmp.values.astype("bool")
        
        # Extracts the tokens and offsets(positions of A, B, and P)
        self.offsets = []
        self.tokens = []
        for _, row in df.iterrows():
            text = insert_tag(row)
            tokens, offsets = tokenize(text, tokenizer)
            self.offsets.append(offsets)
            self.tokens.append(tokenizer.convert_tokens_to_ids(["[CLS]"] + tokens + ["[SEP]"]))
        
        
    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        if self.labeled:
            return self.tokens[idx], self.offsets[idx], self.y[idx]
        
        return self.tokens[idx], self.offsets[idx], None

    
    
def collate_examples(batch, truncate_len=500):

    transposed = list(zip(*batch))
    
    max_len = min( max((len(x) for x in transposed[0])),  truncate_len)
    tokens = np.zeros((len(batch), max_len), dtype=np.int64)
    for i, row in enumerate(transposed[0]):
        row = np.array(row[:truncate_len])
        tokens[i, :len(row)] = row
    

    token_tensor = torch.from_numpy(tokens)
    offsets = torch.stack([torch.LongTensor(x) for x in transposed[1]], dim=0) + 1 # Account for the [CLS] token
    one_hot_labels = torch.stack([torch.from_numpy(x.astype("uint8")) for x in transposed[2]], dim=0)
    
    _, labels = one_hot_labels.max(dim=1) 
    return token_tensor, offsets, labels

tokenizer = BertTokenizer.from_pretrained(
    BERT_MODEL,
    do_lower_case = CASED,
    never_split = ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", "[THISISA]", "[THISISB]", "[THISISP]")
)

tokenizer.vocab["[THISISA]"] = -1
tokenizer.vocab["[THISISB]"] = -1
tokenizer.vocab["[THISISP]"] = -1

In [None]:
train_val_ds = GAPDataset(df_train_val, tokenizer)
test_ds = GAPDataset(df_test, tokenizer)

train_loader = DataLoader(
    train_val_ds,
    collate_fn = collate_examples,
    batch_size = 1,
    shuffle=False,
)

test_loader = DataLoader(
    test_ds,
    collate_fn = collate_examples,
    batch_size = 1,
    shuffle=False,
)

In [None]:
def children(m):
    return m if isinstance(m, (list, tuple)) else list(m.children())
def set_trainable_attr(m, b):
    m.trainable = b
    for p in m.parameters():
        p.requires_grad = b
def apply_leaf(m, f):
    c = children(m)
    if isinstance(m, nn.Module):
        f(m)
    if len(c) > 0:
        for l in c:
            apply_leaf(l, f)
def set_trainable(l, b):
    apply_leaf(l, lambda m: set_trainable_attr(m, b))

model = GAPModel(BERT_MODEL, torch.device("cuda:0"))
set_trainable(model.bert, False)
set_trainable(model.head, True)

In [None]:
bert_outputs = []

model.eval()
with torch.no_grad():
    for token_tensor, offsets, labels in train_loader:
        prediction = model(token_tensor, offsets)
        bert_outputs.append(prediction)

In [None]:
import pickle
pickle.dump(bert_outputs, open('others_bert_outputs.pkl', "wb"))

In [None]:
test_others_bert_outputs = []

model.eval()
with torch.no_grad():
    for token_tensor, offsets, labels in test_loader:
        prediction = model(token_tensor, offsets)
        test_others_bert_outputs.append(prediction)

In [None]:
pickle.dump(bert_outputs, open('test_others_bert_outputs.pkl', "wb"))