In [None]:
!pip install pytorch_pretrained_bert
!pip install transformers

In [0]:
import torch
import numpy as np
from pytorch_pretrained_bert import convert_tf_checkpoint_to_pytorch
from transformers import  BertModel
from pytorch_pretrained_bert import BertConfig, BertForPreTraining

In [0]:
from torch import nn
from torch.nn import CrossEntropyLoss

In [0]:
import torch.nn.functional as F

class BertForQuestionAnswering(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.bert = BertModel.from_pretrained('bert-base-multilingual-cased')
        self.bert.eval()
        self.qa_outputs = nn.Sequential(nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(64, 2))
        self.loss_fct = CrossEntropyLoss()
        
    

    def forward(self, input_ids=None, token_type_ids=None, start_positions=None, end_positions=None, mask=None):
        output = self.bert(input_ids, attention_mask=mask)

        sequence_output = output[0]

        logits = self.qa_outputs(sequence_output)
        loss = None

        if start_positions is not None and end_positions is not None:
            loss = (self.loss_fct(logits[:, :, 0].masked_fill((1 - mask).bool(), float('-inf')), start_positions) + \
                   self.loss_fct(logits[:, :, 1].masked_fill((1 - mask).bool(), float('-inf')), end_positions)) / 2

        return loss, F.softmax(logits.masked_fill((1 - mask[:, :, None]).bool(), float('-inf')), dim=1)

In [0]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', do_lower_case=False)

In [0]:
def preprocess(text, question, ans):
    answer = answer.lower()
    if answer not in text.lower():
        return [], []
    
    firstInText = text.lower().find(answer)
    lastInText = first + len(answer)
    text_tokens = tokenizer.tokenize(text[:firstInText].strip())
    first = len(text_tokens)
    text_tokens += tokenizer.tokenize(text[firstInText:lastInText].strip())
    last = len(text_tokens) - 1
    text_tokens += tokenizer.tokenize(text[lastInText:].strip())
    question_tokens = tokenizer.tokenize(question)
    
    length = MAX_TEXT_LEN - len(question_tokens) - 3
    if len(text_tokens) > length:
        part_length = length // 3
        stride = 3 * part_length
        nrow = np.ceil(len(text_tokens) / part_length) - 2
        indexes = part_length * np.arange(nrow)[:, None] + np.arange(stride)
        print(indexes)
        indexes = indexes.astype(np.int32)

        max_index = indexes.max()
        diff = max_index + 1 - len(text_tokens)
        text_tokens += diff * [tokenizer.pad_token]

        text_tokens = list(np.array(text_tokens)[indexes])
        
        tokens = []
        labels = []
        for i, ts in enumerate(text_tokens):
            while ts[-1] == tokenizer.pad_token:
                ts = ts[:-1]
                
            tokens += [ts]
                
            lfirst = first - i * part_length
            llast = last - i * part_length
            
            mask = lfirst >= 0 and lfirst < len(ts) and llast >= 0 and llast < len(ts)
            labels += [(lfirst if mask else 0, llast if mask else 0)]
    else:
        tokens = [text_tokens]
        labels = [(first, last)]
        
    for i in range(len(tokens)):
        tokens[i] = [tokenizer.cls_token] + \
                    question_tokens + \
                    [tokenizer.sep_token] + \
                    tokens[i] + \
                    [tokenizer.sep_token]
        labels[i] = ((labels[i][0][0] + 2 + len(question_tokens), labels[i][0][1]),
                     (labels[i][1][0] + 2 + len(question_tokens), labels[i][1][1]))

    return tokens, labels

In [0]:
def pad_sequence(texts):
    max_len = max([len(text) for text in texts])
    masks = [[1] * len(text) + [0] * (max_len - len(text)) for text in texts]
    texts = [text + [tokenizer.pad_token] * (max_len - len(text)) for text in texts]
    texts = [tokenizer.convert_tokens_to_ids(text) for text in texts]
    texts = torch.LongTensor(texts)
    masks = torch.LongTensor(masks)

    return texts, masks

def collate_fn(data):
    texts, labels = zip(*data)
    texts, masks = pad_sequence(texts)
    
    labels_first, labels_last = zip(*labels)
    start_pos = zip(*labels_first)
    end_pos = zip(*labels_last)
    
    return texts, masks, torch.LongTensor(start_pos), torch.LongTensor(end_pos)

In [0]:
MAX_TEXT_LEN = 256

In [0]:
import csv
from tqdm import tqdm

dataset_tokens, dataset_labels = [], []
#TODO load dataset

50365it [01:53, 445.48it/s]


In [0]:
print(len(dataset_tokens))

50363


In [0]:
train_data_loader = torch.utils.data.DataLoader(list(zip(dataset_tokens, dataset_labels)), batch_size=16, shuffle=True,collate_fn=collate_fn)

In [None]:
!pip3 install wandb

In [0]:
!wandb login

In [None]:
import wandb
wandb.init(project="dul")


In [0]:
model = BertForQuestionAnswering(bert)

In [None]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 0.00005, weight_decay=0.000001)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.3)
epochs = 3
device = 'cuda'
model.to(device)

for epoch in range(epochs):
    model.train()
    for texts, masks, start_pos, end_pos in train_data_loader:
        optimizer.zero_grad()
        loss, _ = model(texts.to(device),
                        mask=masks.to(device),
                        start_positions=torch.tensor(start_pos).to(device),
                        end_positions=torch.tensor(end_pos).to(device))
        wandb.log({'loss' : float(loss)})
        loss.backward()
        optimizer.step()

In [0]:
model.cuda()

In [0]:
import torch.nn.functional as F
import re

def getBestProb(probs):
    n = len(probs)
    start, end, bestProb = 0, 0, 0
    for i in range(n):
        for j in range(i, n):
            prob = probs[i, 0] * probs[j, 1]
            if bestProb < prob:
                bestProb, start, end = prob, i, j
              
    return start, end


def concat(tokens):
    tokens = [token.replace('#', '') for token in tokens]
    return ' '.join(list(filter(lambda s: s != tokenizer.unk_token, tokens))).strip()
              


with open('output', 'w') as f:
     f.write('')

#TODO load test properly
model.eval()
with torch.no_grad():
  with open('test.txt', newline='') as csvfile:
    test = csv.reader(csvfile, delimiter='\t')
    for row in tqdm(test):
        question_id, paragraph, question = row[-3], row[-2], row[-1]

        question_tokens = tokenizer.tokenize(question)
        text_tokens = tokenizer.tokenize(paragraph)
        
        all_tokens = [tokenizer.cls_token] + \
                     question_tokens + \
                     [tokenizer.sep_token] + \
                     text_tokens + \
                     [tokenizer.sep_token]

        length = MAX_TEXT_LEN - len(question_tokens) - 3
        if (len(text_tokens) > length):
            part_length = length // 3
            stride = 3 * part_length
            nrow = np.ceil(len(text_tokens) / part_length) - 2
            indexes = part_length * np.arange(nrow)[:, None] + np.arange(stride)
            indexes = indexes.astype(np.int32)

            max_index = indexes.max()
            diff = max_index + 1 - len(text_tokens)
            text_tokens += diff * [tokenizer.pad_token]
            text_tokens = list(np.array(text_tokens)[indexes])

            all_probs = []
            for i, ts in enumerate(text_tokens):
                while ts[-1] == tokenizer.pad_token:
                    ts = ts[:-1]

                tokens = [tokenizer.cls_token] + \
                     question_tokens + \
                     [tokenizer.sep_token] + \
                     ts + \
                     [tokenizer.sep_token]

                texts, masks = pad_sequence([tokens])
                texts = texts.to(device)
                masks = masks.to(device)

                _, probs = model(texts, mask=masks)
                probs = probs.squeeze(0)
                probs = lps[2 + len(question_tokens):]
                probs[:, 0] = F.softmax(probs[:, 0], 0)
                probs[:, 1] = F.softmax(probs[:, 1], 0)
                probs = probs.cpu().numpy()
                all_probs += list(probs)
            
            start, end = get_best(np.array(all_probs))
                
            start += 2 + len(question_tokens)
            end += 2 + len(question_tokens)
                    
        else:
            texts, masks = pad_sequence([all_tokens])
            texts = texts.to(device)
            masks = masks.to(device)

            _, probs = model(texts, mask=masks)
            probs = probs.squeeze(0)
            probs[:, 0] = F.softmax(probs[:, 0], 0)
            probs[:, 1] = F.softmax(probs[:, 1], 0)
            probs = probs.cpu().numpy()
            
            start, end = get_best(probs)
            
        s = concat(all_tokens[start: end + 1])
        with open('output', 'a') as f:
            f.write(question_id + '\t' + s + '\n')

1001it [00:23, 47.22it/s]
