In [13]:
from torch.utils.data import Dataset
import json

class CMRC2018(Dataset):
    def __init__(self, data_file):
        self.data = self.load_data(data_file)

    def load_data(self, data_file):
        Data = {}
        with open(data_file, 'r', encoding='utf-8') as f:
            json_data = json.load(f)
            idx = 0
            for article in json_data['data']:
                title = article['title']
                context = article['paragraphs'][0]['context']
                for question in article['paragraphs'][0]['qas']:
                    q_id = question['id']
                    ques = question['question']
                    text = [ans['text'] for ans in question['answers']]
                    answer_start = [ans['answer_start'] for ans in question['answers']]
                    Data[idx] = {
                        'id': q_id,
                        'title': title,
                        'context': context,
                        'question': ques,
                        'answers': {
                            'text': text,
                            'answer_start': answer_start
                        }
                    }
                    idx += 1
        return Data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

train_data = CMRC2018('data/cmrc2018/cmrc2018_train.json')
valid_data = CMRC2018('data/cmrc2018/cmrc2018_dev.json')
test_data = CMRC2018('data/cmrc2018/cmrc2018_trial.json')

In [14]:
print(f'train set size: {len(train_data)}')
print(f'valid set size: {len(valid_data)}')
print(f'test set size: {len(test_data)}')
print(next(iter(valid_data)))

train set size: 10142
valid set size: 3219
test set size: 1002
{'id': 'DEV_0_QUERY_0', 'title': '战国无双3', 'context': '《战国无双3》（）是由光荣和ω-force开发的战国无双系列的正统第三续作。本作以三大故事为主轴，分别是以武田信玄等人为主的《关东三国志》，织田信长等人为主的《战国三杰》，石田三成等人为主的《关原的年轻武者》，丰富游戏内的剧情。此部份专门介绍角色，欲知武器情报、奥义字或擅长攻击类型等，请至战国无双系列1.由于乡里大辅先生因故去世，不得不寻找其他声优接手。从猛将传 and Z开始。2.战国无双 编年史的原创男女主角亦有专属声优。此模式是任天堂游戏谜之村雨城改编的新增模式。本作中共有20张战场地图（不含村雨城），后来发行的猛将传再新增3张战场地图。但游戏内战役数量繁多，部分地图会有兼用的状况，战役虚实则是以光荣发行的2本「战国无双3 人物真书」内容为主，以下是相关介绍。（注：前方加☆者为猛将传新增关卡及地图。）合并本篇和猛将传的内容，村雨城模式剔除，战国史模式可直接游玩。主打两大模式「战史演武」&「争霸演武」。系列作品外传作品', 'question': '《战国无双3》是由哪两个公司合作开发的？', 'answers': {'text': ['光荣和ω-force', '光荣和ω-force', '光荣和ω-force'], 'answer_start': [11, 11, 11]}}


In [17]:
from transformers import AutoTokenizer

checkpoint = 'bert-base-chinese'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [23]:
#we handle the context into : question + text= [CLS]question[SEP]context[SEP]
#Example:
context = [train_data.data[i]['context'] for i in range(4)]
question = [train_data.data[i]['question'] for i in range(4)]
inputs = tokenizer(
    question,
    context,
    max_length=300,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True
)

print(inputs.keys())
print(f"The 4 examples gave {len(inputs['input_ids'])} features.")
print(f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.")

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])
The 4 examples gave 14 features.
Here is where each comes from: [0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3].


In [25]:
answers = [train_data[idx]["answers"] for idx in range(4)]
start_positions = []
end_positions = []
print(answers)

[{'text': ['1963年'], 'answer_start': [30]}, {'text': ['1990年被擢升为天主教河内总教区宗座署理'], 'answer_start': [41]}, {'text': ['范廷颂于1919年6月15日在越南宁平省天主教发艳教区出生'], 'answer_start': [97]}, {'text': ['1994年3月23日，范廷颂被教宗若望保禄二世擢升为天主教河内总教区总主教并兼天主教谅山教区宗座署理'], 'answer_start': [548]}]


In [33]:

for i, offset in enumerate(inputs["offset_mapping"]):
    sample_idx = inputs["overflow_to_sample_mapping"][i]
    answer = answers[sample_idx]
    start_char = answer["answer_start"][0]
    end_char = answer["answer_start"][0] + len(answer["text"][0])
    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of the context
    idx = 0
    while sequence_ids[idx] != 1:
        idx += 1
    context_start = idx
    while sequence_ids[idx] == 1:
        idx += 1
    context_end = idx - 1

    # If the answer is not fully inside the context, label is (0, 0)
    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
        start_positions.append(0)
        end_positions.append(0)
    else:
        # Otherwise it's the start and end token positions
        idx = context_start
        while idx <= context_end and offset[idx][0] <= start_char:
            idx += 1
        start_positions.append(idx - 1)

        idx = context_end
        while idx >= context_start and offset[idx][1] >= end_char:
            idx -= 1
        end_positions.append(idx + 1)

print(start_positions)
print(end_positions)

[47, 0, 0, 0, 53, 0, 0, 100, 0, 0, 0, 0, 61, 0, 47, 0, 0, 0, 53, 0, 0, 100, 0, 0, 0, 0, 61, 0, 47, 0, 0, 0, 53, 0, 0, 100, 0, 0, 0, 0, 61, 0, 47, 0, 0, 0, 53, 0, 0, 100, 0, 0, 0, 0, 61, 0, 47, 0, 0, 0, 53, 0, 0, 100, 0, 0, 0, 0, 61, 0, 47, 0, 0, 0, 53, 0, 0, 100, 0, 0, 0, 0, 61, 0]
[48, 0, 0, 0, 70, 0, 0, 124, 0, 0, 0, 0, 106, 0, 48, 0, 0, 0, 70, 0, 0, 124, 0, 0, 0, 0, 106, 0, 48, 0, 0, 0, 70, 0, 0, 124, 0, 0, 0, 0, 106, 0, 48, 0, 0, 0, 70, 0, 0, 124, 0, 0, 0, 0, 106, 0, 48, 0, 0, 0, 70, 0, 0, 124, 0, 0, 0, 0, 106, 0, 48, 0, 0, 0, 70, 0, 0, 124, 0, 0, 0, 0, 106, 0]


In [34]:
idx = 0
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]

start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs["input_ids"][idx][start : end + 1])

print(f"Theoretical answer: {answer}, labels give: {labeled_answer}")

Theoretical answer: 1963年, labels give: 1963 年


In [36]:
import torch
from torch.utils.data import DataLoader

max_length = 384
stride = 128

def train_collote_fn(batch_samples):
    batch_question, batch_context, batch_answers = [], [], []
    for sample in batch_samples:
        batch_question.append(sample['question'])
        batch_context.append(sample['context'])
        batch_answers.append(sample['answers'])
    batch_data = tokenizer(
        batch_question,
        batch_context,
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding='max_length',
        return_tensors="pt"
    )

    offset_mapping = batch_data.pop('offset_mapping')
    sample_mapping = batch_data.pop('overflow_to_sample_mapping')

    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_mapping[i]
        answer = batch_answers[sample_idx]
        start_char = answer['answer_start'][0]
        end_char = answer['answer_start'][0] + len(answer['text'][0])
        sequence_ids = batch_data.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)
    return batch_data, torch.tensor(start_positions), torch.tensor(end_positions)

train_dataloader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=train_collote_fn)

In [37]:
import torch

batch_X, batch_Start, batch_End = next(iter(train_dataloader))
print('batch_X shape:', {k: v.shape for k, v in batch_X.items()})
print('batch_Start shape:', batch_Start.shape)
print('batch_End shape:', batch_End.shape)
print(batch_X)
print(batch_Start)
print(batch_End)

print('train set size: ', )
print(len(train_data), '->', sum([batch_data['input_ids'].shape[0] for batch_data, _, _ in train_dataloader]))

batch_X shape: {'input_ids': torch.Size([9, 384]), 'token_type_ids': torch.Size([9, 384]), 'attention_mask': torch.Size([9, 384])}
batch_Start shape: torch.Size([9])
batch_End shape: torch.Size([9])
{'input_ids': tensor([[ 101, 6205, 5401,  ..., 1075,  511,  102],
        [ 101, 6205, 5401,  ...,  671,  763,  102],
        [ 101, 6205, 5401,  ...,    0,    0,    0],
        ...,
        [ 101, 5276, 5432,  ...,  872, 2199,  102],
        [ 101, 5276, 5432,  ..., 4638, 1355,  102],
        [ 101, 5276, 5432,  ...,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ...

In [64]:
# for test we just care about the text of answer ; but not sart/end positiom, so handle the data loader:
def test_collote_fn(batch_samples):
    batch_id, batch_question, batch_context = [], [], []
    for sample in batch_samples:
        batch_id.append(sample['id'])
        batch_question.append(sample['question'])
        batch_context.append(sample['context'])
    batch_data = tokenizer(
        batch_question,
        batch_context,
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
        return_tensors="pt"
    )

    offset_mapping = batch_data.pop('offset_mapping').numpy().tolist()
    sample_mapping = batch_data.pop('overflow_to_sample_mapping')
    example_ids = []

    for i in range(len(batch_data['input_ids'])):
        sample_idx = sample_mapping[i]
        example_ids.append(batch_id[sample_idx])

        sequence_ids = batch_data.sequence_ids(i)
        offset = offset_mapping[i]
        offset_mapping[i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]
    return batch_data, offset_mapping, example_ids

valid_dataloader = DataLoader(valid_data, batch_size=8, shuffle=False, collate_fn=test_collote_fn)
test_dataloader = DataLoader(test_data, batch_size=8, shuffle=False, collate_fn=test_collote_fn)


In [40]:
batch_X, offset_mapping, example_ids = next(iter(valid_dataloader))
print('batch_X shape:', {k: v.shape for k, v in batch_X.items()})
print(example_ids)

print('valid set size: ')
print(len(valid_data), '->', sum([batch_data['input_ids'].shape[0] for batch_data, _, _ in valid_dataloader]))

batch_X shape: {'input_ids': torch.Size([16, 384]), 'token_type_ids': torch.Size([16, 384]), 'attention_mask': torch.Size([16, 384])}
['DEV_0_QUERY_0', 'DEV_0_QUERY_0', 'DEV_0_QUERY_1', 'DEV_0_QUERY_1', 'DEV_0_QUERY_2', 'DEV_0_QUERY_2', 'DEV_1_QUERY_0', 'DEV_1_QUERY_0', 'DEV_1_QUERY_1', 'DEV_1_QUERY_1', 'DEV_1_QUERY_2', 'DEV_1_QUERY_2', 'DEV_1_QUERY_3', 'DEV_1_QUERY_3', 'DEV_2_QUERY_0', 'DEV_2_QUERY_0']
valid set size: 
3219 -> 6254


In [41]:
# build model:
# BERT + fc : where the fc layer output the position of start position and end position
# 2 output, start: each token as start position's score
from torch import nn
from transformers import AutoConfig
from transformers import BertPreTrainedModel, BertModel

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

class BertForExtractiveQA(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.post_init()

    def forward(self, x):
        bert_output = self.bert(**x)
        sequence_output = bert_output.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        return start_logits, end_logits

config = AutoConfig.from_pretrained(checkpoint)
model = BertForExtractiveQA.from_pretrained(checkpoint, config=config).to(device)
print(model)

Some weights of BertForExtractiveQA were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using cuda device
BertForExtractiveQA(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,)

In [44]:

train_dataloader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=train_collote_fn)

batch_X, _, _ = next(iter(train_dataloader))
start_outputs, end_outputs = model(batch_X.to(device))
print('batch_X shape:', {k: v.shape for k, v in batch_X.items()})
print('start_outputs shape', start_outputs.shape)
print('end_outputs shape', end_outputs.shape)

batch_X shape: {'input_ids': torch.Size([7, 384]), 'token_type_ids': torch.Size([7, 384]), 'attention_mask': torch.Size([7, 384])}
start_outputs shape torch.Size([7, 384])
end_outputs shape torch.Size([7, 384])


  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [45]:
from tqdm.auto import tqdm

def train_loop(dataloader, model, loss_fn, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch-1) * len(dataloader)

    model.train()
    for batch, (X, start_pos, end_pos) in enumerate(dataloader, start=1):
        X, start_pos, end_pos = X.to(device), start_pos.to(device), end_pos.to(device)
        start_pred, end_pred = model(X)
        start_loss = loss_fn(start_pred, start_pos)
        end_loss = loss_fn(end_pred, end_pos)
        loss = (start_loss + end_loss) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    return total_loss

In [46]:
valid_data = CMRC2018('data/cmrc2018/cmrc2018_dev.json')
small_eval_set = [valid_data[idx] for idx in range(12)]

trained_checkpoint = "uer/roberta-base-chinese-extractive-qa"
tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)
eval_set = DataLoader(small_eval_set, batch_size=4, shuffle=False, collate_fn=test_collote_fn)

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import AutoModelForQuestionAnswering
trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(device)

tokenizer_config.json:   0%|          | 0.00/216 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/452 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/407M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/407M [00:00<?, ?B/s]

In [47]:
start_logits = []
end_logits = []

trained_model.eval()
for batch_data, _, _ in eval_set:
    batch_data = batch_data.to(device)
    with torch.no_grad():
        outputs = trained_model(**batch_data)
    start_logits.append(outputs.start_logits.cpu().numpy())
    end_logits.append(outputs.end_logits.cpu().numpy())

import numpy as np
start_logits = np.concatenate(start_logits)
end_logits = np.concatenate(end_logits)

In [48]:
all_example_ids = []
all_offset_mapping = []
for _, offset_mapping, example_ids in eval_set:
    all_example_ids += example_ids
    all_offset_mapping += offset_mapping

import collections
example_to_features = collections.defaultdict(list)
for idx, feature_id in enumerate(all_example_ids):
    example_to_features[feature_id].append(idx)

print(example_to_features)

n_best = 20
max_answer_length = 30
theoretical_answers = [
    {"id": ex["id"], "answers": ex["answers"]} for ex in small_eval_set
]
predicted_answers = []

for example in small_eval_set:
    example_id = example["id"]
    context = example["context"]
    answers = []

    for feature_index in example_to_features[example_id]:
        start_logit = start_logits[feature_index]
        end_logit = end_logits[feature_index]
        offsets = all_offset_mapping[feature_index]

        start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
        end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
        for start_index in start_indexes:
            for end_index in end_indexes:
                if offsets[start_index] is None or offsets[end_index] is None:
                    continue
                if (end_index < start_index or end_index - start_index + 1 > max_answer_length):
                    continue
                answers.append(
                    {
                        "start": offsets[start_index][0],
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                )
    if len(answers) > 0:
        best_answer = max(answers, key=lambda x: x["logit_score"])
        predicted_answers.append({
            "id": example_id,
            "prediction_text": best_answer["text"],
            "answer_start": best_answer["start"]
        })
    else:
        predicted_answers.append({
            "id": example_id,
            "prediction_text": "",
            "answer_start": 0
        })

defaultdict(<class 'list'>, {'DEV_0_QUERY_0': [0, 1], 'DEV_0_QUERY_1': [2, 3], 'DEV_0_QUERY_2': [4, 5], 'DEV_1_QUERY_0': [6, 7], 'DEV_1_QUERY_1': [8, 9], 'DEV_1_QUERY_2': [10, 11], 'DEV_1_QUERY_3': [12, 13], 'DEV_2_QUERY_0': [14, 15], 'DEV_2_QUERY_1': [16, 17], 'DEV_2_QUERY_2': [18, 19], 'DEV_3_QUERY_0': [20], 'DEV_3_QUERY_1': [21]})


In [49]:
for pred, label in zip(predicted_answers, theoretical_answers):
    print(pred['id'])
    print('pred:', pred['prediction_text'])
    print('label:', label['answers']['text'])

DEV_0_QUERY_0
pred: 光荣和ω-force
label: ['光荣和ω-force', '光荣和ω-force', '光荣和ω-force']
DEV_0_QUERY_1
pred: 任天堂游戏谜之村雨城
label: ['村雨城', '村雨城', '任天堂游戏谜之村雨城']
DEV_0_QUERY_2
pred: 「战史演武」&「争霸演武」
label: ['「战史演武」&「争霸演武」', '「战史演武」&「争霸演武」', '「战史演武」&「争霸演武」']
DEV_1_QUERY_0
pred: 锣鼓经是大陆传统器乐及戏曲里面常用的打击乐记谱方法
label: ['大陆传统器乐及戏曲里面常用的打击乐记谱方法', '大陆传统器乐及戏曲里面常用的打击乐记谱方法', '大陆传统器乐及戏曲里面常用的打击乐记谱方法']
DEV_1_QUERY_1
pred: 「锣鼓点」
label: ['锣鼓点', '锣鼓点', '锣鼓点']
DEV_1_QUERY_2
pred: 依照角色行当的身份、性格、情绪以及环境，配合相应的锣鼓点。
label: ['依照角色行当的身份、性格、情绪以及环境，配合相应的锣鼓点。', '依照角色行当的身份、性格、情绪以及环境，配合相应的锣鼓点。', '依照角色行当的身份、性格、情绪以及环境，配合相应的锣鼓点']
DEV_1_QUERY_3
pred: 戏曲锣鼓所运用的敲击乐器主要分为鼓、锣、钹和板四类型
label: ['鼓、锣、钹和板', '鼓、锣、钹和板', '鼓、锣、钹和板']
DEV_2_QUERY_0
pred: 全长364.6公里
label: ['364.6公里', '364.6公里', '364.6公里']
DEV_2_QUERY_1
pred: 三茂铁路股份有限公司
label: ['三茂铁路股份有限公司', '三茂铁路股份有限公司', '三茂铁路股份有限公司']
DEV_2_QUERY_2
pred: 1903年
label: ['1903年', '1903年', '1903年']
DEV_3_QUERY_0
pred: 山东省北部环渤海地区
label: ['山东省北部环渤海地区', '山东省北部环渤海地区', '山东省北部环渤海地区']
DEV_3_QUERY_1
pred: 11.42亿元
label: ['

In [55]:
from cmrc2018_evaluate import evaluate
result = evaluate(predicted_answers, theoretical_answers)
print(f"F1: {result['f1']:>0.2f} EM: {result['em']:>0.2f} AVG: {result['avg']:>0.2f}\n")

F1: 92.63 EM: 75.00 AVG: 83.81



In [58]:
import collections
from cmrc2018_evaluate import evaluate

n_best = 20
max_answer_length = 30

def train_loop(dataloader, model, loss_fn, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch-1) * len(dataloader)

    model.train()
    for batch, (X, start_pos, end_pos) in enumerate(dataloader, start=1):
        X, start_pos, end_pos = X.to(device), start_pos.to(device), end_pos.to(device)
        start_pred, end_pred = model(X)
        start_loss = loss_fn(start_pred, start_pos)
        end_loss = loss_fn(end_pred, end_pos)
        loss = (start_loss + end_loss) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    return total_loss

In [62]:
import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from transformers import AutoTokenizer, AutoConfig
from transformers import BertPreTrainedModel, BertModel
from transformers import AdamW, get_scheduler
import json
import collections
import sys
from tqdm.auto import tqdm
sys.path.append('./')
from cmrc2018_evaluate import evaluate
def test_loop(dataloader, dataset, model):
    all_example_ids = []
    all_offset_mapping = []
    for _, offset_mapping, example_ids in dataloader:
        all_example_ids += example_ids
        all_offset_mapping += offset_mapping
    example_to_features = collections.defaultdict(list)
    for idx, feature_id in enumerate(all_example_ids):
        example_to_features[feature_id].append(idx)

    start_logits = []
    end_logits = []
    model.eval()
    for batch_data, _, _ in tqdm(dataloader):
        batch_data = batch_data.to(device)
        with torch.no_grad():
            pred_start_logits, pred_end_logit = model(batch_data)
        start_logits.append(pred_start_logits.cpu().numpy())
        end_logits.append(pred_end_logit.cpu().numpy())
    start_logits = np.concatenate(start_logits)
    end_logits = np.concatenate(end_logits)

    theoretical_answers = [
        {"id": dataset[s_idx]["id"], "answers": dataset[s_idx]["answers"]} for s_idx in range(len(dataset))
    ]
    predicted_answers = []
    for s_idx in tqdm(range(len(dataset))):
        example_id = dataset[s_idx]["id"]
        context = dataset[s_idx]["context"]
        answers = []
        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = all_offset_mapping[feature_index]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    if (end_index < start_index or end_index-start_index+1 > max_answer_length):
                        continue
                    answers.append({
                        "start": offsets[start_index][0],
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    })
        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append({
                "id": example_id,
                "prediction_text": best_answer["text"],
                "answer_start": best_answer["start"]
            })
        else:
            predicted_answers.append({
                "id": example_id,
                "prediction_text": "",
                "answer_start": 0
            })
    result = evaluate(predicted_answers, theoretical_answers)
    print(f"F1: {result['f1']:>0.2f} EM: {result['em']:>0.2f} AVG: {result['avg']:>0.2f}\n")
    return result

loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num*len(train_dataloader),
)

total_loss = 0.
best_avg_score = 0.
for t in range(epoch_num):
    print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
    total_loss = train_loop(train_dataloader, model, loss_fn, optimizer, lr_scheduler, t+1, total_loss)
    valid_scores = test_loop(valid_dataloader, valid_data, model)
    avg_score = valid_scores['avg']
    if avg_score > best_avg_score:
        best_avg_score = avg_score
        print('saving new weights...\n')
        torch.save(model.state_dict(), f'epoch_{t+1}_valid_avg_{avg_score:0.4f}_model_weights.bin')
print("Done!")

Epoch 1/3
-------------------------------


  0%|          | 0/2536 [00:00<?, ?it/s]

  0%|          | 0/403 [00:00<?, ?it/s]

  0%|          | 0/3219 [00:00<?, ?it/s]

F1: 83.38 EM: 62.50 AVG: 72.94

saving new weights...

Epoch 2/3
-------------------------------


  0%|          | 0/2536 [00:00<?, ?it/s]

  0%|          | 0/403 [00:00<?, ?it/s]

  0%|          | 0/3219 [00:00<?, ?it/s]

F1: 84.88 EM: 65.58 AVG: 75.23

saving new weights...

Epoch 3/3
-------------------------------


  0%|          | 0/2536 [00:00<?, ?it/s]

  0%|          | 0/403 [00:00<?, ?it/s]

  0%|          | 0/3219 [00:00<?, ?it/s]

F1: 84.63 EM: 64.27 AVG: 74.45

Done!


In [65]:
model.load_state_dict(torch.load('epoch_2_valid_avg_75.2314_model_weights.bin'))

model.eval()
with torch.no_grad():
    print('evaluating on test set...')
    all_example_ids = []
    all_offset_mapping = []
    for _, offset_mapping, example_ids in test_dataloader:
        all_example_ids += example_ids
        all_offset_mapping += offset_mapping
    example_to_features = collections.defaultdict(list)
    for idx, feature_id in enumerate(all_example_ids):
        example_to_features[feature_id].append(idx)

    start_logits = []
    end_logits = []
    model.eval()
    for batch_data, _, _ in tqdm(test_dataloader):
        batch_data = batch_data.to(device)
        pred_start_logits, pred_end_logit = model(batch_data)
        start_logits.append(pred_start_logits.cpu().numpy())
        end_logits.append(pred_end_logit.cpu().numpy())
    start_logits = np.concatenate(start_logits)
    end_logits = np.concatenate(end_logits)

    theoretical_answers = [
        {"id": test_data[s_idx]["id"], "answers": test_data[s_idx]["answers"]} for s_idx in range(len(test_dataloader))
    ]
    predicted_answers = []
    save_resluts = []
    for s_idx in tqdm(range(len(test_data))):
        example_id = test_data[s_idx]["id"]
        context = test_data[s_idx]["context"]
        title = test_data[s_idx]["title"]
        question = test_data[s_idx]["question"]
        labels = test_data[s_idx]["answers"]
        answers = []
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = all_offset_mapping[feature_index]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    if (end_index < start_index or end_index-start_index+1 > max_answer_length):
                        continue
                    answers.append({
                        "start": offsets[start_index][0],
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    })
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append({
                "id": example_id,
                "prediction_text": best_answer["text"],
                "answer_start": best_answer["start"]
            })
            save_resluts.append({
                "id": example_id,
                "title": title,
                "context": context,
                "question": question,
                "answers": labels,
                "prediction_text": best_answer["text"],
                "answer_start": best_answer["start"]
            })
        else:
            predicted_answers.append({
                "id": example_id,
                "prediction_text": "",
                "answer_start": 0
            })
            save_resluts.append({
                "id": example_id,
                "title": title,
                "context": context,
                "question": question,
                "answers": labels,
                "prediction_text": "",
                "answer_start": 0
            })
    eval_result = evaluate(predicted_answers, theoretical_answers)
    print(f"F1: {eval_result['f1']:>0.2f} EM: {eval_result['em']:>0.2f} AVG: {eval_result['avg']:>0.2f}\n")
    print('saving predicted results...')
    with open('test_data_pred.json', 'wt', encoding='utf-8') as f:
        for example_result in save_resluts:
            f.write(json.dumps(example_result, ensure_ascii=False) + '\n')

evaluating on test set...


  0%|          | 0/126 [00:00<?, ?it/s]

  0%|          | 0/1002 [00:00<?, ?it/s]

F1: 73.01 EM: 32.54 AVG: 52.77

saving predicted results...
