# Hugging Face - Question Answering in Japanese

This source code builds the fine-tuned model to perform extractive question answering (extractive QA) for Japanese language.<br>
For more background and details, see [here](https://tsmatz.wordpress.com/2022/12/12/huggingface-japanese-question-answering/).

For the reason of learning performance, I have used [rinna/japanese-roberta-base](https://huggingface.co/rinna/japanese-roberta-base) for pre-trained transformer, which is well-trained and optimized for Japanese corpus.<br>
You can also use [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) for other languages.

*back to [index](https://github.com/tsmatz/huggingface-finetune-japanese/)*

## Install required packages

In order to install core components, see [Readme](https://github.com/tsmatz/huggingface-finetune-japanese/).<br>
Install additional packages for running this notebook as follows.

In [None]:
!pip install numpy

## Check device

Check whether GPU is available.

In [1]:
import torch

if torch.cuda.is_available():
    print("GPU is enabled.")
    print("device count: {}, current device: {}".format(torch.cuda.device_count(), torch.cuda.current_device()))
else:
    print("GPU is not enabled.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

GPU is enabled.
device count: 1, current device: 0


## Prepare data

In this example, we use [JaQuAD](https://huggingface.co/datasets/SkelterLabsInc/JaQuAD) (Japanese Question Answering Dataset) in Hugging Face, which is annotated extractive question answering dataset like famous SQuAD.<br>
This dataset has over 30000 samples for training.

In [2]:
from datasets import load_dataset

ds = load_dataset("SkelterLabsInc/JaQuAD")
ds



Downloading builder script:   0%|          | 0.00/4.48k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.74k [00:00<?, ?B/s]

Downloading and preparing dataset ja_qu_ad/default to /home/tsmatsuz/.cache/huggingface/datasets/SkelterLabsInc___ja_qu_ad/default/0.1.0/5847b2e2ab5e02de284395bb15f87f13eae8f6f6ff1f01e4ee9c5c0dcf8ef8eb...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/790k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/815k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/844k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/658k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/791k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/776k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/718k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/800k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/752k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/797k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/797k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/785k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/580k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/743k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/728k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/757k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/763k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/772k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/807k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/796k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/775k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/804k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/855k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/745k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/771k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/778k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/763k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/762k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/945k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/246k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/855k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/801k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/746k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/621k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset ja_qu_ad downloaded and prepared to /home/tsmatsuz/.cache/huggingface/datasets/SkelterLabsInc___ja_qu_ad/default/0.1.0/5847b2e2ab5e02de284395bb15f87f13eae8f6f6ff1f01e4ee9c5c0dcf8ef8eb. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'question_type', 'answers'],
        num_rows: 31748
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'question_type', 'answers'],
        num_rows: 3939
    })
})

In extractive QA, both ```context``` and ```question``` are provided for inputs, and it then predicts the span (start position and end position) of answer in ```context``` text.

In [3]:
ds["train"][0]

{'id': 'tr-000-00-000',
 'title': '手塚治虫',
 'context': '手塚治虫(てづかおさむ、本名:手塚治(読み同じ)、1928年(昭和3年)11月3日-1989年(平成元年)2月9日)は、日本の漫画家、アニメーター、アニメ監督である。\n戦後日本においてストーリー漫画の第一人者として、漫画表現の開拓者的な存在として活躍した。\n\n兵庫県宝塚市出身(出生は大阪府豊能郡豊中町、現在の豊中市)同市名誉市民である。\n大阪帝国大学附属医学専門部を卒業。\n医師免許取得のち医学博士(奈良県立医科大学・1961年)。',
 'question': '戦後日本のストーリー漫画の第一人者で、医学博士の一面もある漫画家は誰?',
 'question_type': 'Multiple sentence reasoning',
 'answers': {'text': ['手塚治虫'], 'answer_start': [0], 'answer_type': ['Person']}}

To generate inputs for fine-tuning, now I tokenize each text and convert into token ids.

First, load tokenizer in pre-trained ```rinna/japanese-roberta-base``` model.

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-roberta-base")

Downloading:   0%|          | 0.00/259 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/806k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/153 [00:00<?, ?B/s]

Before tokenize dataset, let's see how input is formed.<br>
As you can see below, the inputs take the format :

```question text </s> context text </s> pad```

In this task, we need the complete form for the question, but the context can be incomplete. (If the answer happens not to be included in the context, the answer will be empty.)<br>
By setting ```max_length``` property and ```truncation="only_second"``` as follows, the first sequence (i.e, question) won't be truncated, but the second sequence (i.e, context) is truncated by the maximum length of tokens. 

In [5]:
features = tokenizer(
    ds["train"][0]["question"],
    ds["train"][0]["context"],
    max_length = 384,
    truncation="only_second",
    padding = "max_length",
)
print("".join(tokenizer.batch_decode(features["input_ids"])))

戦後日本のストーリー漫画の第一人者で、医学博士の一面もある漫画家は誰?</s>手塚治虫(てづかおさむ、本名:手塚治(読み同じ)、1928年(昭和3年)11月3日-1989年(平成元年)2月9日)は、日本の漫画家、アニメーター、アニメ監督である。戦後日本においてストーリー漫画の第一人者として、漫画表現の開拓者的な存在として活躍した。兵庫県宝塚市出身(出生は大阪府豊能郡豊中町、現在の豊中市)同市名誉市民である。大阪帝国大学附属医学専門部を卒業。医師免許取得のち医学博士(奈良県立医科大学・1961年)。</s>[PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]

Now let's tokenize and convert dataset for fine-tuning.

Question Answering model in Hugging Face expects answer's ```start_positions``` and ```end_positions``` which indicate the positions in the input's sequence. (These positions should be indices in ```input_ids```.)

When the context is longer than the maximum sequence length (here, it's 384), it's simply truncated and the overflowing tokens will be returned as the next sequence, by setting ```return_overflowing_tokens=True```.<br>
However, it might happen that the answer tokens are separated into multiple sequences. To prevent this occurence, the size of stride in sliding window can be controlled by ```stride``` property. For instance, the last n tokens are not fit and overflow, m + n tokens will be in the next sequence when ```stride=m```. These m tokens are then the overlapped tokens between windows.<br>
The position of answer will then be either of first sequence or second sequence, or in both sequences.

In this example, I have also removed the sequence, in which the answer doesn't exist.

> Note : To get token index for each character, you can also use ```char_to_token()``` method in tokenizer, instead of using ```return_offsets_mapping``` property.

In [6]:
# Modified the following source code for supporting the case of overflowing
# https://huggingface.co/docs/transformers/tasks/question_answering

def tokenize_sample_data(data):
    # tokenize
    tokenized_feature = tokenizer(
        data["question"],
        data["context"],
        max_length = 384,
        return_overflowing_tokens=True,
        stride=128,
        truncation="only_second",
        padding = "max_length",
        return_offsets_mapping=True,
    )

    # When it overflows, multiple rows will be returned for a single example.
    # The following then gets the array of corresponding the original sample index.
    sample_mapping = tokenized_feature.pop("overflow_to_sample_mapping")
    # Get the array of [start_char, end_char + 1] in each token.
    # The shape is [returned_row_size, max_length]
    offset_mapping = tokenized_feature.pop("offset_mapping")

    start_positions = []
    end_positions = []
    for i, offset in enumerate(offset_mapping):
        sample_index = sample_mapping[i]
        answers = data["answers"][sample_index]
        start_char = answers["answer_start"][0]
        end_char = start_char + len(answers["text"][0]) - 1
        # The format of sequence_ids is [None, 0, ..., 0, None, None, 1, ..., 1, None, None, ...]
        # in which question's token is 0 and contex's token is 1
        sequence_ids = tokenized_feature.sequence_ids(i)
        # find the start and end index of context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1
        # Set start positions and end positions in inputs_ids
        # Note: The second element in offset is end_char + 1
        #if offset[context_start][0] > end_char or offset[context_end][1] <= start_char:
        if not (offset[context_start][0] <= start_char and end_char < offset[context_end][1]):
            # The case that answer is not inside the context
            ## Note : Some tokenizer (such as, tokenizer in rinna model) doesn't place CLS
            ## for the first token in sequence, and I then set -1 as positions.
            ## (Later I'll process rows with start_positions=-1.)
            start_positions.append(-1)
            end_positions.append(-1)
            #start_positions.append(0)
            #end_positions.append(0)
        else:
            # The case that answer is found in the context

            # Set start position
            idx = context_start
            while offset[idx][0] < start_char:
                idx += 1
            if offset[idx][0] == start_char:
                start_positions.append(idx)
            else:
                start_positions.append(idx - 1)

            # Set end position
            idx = context_end
            while offset[idx][1] > end_char + 1:
                idx -= 1
            if offset[idx][1] == end_char + 1:
                end_positions.append(idx)
            else:
                end_positions.append(idx + 1)

    # build result
    tokenized_feature["start_positions"] = start_positions
    tokenized_feature["end_positions"] = end_positions   
    return tokenized_feature

In [7]:
# Run conversion
tokenized_ds = ds.map(
    tokenize_sample_data,
    remove_columns=["id", "title", "context", "question", "question_type", "answers"],
    batched=True,
    batch_size=128)
# Remove rows with start_positions=-1 (see above)
tokenized_ds = tokenized_ds.filter(lambda x: x["start_positions"] != -1)

print("********** input_ids **********")
print(tokenized_ds["train"]["input_ids"][0])
print("********** start_positions **********")
print(tokenized_ds["train"]["start_positions"][0])
print("********** end_positions **********")
print(tokenized_ds["train"]["end_positions"][0])

  0%|          | 0/249 [00:00<?, ?ba/s]

  0%|          | 0/31 [00:00<?, ?ba/s]

  0%|          | 0/35 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

********** input_ids **********
[5242, 216, 1879, 1302, 11091, 63, 147, 19, 7, 27730, 10, 13134, 553, 6768, 11, 5943, 3017, 2, 9, 28889, 15, 58, 16154, 13726, 561, 7, 3875, 76, 25649, 687, 15, 3009, 851, 14, 15481, 16, 15, 112, 31, 16, 3824, 22, 31, 33, 61, 4485, 16, 15, 16980, 14, 25, 22, 52, 33, 14, 11, 7, 14122, 7, 21244, 7, 1047, 464, 27, 8, 5242, 15045, 1879, 1302, 11091, 63, 5905, 7, 1302, 1030, 10, 19365, 132, 1763, 12963, 8, 9, 3656, 6227, 3202, 15, 8967, 11, 2882, 1311, 1624, 139, 1311, 82, 87, 7, 373, 1311, 82, 69, 14, 11661, 2957, 1296, 27, 8, 8773, 1744, 22585, 2966, 1673, 126, 8236, 8, 9, 2800, 4459, 3113, 2089, 27730, 15, 7907, 356, 14689, 13, 9117, 16, 14, 8, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,

## Fine-tune

In this example, we use ```AutoModelForQuestionAnswering``` class with pre-trained RobertaModel.

Like [token classification example](./01-named-entity.ipynb), this model consists of pre-trained RobertaModel and classification head.<br>
However, the final output (which shape is ```[batch_length, sequence_length, 2]```) is split into 2 parts, and each of them then has the shape ```[batch_length, sequence_length]```. These two tensors are then used as start logits and end logits, and the token classification loss between these logits and true labels (```start_positions``` and ```end_positions```, respectively) are then computed for optimization. (See [here](https://tsmatz.wordpress.com/2022/12/12/huggingface-japanese-question-answering/) for model architecture.)

> Note : The following ```num_labels``` and ```hidden_size``` is the default values in ```AutoModelForQuestionAnswering```, and you can then skip these config settings.

In [8]:
from transformers import AutoConfig, AutoModelForQuestionAnswering

# see https://huggingface.co/docs/transformers/main_classes/configuration
config = AutoConfig.from_pretrained(
    "rinna/japanese-roberta-base",
    num_labels=2,
    hidden_size=768,
)
model = (AutoModelForQuestionAnswering
         .from_pretrained("rinna/japanese-roberta-base", config=config)
         .to(device))

Downloading:   0%|          | 0.00/663 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/443M [00:00<?, ?B/s]

Some weights of the model checkpoint at rinna/japanese-roberta-base were not used when initializing RobertaForQuestionAnswering: ['lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.bias', 'lm_head.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForQuestionAnswering were not initialized from the model checkpoint at rinna/japanese-roberta-base and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRA

We prepare data collator, which works for preprocessing data.<br>
Unlike other examples, here we use default data collator which doesn't do any extra works - such as, filling -100 in padded tokens -, because we don't need to skip loss or evaluation computation in padded tokens. 

In [9]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

We prepare training arguments for fine-tuning.<br>
In this example, we use HuggingFace transformer trainer class, with which you can run training without manually writing training loop.

First we prepare trainer's arguments.<br>
The checkpoint files (in each 50 steps) are saved in the folder named ```rinna-roberta-qa-ja```.

> Note : In general, the saved checkpoints in the training will become so large.<br>
> Set ```save_total_limit``` property (which limits the total amount of checkpoints by deleting the older ones) to save disk spaces, or expand disks in Azure VM. (See [here](https://learn.microsoft.com/en-us/azure/virtual-machines/linux/expand-disks) to expand disks in Azure.)

In [10]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir = "rinna-roberta-qa-ja",
    log_level = "error",
    num_train_epochs = 3,
    learning_rate = 7e-5,
    lr_scheduler_type = "linear",
    warmup_steps = 100,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 1,
    gradient_accumulation_steps = 16,
    evaluation_strategy = "steps",
    eval_steps = 150,
    save_steps = 500,
    logging_steps = 50,
    push_to_hub = False
)

Build trainer. (Put it all together.)

In [11]:
from transformers import Trainer

trainer = Trainer(
    model = model,
    args = training_args,
    data_collator = data_collator,
    train_dataset = tokenized_ds["train"],
    eval_dataset = tokenized_ds["validation"].select(range(100)),
    tokenizer = tokenizer,
)

Now let's run training.<br>
As I have mentioned above, make sure that you have enough disk space.

In [12]:
trainer.train()



Step,Training Loss,Validation Loss
150,2.1526,1.098189
300,1.1573,0.63552
450,1.0778,0.613432
600,0.9466,0.531952
750,0.8172,0.610254
900,0.8197,0.478506
1050,0.6442,0.522829
1200,0.514,0.495377
1350,0.5367,0.46102
1500,0.5377,0.475981


TrainOutput(global_step=3045, training_loss=0.7215310818651822, metrics={'train_runtime': 14091.9134, 'train_samples_per_second': 6.918, 'train_steps_per_second': 0.216, 'total_flos': 1.910223004956365e+16, 'train_loss': 0.7215310818651822, 'epoch': 3.0})

In order to use it later, you can save the trained model.

In [13]:
import os

os.makedirs("./trained_for_qa_jp", exist_ok=True)
if hasattr(trainer.model, "module"):
    trainer.model.module.save_pretrained("./trained_for_qa_jp")
else:
    trainer.model.save_pretrained("./trained_for_qa_jp")

Load pre-trained model from local.

In [14]:
from transformers import AutoModelForQuestionAnswering

model = (AutoModelForQuestionAnswering
         .from_pretrained("./trained_for_qa_jp")
         .to(device))

## Perform Question Answering

Now let's predict the answer fot the given context and question (which has not seen in the training set) with fine-tuned model.

Instead manually running through forward pass, here I use a dedicated pipeline, in which preprocessing and postprocessing (such as, skipping padded tokens) are wrapped.<br>
For Asian languages (such as, Chinese, Korean, and Japanese) which doesn't have an explicit white space, specify ```align_to_words=False```.

As you can see below, this also returns scores.

> Note : Picking up argmax of start and end indicies will fail to take correct answer. For instance, if span (9, 11), (5, 7), and (3, 7) are the top 3 candidates for the answer, 7th token might be picked up as end's index, and it might then return the span (9, 7).<br>
> The QA pipeline in Hugging Face automatically picks up the best combination to avoid these mistakes.

In [15]:
from transformers import pipeline

qa_pipeline = pipeline(
    "question-answering",
    model=model,
    tokenizer=tokenizer,
    device=0)

idx = 0
print("***** context *****")
print(ds["validation"]["context"][idx])
print("")
print("***** question *****")
print(ds["validation"]["question"][idx])
print("")
print("***** true answer *****")
print(ds["validation"]["answers"][idx]["text"][0])
print("")
print("***** predicted top3 answer *****")
qa_pipeline(
    question = ds["validation"]["question"][idx],
    context = ds["validation"]["context"][idx],
    align_to_words = False,
    top_k=3,
)

***** context *****
本項東大寺の仏像では、奈良県奈良市にある聖武天皇ゆかりの寺院・東大寺に伝来する仏像について説明する。

8世紀に日本の首都であった奈良を代表する寺院である東大寺は、「古都奈良の文化財」の一部として世界遺産に登録されている。東大寺には、「奈良の大仏」として知られる、高さ約15メートルの盧舎那仏像をはじめ、日本仏教美術史を代表する著名作品が多く所蔵されている。

本項では東大寺に所在する仏像彫刻について概観する。なお、東大寺の概要については「東大寺」の項を、大仏については「東大寺盧舎那仏像」の項を参照のこと。

***** question *****
8世紀に日本の首都はどこでしたか。

***** true answer *****
奈良

***** predicted top3 answer *****


[{'score': 0.9808037877082825, 'start': 65, 'end': 67, 'answer': '奈良'},
 {'score': 0.011853429488837719,
  'start': 65,
  'end': 80,
  'answer': '奈良を代表する寺院である東大寺'},
 {'score': 0.00010151458263862878,
  'start': 65,
  'end': 74,
  'answer': '奈良を代表する寺院'}]

Run prediction without pipeline

In [16]:
import torch
import numpy as np

def inference_answer(question, context):
    question = question
    context = context
    test_feature = tokenizer(
        question,
        context,
        max_length=318,
    )
    with torch.no_grad():
        outputs = model(torch.tensor([test_feature["input_ids"]]).to(device))
    start_logits = outputs.start_logits.cpu().numpy()
    end_logits = outputs.end_logits.cpu().numpy()
    answer_ids = test_feature["input_ids"][np.argmax(start_logits):np.argmax(end_logits)+1]
    return "".join(tokenizer.batch_decode(answer_ids))

idx = 0
question = ds["validation"]["question"][idx]
context = ds["validation"]["context"][idx]
answer_pred = inference_answer(question, context)

print("***** question *****")
print(question)
print("")
print("***** context *****")
print(context)
print("")
print("***** true answer *****")
print(ds["validation"]["answers"][idx]["text"][0])
print("")
print("***** predicted answer *****")
print(answer_pred)

***** question *****
8世紀に日本の首都はどこでしたか。

***** context *****
本項東大寺の仏像では、奈良県奈良市にある聖武天皇ゆかりの寺院・東大寺に伝来する仏像について説明する。

8世紀に日本の首都であった奈良を代表する寺院である東大寺は、「古都奈良の文化財」の一部として世界遺産に登録されている。東大寺には、「奈良の大仏」として知られる、高さ約15メートルの盧舎那仏像をはじめ、日本仏教美術史を代表する著名作品が多く所蔵されている。

本項では東大寺に所在する仏像彫刻について概観する。なお、東大寺の概要については「東大寺」の項を、大仏については「東大寺盧舎那仏像」の項を参照のこと。

***** true answer *****
奈良

***** predicted answer *****
奈良


In [17]:
idx = 1
question = ds["validation"]["question"][idx]
context = ds["validation"]["context"][idx]
answer_pred = inference_answer(question, context)

print("***** question *****")
print(question)
print("")
print("***** context *****")
print(context)
print("")
print("***** true answer *****")
print(ds["validation"]["answers"][idx]["text"][0])
print("")
print("***** predicted answer *****")
print(answer_pred)

***** question *****
「奈良の大仏」の高さは何メートルなの?

***** context *****
本項東大寺の仏像では、奈良県奈良市にある聖武天皇ゆかりの寺院・東大寺に伝来する仏像について説明する。

8世紀に日本の首都であった奈良を代表する寺院である東大寺は、「古都奈良の文化財」の一部として世界遺産に登録されている。東大寺には、「奈良の大仏」として知られる、高さ約15メートルの盧舎那仏像をはじめ、日本仏教美術史を代表する著名作品が多く所蔵されている。

本項では東大寺に所在する仏像彫刻について概観する。なお、東大寺の概要については「東大寺」の項を、大仏については「東大寺盧舎那仏像」の項を参照のこと。

***** true answer *****
約15メートル

***** predicted answer *****
約15メートルの
