In [0]:
!pip install pytorch_pretrained_bert
!pip install transformers

Collecting pytorch_pretrained_bert
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)
[K     |██▋                             | 10kB 19.3MB/s eta 0:00:01[K     |█████▎                          | 20kB 1.8MB/s eta 0:00:01[K     |████████                        | 30kB 2.6MB/s eta 0:00:01[K     |██████████▋                     | 40kB 1.7MB/s eta 0:00:01[K     |█████████████▎                  | 51kB 2.1MB/s eta 0:00:01[K     |███████████████▉                | 61kB 2.5MB/s eta 0:00:01[K     |██████████████████▌             | 71kB 2.9MB/s eta 0:00:01[K     |█████████████████████▏          | 81kB 3.3MB/s eta 0:00:01[K     |███████████████████████▉        | 92kB 3.7MB/s eta 0:00:01[K     |██████████████████████████▌     | 102kB 2.8MB/s eta 0:00:01[K     |█████████████████████████████▏  | 112kB 2.8MB/s eta 0:00:01[K     |██████████████████████

In [0]:
import torch
import numpy as np
from pytorch_pretrained_bert import convert_tf_checkpoint_to_pytorch
from transformers import  BertModel
from pytorch_pretrained_bert import BertConfig, BertForPreTraining

In [0]:
from torch import nn
from torch.nn import CrossEntropyLoss

In [0]:
import torch.nn.functional as F

class BertForQuestionAnswering(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.bert = BertModel.from_pretrained('bert-base-multilingual-cased')
        self.bert.eval()
        self.qa_outputs = nn.Sequential(nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(64, 2))
        self.loss_fct = CrossEntropyLoss()
        
    

    def forward(self, input_ids=None, token_type_ids=None, start_positions=None, end_positions=None, mask=None):
        output = self.bert(input_ids, attention_mask=mask)

        sequence_output = output[0]

        logits = self.qa_outputs(sequence_output)
        loss = None

        if start_positions is not None and end_positions is not None:
            loss = (self.loss_fct(logits[:, :, 0].masked_fill((1 - mask).bool(), float('-inf')), start_positions) + \
                   self.loss_fct(logits[:, :, 1].masked_fill((1 - mask).bool(), float('-inf')), end_positions)) / 2

        return loss, F.softmax(logits.masked_fill((1 - mask[:, :, None]).bool(), float('-inf')), dim=1)

In [0]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', do_lower_case=False)

100%|██████████| 995526/995526 [00:00<00:00, 12780761.49B/s]


In [0]:
def preprocess(text, question, answer):
    answer = answer.lower()
    if answer not in text.lower():
        return [], []
    
    firstInText = text.lower().find(answer)
    lastInText = firstInText + len(answer)
    text_tokens = tokenizer.tokenize(text[:firstInText].strip())
    first = len(text_tokens)
    text_tokens += tokenizer.tokenize(text[firstInText:lastInText].strip())
    last = len(text_tokens) - 1
    text_tokens += tokenizer.tokenize(text[lastInText:].strip())
    question_tokens = tokenizer.tokenize(question)
    
    length = MAX_TEXT_LEN - len(question_tokens) - 3
    if len(text_tokens) > length:
        part_length = length // 3
        stride = 3 * part_length
        nrow = np.ceil(len(text_tokens) / part_length) - 2
        indexes = part_length * np.arange(nrow)[:, None] + np.arange(stride)
        indexes = indexes.astype(np.int32)

        max_index = indexes.max()
        diff = max_index + 1 - len(text_tokens)
        text_tokens += diff * [tokenizer.pad_token]

        text_tokens = list(np.array(text_tokens)[indexes])
        
        tokens = []
        labels = []
        for i, ts in enumerate(text_tokens):
            while ts[-1] == tokenizer.pad_token:
                ts = ts[:-1]
                
            tokens += [ts]
                
            lfirst = first - i * part_length
            llast = last - i * part_length
            
            mask = lfirst >= 0 and lfirst < len(ts) and llast >= 0 and llast < len(ts)
            labels += [(lfirst if mask else 0, llast if mask else 0)]
    else:
        tokens = [text_tokens]
        labels = [(first, last)]
        
    for i in range(len(tokens)):
        # TODO удалять этот костыль!!!
        if str(type(tokens[i])) == "<class 'numpy.ndarray'>": 
            tokens[i] = list(tokens[i])
        tokens[i] = [tokenizer.cls_token] + \
                    question_tokens + \
                    [tokenizer.sep_token] + \
                    tokens[i] + \
                    [tokenizer.sep_token]
        labels[i] = (labels[i][0] + 2 + len(question_tokens), labels[i][1] + 2 + len(question_tokens))

    return tokens, labels

In [0]:
def pad_sequence(texts):
    max_len = max([len(text) for text in texts])
    masks = [[1] * len(text) + [0] * (max_len - len(text)) for text in texts]
    texts = [text + [tokenizer.pad_token] * (max_len - len(text)) for text in texts]
    texts = [tokenizer.convert_tokens_to_ids(text) for text in texts]
    texts = torch.LongTensor(texts)
    masks = torch.LongTensor(masks)

    return texts, masks

def collate_fn(data):
    texts, labels = zip(*data)
    texts, masks = pad_sequence(texts)
    
    labels_first, labels_last = zip(*labels)
    start_pos = labels_first
    end_pos = labels_last
    return texts, masks, torch.LongTensor(start_pos), torch.LongTensor(end_pos)

In [0]:
MAX_TEXT_LEN = 256

In [0]:
from google.colab import drive
import json

drive.mount('./gdrive')
train_dataset = './gdrive/My Drive/datasets_for_homeworks/train-v1.1.json'
dev_dataset = './gdrive/My Drive/datasets_for_homeworks/dev-v1.1.json'
with open(train_dataset, 'r') as train_json, open(dev_dataset, 'r') as dev_json:
    train_data = json.load(train_json)
    dev_data = json.load(dev_json)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at ./gdrive


In [0]:
def get_text_question_ans_dataset(squad_dataset):
    tqa_dataset = []
    for d in squad_dataset['data']:
        for p in d['paragraphs']:
            for qa in p['qas']:
                # TODO: deal with several answers
                tqa_dataset.append((p['context'], qa['question'], qa['answers'][0]['answer_start'], qa['answers'][0]['text']))
    return tqa_dataset

In [0]:
#tqa_train_dataset = get_text_question_ans_dataset(train_data)
tqa_dev_dataset = get_text_question_ans_dataset(dev_data)

In [0]:
print(len(tqa_train_dataset))
print(len(tqa_dev_dataset))
print(f'Max text len in train: {max(map(lambda x: len(x[0]), tqa_train_dataset))}')
print(f'Max text len in dev: {max(map(lambda x: len(x[0]), tqa_dev_dataset))}')
print(tqa_train_dataset[0])
print(tqa_train_dataset[-1])
print(tqa_dev_dataset[0])
print(tqa_dev_dataset[-1])

87599
10570
Max text len in train: 3706
Max text len in dev: 4063
('Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?', 'Saint Bernadette Soubirous')
("Kathmandu Metropolitan City (KMC), in order to promote international relations has established an International Relations Secr

In [0]:
from tqdm.auto import tqdm

dataset_tokens, dataset_labels = [], []
for datapoint in tqdm(tqa_train_dataset):
    tokens, labels = preprocess(datapoint[0], datapoint[1], datapoint[2])
    dataset_tokens += tokens
    dataset_labels += labels

HBox(children=(IntProgress(value=0, max=87599), HTML(value='')))




In [0]:
print(len(dataset_tokens))

102474


In [0]:
print(dataset_tokens[0])
print(dataset_tokens[-1])

['[CLS]', 'To', 'whom', 'did', 'the', 'Virgin', 'Mary', 'allegedly', 'appear', 'in', '1858', 'in', 'Lourdes', 'France', '?', '[SEP]', 'Arch', '##ite', '##ctural', '##ly', ',', 'the', 'school', 'has', 'a', 'Catholic', 'character', '.', 'At', '##op', 'the', 'Main', 'Building', "'", 's', 'gold', 'dome', 'is', 'a', 'golden', 'statue', 'of', 'the', 'Virgin', 'Mary', '.', 'Im', '##mediate', '##ly', 'in', 'front', 'of', 'the', 'Main', 'Building', 'and', 'facing', 'it', ',', 'is', 'a', 'copper', 'statue', 'of', 'Christ', 'with', 'arms', 'up', '##rais', '##ed', 'with', 'the', 'legend', '"', 'Ve', '##nite', 'Ad', 'Me', 'Om', '##nes', '"', '.', 'Next', 'to', 'the', 'Main', 'Building', 'is', 'the', 'Basilica', 'of', 'the', 'Sacred', 'Heart', '.', 'Im', '##mediate', '##ly', 'behind', 'the', 'basilica', 'is', 'the', 'G', '##rott', '##o', ',', 'a', 'Marian', 'place', 'of', 'prayer', 'and', 'reflect', '##ion', '.', 'It', 'is', 'a', 'replica', 'of', 'the', 'gr', '##otto', 'at', 'Lourdes', ',', 'France'

In [0]:
train_data_loader = torch.utils.data.DataLoader(list(zip(dataset_tokens, dataset_labels)), batch_size=16, shuffle=True,collate_fn=collate_fn)

In [0]:
!pip3 install wandb



In [0]:
!wandb login

[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice: 2
[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: You can find your API key in your browser here: https://app.wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: 2377ef66e63c2eda02e1d83797d0cc73170988c7
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [0]:
import wandb
wandb.init(project="dul")

W&B Run: https://app.wandb.ai/ram_saw/dul/runs/r85f4nnw

In [0]:
model = BertForQuestionAnswering()
model.load_state_dict(torch.load('./gdrive/My Drive/bert.pt'))

100%|██████████| 521/521 [00:00<00:00, 188626.02B/s]
100%|██████████| 714314041/714314041 [00:13<00:00, 52209718.17B/s]


<All keys matched successfully>

In [0]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 0.00005, weight_decay=0.000001)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.3)
epochs = 3
device = 'cuda'
model.to(device)

for epoch in range(epochs):
    model.train()
    for i, (texts, masks, start_pos, end_pos) in enumerate(train_data_loader):
        optimizer.zero_grad()
        loss, _ = model(texts.to(device),
                        mask=masks.to(device),
                        start_positions=torch.tensor(start_pos).to(device),
                        end_positions=torch.tensor(end_pos).to(device))
        wandb.log({'loss' : float(loss)})
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f'Model saved on {i} iteration!')
            torch.save(model.state_dict(), './gdrive/My Drive/bert.pt')

  del sys.path[0]
  


Model saved on 0 iteration!
Model saved on 100 iteration!
Model saved on 200 iteration!
Model saved on 300 iteration!
Model saved on 400 iteration!
Model saved on 500 iteration!
Model saved on 600 iteration!
Model saved on 700 iteration!
Model saved on 800 iteration!
Model saved on 900 iteration!
Model saved on 1000 iteration!
Model saved on 1100 iteration!
Model saved on 1200 iteration!
Model saved on 1300 iteration!
Model saved on 1400 iteration!
Model saved on 1500 iteration!
Model saved on 1600 iteration!
Model saved on 1700 iteration!
Model saved on 1800 iteration!
Model saved on 1900 iteration!
Model saved on 2000 iteration!
Model saved on 2100 iteration!
Model saved on 2200 iteration!
Model saved on 2300 iteration!
Model saved on 2400 iteration!
Model saved on 2500 iteration!
Model saved on 2600 iteration!
Model saved on 2700 iteration!


In [0]:
model.cuda()

In [0]:
import torch.nn.functional as F
import re

def getBestProb(probs):
    n = len(probs)
    start, end, bestProb = 0, 0, 0
    for i in range(n):
        for j in range(i, n):
            prob = probs[i, 0] * probs[j, 1]
            if bestProb < prob:
                bestProb, start, end = prob, i, j
              
    return start, end


def concat(tokens):
    tokens = [token.replace('#', '') for token in tokens]
    return ' '.join(list(filter(lambda s: s != tokenizer.unk_token, tokens))).strip()

In [0]:
from tqdm.auto import tqdm

dev_dataset_tokens, dev_dataset_labels = [], []
for datapoint in tqdm(tqa_dev_dataset):
    tokens, labels = preprocess(datapoint[0], datapoint[1], datapoint[2])
    dev_dataset_tokens += tokens
    dev_dataset_labels += labels

HBox(children=(IntProgress(value=0, max=10570), HTML(value='')))




In [0]:
dev_data_loader = torch.utils.data.DataLoader(list(zip(dev_dataset_tokens, dev_dataset_labels)), batch_size=16, shuffle=True,collate_fn=collate_fn)

In [0]:
def test_model(model):
  #TODO load test properly
  model.eval()
  total = 0
  correct = 0
  with torch.no_grad():
    for datapoint in tqdm(tqa_dev_dataset):
        answer = datapoint[3].lower()
        if answer not in datapoint[0].lower():
             continue
        total += 1
        firstInText = datapoint[0].lower().find(answer)
        lastInText = firstInText + len(answer)
        text_tokens = tokenizer.tokenize(datapoint[0][:firstInText].strip())
        start_pos = len(text_tokens)
        text_tokens += tokenizer.tokenize(datapoint[0][firstInText:lastInText].strip())
        end_pos = len(text_tokens) - 1
        text_tokens += tokenizer.tokenize(datapoint[0][lastInText:].strip())
        question_tokens = tokenizer.tokenize(datapoint[1])
        
        all_tokens = [tokenizer.cls_token] + \
                     question_tokens + \
                     [tokenizer.sep_token] + \
                     text_tokens + \
                     [tokenizer.sep_token]

        length = MAX_TEXT_LEN - len(question_tokens) - 3
        if (len(text_tokens) > length):
            part_length = length // 3
            stride = 3 * part_length
            nrow = np.ceil(len(text_tokens) / part_length) - 2
            indexes = part_length * np.arange(nrow)[:, None] + np.arange(stride)
            indexes = indexes.astype(np.int32)

            max_index = indexes.max()
            diff = max_index + 1 - len(text_tokens)
            text_tokens += diff * [tokenizer.pad_token]

            text_tokens = np.array(text_tokens)[indexes].tolist()

            start, end, prob = 0, 0, 0
            for i, ts in enumerate(text_tokens):
                while ts[-1] == tokenizer.pad_token:
                    ts = ts[:-1]

                ts = [tokenizer.cls_token] + \
                     question_tokens + \
                     [tokenizer.sep_token] + \
                     ts + \
                     [tokenizer.sep_token]

                texts, masks = pad_sequence([ts])
                texts = texts.to(device)
                masks = masks.to(device)

                probs = model(texts, mask=masks)[1]

                size = probs.shape[1]
                m = probs[:, :, 0].view(size, 1).matmul(probs[:, :, 1].view(1, size))
                m = m.reshape(size * size)
                pos = torch.argmax(m)
                if m[pos] > prob:
                  prob = m[pos]
                  start_raw, end_raw = (pos / size).view(-1, 1).cuda(), (pos % size).view(-1, 1).cuda()
                  start, end = torch.min(start_raw, end_raw), torch.max(start_raw, end_raw)

            first = (start_pos + 2 + len(question_tokens))  == start
            second = (end_pos + 2 + len(question_tokens)) == end
            correct += int(first and second)
                    
        else:
            texts, masks = pad_sequence([all_tokens])
            texts = texts.to(device)
            masks = masks.to(device)
            probs = model(texts, mask=masks)[1]
            
            size = probs.shape[1]
            m = probs[:, :, 0].view(size, 1).matmul(probs[:, :, 1].view(1, size))
            pos = torch.argmax(m.reshape(size * size))
            start_raw, end_raw = (pos / size).view(-1, 1).cuda(), (pos % size).view(-1, 1).cuda()
            start, end = torch.min(start_raw, end_raw), torch.max(start_raw, end_raw)
            first = (start_pos + 2 + len(question_tokens))  == start
            second = (end_pos + 2 + len(question_tokens)) == end
            correct += int(first and second)
  print(f'Accuracy on dev data is {correct / total}')

In [0]:
device = 'cuda'
test_model(model)

HBox(children=(IntProgress(value=0, max=10570), HTML(value='')))

Accuracy on dev data is 0.484957426679281
