<a href="https://colab.research.google.com/github/vladimirrim/QA_DL/blob/develop/ConvLSTM_with_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install pytorch_pretrained_bert
!pip install transformers

In [2]:
import torch
import numpy as np
from pytorch_pretrained_bert import convert_tf_checkpoint_to_pytorch
from transformers import  BertModel
from pytorch_pretrained_bert import BertConfig, BertForPreTraining

In [0]:
from torch import nn
from torch.nn import CrossEntropyLoss

In [0]:
import torch.nn as nn
from torch.autograd import Variable
import torch


class ConvLSTMCell(nn.Module):

    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.
        
        Parameters
        ----------
        input_size: (int, int)
            Height and width of input tensor as (height, width).
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.height, self.width = input_size
        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias        = bias
        
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        
        h_cur, c_cur = cur_state
        
        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        
        return h_next, c_next

    def init_hidden(self, batch_size):
        return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda(),
                Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda())


class ConvLSTM(nn.Module):

    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim  = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')

        self.height, self.width = input_size

        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]

            cell_list.append(ConvLSTMCell(input_size=(self.height, self.width),
                                          input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        """
        
        Parameters
        ----------
        input_tensor: todo 
            5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
        hidden_state: todo
            None. todo implement stateful
            
        Returns
        -------
        last_state_list, layer_output
        """
        if not self.batch_first:
            # (t, b, c, h, w) -> (b, t, c, h, w)
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        # Implement stateful ConvLSTM
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            hidden_state = self._init_hidden(batch_size=input_tensor.size(0))

        layer_output_list = []
        last_state_list   = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):

            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):

                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                                 cur_state=[h, c])
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append([h, c])

        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list   = last_state_list[-1:]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                    (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [0]:
def testConvLSTM():
  seq_len = 256
  inp = torch.zeros(16,seq_len,768).unsqueeze(2).unsqueeze(3)
  inp = inp.reshape(16, seq_len, 3, 16, 16)
  print(inp.shape)
  inp = inp.cuda()
  lst = ConvLSTM(input_size=(16, 16), input_dim=3, hidden_dim=[64, 128, 4],
                 kernel_size=(3, 3),
                 num_layers=3,
                 batch_first=True,
                 bias=True,
                 return_all_layers=False)
  lst = lst.cuda()
  out = lst(inp)
  print(out[0][0].shape)
  print(nn.Flatten(start_dim=2)(out[0][0]).shape)

In [0]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', do_lower_case=False)

In [0]:
def preprocess(text, question, answer):
    answer = answer.lower()
    if answer not in text.lower():
        return [], []
    
    firstInText = text.lower().find(answer)
    lastInText = firstInText + len(answer)
    text_tokens = tokenizer.tokenize(text[:firstInText].strip())
    first = len(text_tokens)
    text_tokens += tokenizer.tokenize(text[firstInText:lastInText].strip())
    last = len(text_tokens) - 1
    text_tokens += tokenizer.tokenize(text[lastInText:].strip())
    question_tokens = tokenizer.tokenize(question)
    
    length = MAX_TEXT_LEN - len(question_tokens) - 3
    if len(text_tokens) > length:
        part_length = length // 3
        stride = 3 * part_length
        nrow = np.ceil(len(text_tokens) / part_length) - 2
        indexes = part_length * np.arange(nrow)[:, None] + np.arange(stride)
        indexes = indexes.astype(np.int32)

        max_index = indexes.max()
        diff = max_index + 1 - len(text_tokens)
        text_tokens += diff * [tokenizer.pad_token]

        text_tokens = list(np.array(text_tokens)[indexes])
        
        tokens = []
        labels = []
        for i, ts in enumerate(text_tokens):
            while ts[-1] == tokenizer.pad_token:
                ts = ts[:-1]
                
            tokens += [ts]
                
            lfirst = first - i * part_length
            llast = last - i * part_length
            
            mask = lfirst >= 0 and lfirst < len(ts) and llast >= 0 and llast < len(ts)
            labels += [(lfirst if mask else 0, llast if mask else 0)]
    else:
        tokens = [text_tokens]
        labels = [(first, last)]
        
    for i in range(len(tokens)):
        # TODO удалять этот костыль!!!
        if str(type(tokens[i])) == "<class 'numpy.ndarray'>": 
            tokens[i] = list(tokens[i])
        tokens[i] = [tokenizer.cls_token] + \
                    question_tokens + \
                    [tokenizer.sep_token] + \
                    tokens[i] + \
                    [tokenizer.sep_token]
        labels[i] = (labels[i][0] + 2 + len(question_tokens), labels[i][1] + 2 + len(question_tokens))

    return tokens, labels

In [0]:
def pad_sequence(texts):
    max_len = max([len(text) for text in texts])
    masks = [[1] * len(text) + [0] * (max_len - len(text)) for text in texts]
    texts = [text + [tokenizer.pad_token] * (max_len - len(text)) for text in texts]
    texts = [tokenizer.convert_tokens_to_ids(text) for text in texts]
    texts = torch.LongTensor(texts)
    masks = torch.LongTensor(masks)

    return texts, masks

def collate_fn(data):
    texts, labels = zip(*data)
    texts, masks = pad_sequence(texts)
    
    labels_first, labels_last = zip(*labels)
    start_pos = labels_first
    end_pos = labels_last
    return texts, masks, torch.LongTensor(start_pos), torch.LongTensor(end_pos)

In [0]:
MAX_TEXT_LEN = 256

In [10]:
from google.colab import drive
import json

drive.mount('./gdrive')
train_dataset = './gdrive/My Drive/datasets_for_homeworks/train-v1.1.json'
dev_dataset = './gdrive/My Drive/datasets_for_homeworks/dev-v1.1.json'
with open(train_dataset, 'r') as train_json, open(dev_dataset, 'r') as dev_json:
    train_data = json.load(train_json)
    dev_data = json.load(dev_json)

Drive already mounted at ./gdrive; to attempt to forcibly remount, call drive.mount("./gdrive", force_remount=True).


In [0]:
def get_text_question_ans_dataset(squad_dataset):
    tqa_dataset = []
    for d in squad_dataset['data']:
        for p in d['paragraphs']:
            for qa in p['qas']:
                # TODO: deal with several answers
                tqa_dataset.append((p['context'], qa['question'], qa['answers'][0]['answer_start'], qa['answers'][0]['text']))
    return tqa_dataset

In [0]:
tqa_train_dataset = get_text_question_ans_dataset(train_data)
tqa_dev_dataset = get_text_question_ans_dataset(dev_data)

In [13]:
print(len(tqa_train_dataset))
print(len(tqa_dev_dataset))
print(f'Max text len in train: {max(map(lambda x: len(x[0]), tqa_train_dataset))}')
print(f'Max text len in dev: {max(map(lambda x: len(x[0]), tqa_dev_dataset))}')
print(tqa_train_dataset[0])
print(tqa_train_dataset[-1])
print(tqa_dev_dataset[0])
print(tqa_dev_dataset[-1])

87599
10570
Max text len in train: 3706
Max text len in dev: 4063
('Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?', 515, 'Saint Bernadette Soubirous')
("Kathmandu Metropolitan City (KMC), in order to promote international relations has established an International Relations

In [14]:
from tqdm.auto import tqdm

dataset_tokens, dataset_labels = [], []
for datapoint in tqdm(tqa_train_dataset):
    tokens, labels = preprocess(datapoint[0], datapoint[1], datapoint[3])
    dataset_tokens += tokens
    dataset_labels += labels

HBox(children=(IntProgress(value=0), HTML(value='')))




In [15]:
print(len(dataset_tokens))

140


In [16]:
print(dataset_tokens[0])
print(dataset_tokens[-1])

['[CLS]', 'To', 'whom', 'did', 'the', 'Virgin', 'Mary', 'allegedly', 'appear', 'in', '1858', 'in', 'Lourdes', 'France', '?', '[SEP]', 'Arch', '##ite', '##ctural', '##ly', ',', 'the', 'school', 'has', 'a', 'Catholic', 'character', '.', 'At', '##op', 'the', 'Main', 'Building', "'", 's', 'gold', 'dome', 'is', 'a', 'golden', 'statue', 'of', 'the', 'Virgin', 'Mary', '.', 'Im', '##mediate', '##ly', 'in', 'front', 'of', 'the', 'Main', 'Building', 'and', 'facing', 'it', ',', 'is', 'a', 'copper', 'statue', 'of', 'Christ', 'with', 'arms', 'up', '##rais', '##ed', 'with', 'the', 'legend', '"', 'Ve', '##nite', 'Ad', 'Me', 'Om', '##nes', '"', '.', 'Next', 'to', 'the', 'Main', 'Building', 'is', 'the', 'Basilica', 'of', 'the', 'Sacred', 'Heart', '.', 'Im', '##mediate', '##ly', 'behind', 'the', 'basilica', 'is', 'the', 'G', '##rott', '##o', ',', 'a', 'Marian', 'place', 'of', 'prayer', 'and', 'reflect', '##ion', '.', 'It', 'is', 'a', 'replica', 'of', 'the', 'gr', '##otto', 'at', 'Lourdes', ',', 'France'

In [0]:
train_data_loader = torch.utils.data.DataLoader(list(zip(dataset_tokens, dataset_labels)), batch_size=16, shuffle=True,collate_fn=collate_fn)

In [18]:
!pip3 install wandb



In [19]:
!wandb login

[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice: 2
[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: You can find your API key in your browser here: https://app.wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: 2377ef66e63c2eda02e1d83797d0cc73170988c7
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [20]:
import wandb
wandb.init(project="dul")

W&B Run: https://app.wandb.ai/ram_saw/dul/runs/n53rfftc

In [0]:
BERT = BertModel.from_pretrained('bert-base-multilingual-cased')

In [0]:
import torch.nn.functional as F

class BertForQuestionAnswering(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.bert = BertModel.from_pretrained('bert-base-multilingual-cased')
        self.bert.eval()
        self.qa_outputs = nn.Sequential(nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(64, 2))
        self.loss_fct = CrossEntropyLoss()
        
    

    def forward(self, input_ids=None, token_type_ids=None, start_positions=None, end_positions=None, mask=None):
        output = self.bert(input_ids, attention_mask=mask)

        sequence_output = output[0]

        logits = self.qa_outputs(sequence_output)
        loss = None

        if start_positions is not None and end_positions is not None:
            loss = (self.loss_fct(logits[:, :, 0].masked_fill((1 - mask).bool(), float('-inf')), start_positions) + \
                   self.loss_fct(logits[:, :, 1].masked_fill((1 - mask).bool(), float('-inf')), end_positions)) / 2

        return loss, F.softmax(logits.masked_fill((1 - mask[:, :, None]).bool(), float('-inf')), dim=1)

In [0]:
class BertForQuestionAnsweringConvLSTM(nn.Module):
    
    def __init__(self, bert, conv_hidden_dims=None):
        super().__init__()
        if conv_hidden_dims is None:
            conv_hidden_dims = [8, 16, 4]
        self.bert = bert
        self.bert.eval()
        self.convLstm = lst = ConvLSTM(input_size=(16, 16), input_dim=3, hidden_dim=conv_hidden_dims,
                                       kernel_size=(3, 3),
                                       num_layers=3,
                                       batch_first=True,
                                       bias=True,
                                       return_all_layers=False)
        self.flatten_for_qa = nn.Flatten(start_dim=2)
        self.qa_outputs = nn.Sequential(
            nn.Linear(conv_hidden_dims[-1] * 16 * 16, 512),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(64, 2))
        self.loss_fct = CrossEntropyLoss()
    

    def forward(self, input_ids=None, token_type_ids=None, start_positions=None, end_positions=None, mask=None):
        output = self.bert(input_ids, attention_mask=mask)

        output = output[0]
        output = output.reshape(output.shape[0], output.shape[1], 3, 16, 16)
        output = self.convLstm(output)
        output = output[0][0]
        sequence_output = self.flatten_for_qa(output)
        logits = self.qa_outputs(sequence_output)
        loss = None

        if start_positions is not None and end_positions is not None:
            loss = (self.loss_fct(logits[:, :, 0].masked_fill((1 - mask).bool(), float('-inf')), start_positions) + \
                   self.loss_fct(logits[:, :, 1].masked_fill((1 - mask).bool(), float('-inf')), end_positions)) / 2

        return loss, F.softmax(logits.masked_fill((1 - mask[:, :, None]).bool(), float('-inf')), dim=1)

In [0]:
model = BertForQuestionAnsweringConvLSTM(BERT)
#model.load_state_dict(torch.load('./gdrive/My Drive/bert.pt'))

In [25]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 0.00005, weight_decay=0.000001)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.3)
epochs = 3
device = 'cuda'
model.to(device)

for epoch in range(epochs):
    model.train()
    for i, (texts, masks, start_pos, end_pos) in enumerate(train_data_loader):
        optimizer.zero_grad()
        loss, _ = model(texts.to(device),
                        mask=masks.to(device),
                        start_positions=torch.tensor(start_pos).to(device),
                        end_positions=torch.tensor(end_pos).to(device))
        wandb.log({'loss' : float(loss)})
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f'Model saved on {i} iteration!')
            torch.save(model.state_dict(), './gdrive/My Drive/bert.pt')

  del sys.path[0]
  


Model saved on 0 iteration!
Model saved on 0 iteration!
Model saved on 0 iteration!


In [26]:
model.cuda()

BertForQuestionAnsweringConvLSTM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, ele

In [0]:
import torch.nn.functional as F
import re

def getBestProb(probs):
    n = len(probs)
    start, end, bestProb = 0, 0, 0
    for i in range(n):
        for j in range(i, n):
            prob = probs[i, 0] * probs[j, 1]
            if bestProb < prob:
                bestProb, start, end = prob, i, j
              
    return start, end


def concat(tokens):
    tokens = [token.replace('#', '') for token in tokens]
    return ' '.join(list(filter(lambda s: s != tokenizer.unk_token, tokens))).strip()

In [29]:
from tqdm.auto import tqdm

dev_dataset_tokens, dev_dataset_labels = [], []
for datapoint in tqdm(tqa_dev_dataset):
    tokens, labels = preprocess(datapoint[0], datapoint[1], datapoint[3])
    dev_dataset_tokens += tokens
    dev_dataset_labels += labels

HBox(children=(IntProgress(value=0, max=10570), HTML(value='')))




In [0]:
dev_data_loader = torch.utils.data.DataLoader(list(zip(dev_dataset_tokens, dev_dataset_labels)), batch_size=16, shuffle=True,collate_fn=collate_fn)

In [0]:
def test_model(model):
  #TODO load test properly
  model.eval()
  total = 0
  correct = 0
  with torch.no_grad():
    for datapoint in tqdm(tqa_dev_dataset):
        answer = datapoint[3].lower()
        if answer not in datapoint[0].lower():
             continue
        total += 1
        firstInText = datapoint[0].lower().find(answer)
        lastInText = firstInText + len(answer)
        text_tokens = tokenizer.tokenize(datapoint[0][:firstInText].strip())
        start_pos = len(text_tokens)
        text_tokens += tokenizer.tokenize(datapoint[0][firstInText:lastInText].strip())
        end_pos = len(text_tokens) - 1
        text_tokens += tokenizer.tokenize(datapoint[0][lastInText:].strip())
        question_tokens = tokenizer.tokenize(datapoint[1])
        
        all_tokens = [tokenizer.cls_token] + \
                     question_tokens + \
                     [tokenizer.sep_token] + \
                     text_tokens + \
                     [tokenizer.sep_token]

        length = MAX_TEXT_LEN - len(question_tokens) - 3
        if (len(text_tokens) > length):
            part_length = length // 3
            stride = 3 * part_length
            nrow = np.ceil(len(text_tokens) / part_length) - 2
            indexes = part_length * np.arange(nrow)[:, None] + np.arange(stride)
            indexes = indexes.astype(np.int32)

            max_index = indexes.max()
            diff = max_index + 1 - len(text_tokens)
            text_tokens += diff * [tokenizer.pad_token]

            text_tokens = np.array(text_tokens)[indexes].tolist()

            start, end, prob = 0, 0, 0
            for i, ts in enumerate(text_tokens):
                while ts[-1] == tokenizer.pad_token:
                    ts = ts[:-1]

                ts = [tokenizer.cls_token] + \
                     question_tokens + \
                     [tokenizer.sep_token] + \
                     ts + \
                     [tokenizer.sep_token]

                texts, masks = pad_sequence([ts])
                texts = texts.to(device)
                masks = masks.to(device)

                probs = model(texts, mask=masks)[1]

                size = probs.shape[1]
                m = probs[:, :, 0].view(size, 1).matmul(probs[:, :, 1].view(1, size))
                m = m.reshape(size * size)
                pos = torch.argmax(m)
                if m[pos] > prob:
                  prob = m[pos]
                  start_raw, end_raw = (pos / size).view(-1, 1).cuda(), (pos % size).view(-1, 1).cuda()
                  start, end = torch.min(start_raw, end_raw), torch.max(start_raw, end_raw)

            first = (start_pos + 2 + len(question_tokens))  == start
            second = (end_pos + 2 + len(question_tokens)) == end
            correct += int(first and second)
                    
        else:
            texts, masks = pad_sequence([all_tokens])
            texts = texts.to(device)
            masks = masks.to(device)
            probs = model(texts, mask=masks)[1]
            
            size = probs.shape[1]
            m = probs[:, :, 0].view(size, 1).matmul(probs[:, :, 1].view(1, size))
            pos = torch.argmax(m.reshape(size * size))
            start_raw, end_raw = (pos / size).view(-1, 1).cuda(), (pos % size).view(-1, 1).cuda()
            start, end = torch.min(start_raw, end_raw), torch.max(start_raw, end_raw)
            first = (start_pos + 2 + len(question_tokens))  == start
            second = (end_pos + 2 + len(question_tokens)) == end
            correct += int(first and second)
  print(f'Accuracy on dev data is {correct / total}')

In [32]:
device = 'cuda'
test_model(model)

HBox(children=(IntProgress(value=0), HTML(value='')))


Accuracy on dev data is 0.03
