In [1]:
import json
import torch
import torch.nn as nn
from transformers import AutoModelForTokenClassification

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, tokenizer):
        with open(file_path, 'r') as f:
            self.data = json.load(f)

        self.tokenizer = tokenizer
        self.contexts = []
        self.questions = []
        self.answers = []
        for group in self.data['data']:
            for passage in group['paragraphs']:
                context = passage['context']
                for qa in passage['qas']:
                    question = qa['question']
                    for answer in qa['answers']:
                        self.answers.append(answer)
                        self.questions.append(question)
                        self.contexts.append(context)

    def convert_to_token_start_end_pos(self, encodings, answer):
        start = encodings.char_to_token(0, answer['answer_start'])
        end = encodings.char_to_token(0, answer['answer_end'])

        # if start = None, the answers have been truncated
        if start == None:
            start = self.tokenizer.model_max_length

        # if end == None, the 'char_to_token' function points to the space after the correct token, so add - 1
        if end == None:
            end = encodings.char_to_token(0, answer['answer_end'] - 1)
            # if end is still None, the answers have been truncated
            if end == None:
                end = self.tokenizer.model_max_length

        encodings['start_positions'] = start
        encodings['end_positions'] = end

        return encodings

    def __getitem__(self, index):
        context = self.contexts[index]
        question = self.questions[index]
        answer = self.answers[index]

        real_answer = answer['text']
        start_idx = answer['answer_start']
        # Get the real end index
        end_idx = start_idx + len(real_answer)

        # Deal with the problem of 1 or 2 more characters 
        if context[start_idx:end_idx] == real_answer:
            answer['answer_end'] = end_idx
        # When the real answer is more by one character
        elif context[start_idx-1:end_idx-1] == real_answer:
            answer['answer_start'] = start_idx - 1
            answer['answer_end'] = end_idx - 1  
        # When the real answer is more by two characters  
        elif context[start_idx-2:end_idx-2] == real_answer:
            answer['answer_start'] = start_idx - 2
            answer['answer_end'] = end_idx - 2

        encodings = self.tokenizer(context, question, truncation=True, padding=True)
        encodings = self.convert_to_token_start_end_pos(encodings, answer)

        return {key: torch.tensor(val) for key, val in encodings.items()}

    def __len__(self):
        return len(self.questions)

In [4]:
class CustomCollator:
    def __init__(self, tokenizer):
        self.pad_token_id = tokenizer.pad_token_id

    def __call__(self, samples):
        batch_size = len(samples)
        assert batch_size == 1, f'Only batch_size=1 supported, got batch_size={batch_size}.'

        sample = samples[0]

        max_seq_length = tokenizer.model_max_length
        padded_length = max_seq_length

        input_shape = (1, padded_length)
        input_ids = torch.full(input_shape,
                               self.pad_token_id,
                               dtype=torch.long)
        attention_mask = torch.zeros(input_shape, dtype=torch.long)

        seq_length = len(sample['input_ids'])
        input_ids[0, :seq_length] = sample['input_ids']
        attention_mask[0, :seq_length] = sample['attention_mask']

        start_positions = sample['start_positions']
        end_positions = sample['end_positions']

        return dict(input_ids=input_ids,
                    attention_mask=attention_mask,
                    start_positions=start_positions,
                    end_positions=end_positions)

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("roberta-base")

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

In [6]:
# train_data = SquadDataset('../input/squad-20/train-v2.0.json', tokenizer)
train_data = SquadDataset('data/train-v2.0.json', tokenizer)

In [7]:
from torch.utils.data import DataLoader

collate = CustomCollator(tokenizer)
train_loader = DataLoader(train_data,
                      batch_size=1,
                      shuffle=True,
                      collate_fn=collate)

In [8]:
class QASpanDetector(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModelForTokenClassification.from_pretrained('roberta-base', num_labels=2)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids,
                             attention_mask=attention_mask)
        return outputs

In [9]:
import torch.nn.functional as F

In [10]:
model = QASpanDetector()
model = model.to(device)

Downloading:   0%|          | 0.00/478M [00:00<?, ?B/s]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForTokenClassification: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForTokenClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able

In [12]:
def iou_loss(reg_pred, reg_target):
    batch_size = reg_pred.shape[0]
    intersection = torch.minimum(reg_pred, reg_target)
    union = torch.maximum(reg_pred, reg_target)
    iou = intersection / union
    reg_loss = -iou.log().sum() / batch_size
    return reg_loss

In [14]:
from torch.optim import AdamW

torch.manual_seed(0)
epochs = 4
print_every = 100
optim = AdamW(model.parameters(), lr=1e-5)
obj_criterion = nn.CrossEntropyLoss(reduction='sum')
for epoch in range(epochs):
    # Set model in train mode
    model.train()
    loss_of_epoch = 0

    print("############Train############")
    for batch_idx, batch in enumerate(train_loader):
        optim.zero_grad()
        sentence_length = batch['input_ids'].size(1)
        batch_size = batch['input_ids'].size(0)

        answer_start = batch['start_positions']
        attention_mask = batch['attention_mask']
        answer_length = batch['end_positions'] - batch['start_positions'] + 1
        answer_length = answer_length.to(device)

        outputs = model(batch['input_ids'].to(device), batch['attention_mask'].to(device))
        obj_pred = outputs['logits'][:, :, 0]
        reg_pred = outputs['logits'][:, :, 1]
        if int(answer_start) >= 512:
            continue
        else:
            reg_pred = reg_pred[:, int(answer_start)].exp()
            reg_loss = iou_loss(reg_pred, answer_length)
            obj_target = torch.zeros(obj_pred.shape)
            obj_target[:, int(answer_start)] = 1
            obj_target = obj_target.to(device)
            obj_loss = obj_criterion(obj_pred, obj_target) / batch_size
            loss = reg_loss + obj_loss
            
        loss.backward()
        optim.step()
        loss_of_epoch += loss.item()
        if (batch_idx + 1) % print_every == 0:
            print("Batch {:} / {:}".format(batch_idx + 1, len(train_loader)))
            print("Reg Loss:", round(reg_loss.item(), 2))
            print("Obj Loss:", round(obj_loss.item(), 2))
            print("Loss:", round(loss.item(), 2))
            # torch.save(model, "/kaggle/working/yolo_qa.pth")
            torch.save(model, "models/yolo_qa.pth")
    loss_of_epoch /= len(train_loader)
    print("\n-------Epoch ", epoch + 1,
        "-------"
        "\nTraining Loss:", loss_of_epoch,
        "\n-----------------------",
        "\n\n")

############Train############
Batch 100 / 86821
Reg Loss: 0.14
Obj Loss: 4.57
Loss: 4.71
Batch 200 / 86821
Reg Loss: 1.07
Obj Loss: 3.8
Loss: 4.87
Batch 300 / 86821
Reg Loss: 0.36
Obj Loss: 0.59
Loss: 0.95
Batch 400 / 86821
Reg Loss: 0.65
Obj Loss: 2.25
Loss: 2.9
Batch 500 / 86821
Reg Loss: 0.62
Obj Loss: 0.42
Loss: 1.04
Batch 600 / 86821
Reg Loss: 0.24
Obj Loss: 2.68
Loss: 2.92
Batch 700 / 86821
Reg Loss: 1.52
Obj Loss: 0.41
Loss: 1.93
Batch 800 / 86821
Reg Loss: 0.27
Obj Loss: 0.7
Loss: 0.97
Batch 900 / 86821
Reg Loss: 1.14
Obj Loss: 0.29
Loss: 1.44
Batch 1000 / 86821
Reg Loss: 0.12
Obj Loss: 0.89
Loss: 1.01
Batch 1100 / 86821
Reg Loss: 0.21
Obj Loss: 1.3
Loss: 1.51
Batch 1200 / 86821
Reg Loss: 0.53
Obj Loss: 2.1
Loss: 2.63
Batch 1300 / 86821
Reg Loss: 0.08
Obj Loss: 0.21
Loss: 0.28
Batch 1400 / 86821
Reg Loss: 0.75
Obj Loss: 1.17
Loss: 1.92
Batch 1500 / 86821
Reg Loss: 0.02
Obj Loss: 2.84
Loss: 2.86
Batch 1600 / 86821
Reg Loss: 0.63
Obj Loss: 0.12
Loss: 0.75
Batch 1700 / 86821
Reg L

In [15]:
torch.save(model, "models/yolo_qa.pth")