# Robertaのノートブック

## 設定

In [None]:
import os
import random
import numpy as np
import pandas as pd
import string
import unicodedata
from typing import Any, Dict, Iterator, List, Tuple, Union
import torch
import datasets
from datasets import load_metric


import transformers 
from transformers import AutoTokenizer,AlbertTokenizer,BertJapaneseTokenizer,T5Tokenizer
from transformers import AutoModelForQuestionAnswering,TrainingArguments,Trainer,ElectraForMaskedLM,RobertaForMaskedLM,AutoModel
from transformers import default_data_collator
import os



In [None]:
from huggingface_hub import notebook_login
notebook_login()

### 変数

In [None]:
##ここを変更する
args = {
    'random_seed': 42,
    'pretrained_model': 'rinna/japanese-roberta-base',
    'pretrained_tokenizer': 'rinna/japanese-roberta-base',
    'batch_size': 8, 
    "eval_batch_size":8,
    'lr': 2e-5, 
    'max_length': 384,  
    'doc_stride': 128,  
    'epochs': 1,  
    'dataset': 'SkelterLabsInc/JaQuAD',
    'optimizer': 'AdamW',
    'norm_form': 'NFKC',
    'weight_decay': 0.01,  
    'lr_scheduler': 'warmup_lin',
    'warmup_ratio': 0.1,
    "eval_accumulation_steps":10,
    'cpu_workers': os.cpu_count(),
    'note':"same_prepare",
}
args

tokenizer = T5Tokenizer.from_pretrained(
            args["pretrained_tokenizer"]
            if args["pretrained_tokenizer"] else
            args["pretrained_model"])
# tokenizer.do_lower_case = True

model = AutoModelForQuestionAnswering.from_pretrained(args['pretrained_model'])


In [None]:
#seed値を固定
def set_seed(seed =42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
set_seed(seed=args["random_seed"])

#wandb
import wandb
wandb.login()
os.environ["WANDB_PROJECT"] = "JaQuad"

#細かい設定
#!sudo apt install git-lfs
pad_on_right = tokenizer.padding_side == "right"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

#モデルを軽量化
def quantize_transform(model):
    model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
    return model
quantize_transform(model)

## データセットの準備

### datasetdict

In [None]:
datasetdict = datasets.load_dataset(args['dataset'])
datasetdict = datasetdict.flatten()\
            .rename_column('answers.text', 'answer')\
            .rename_column('answers.answer_start', 'answer_start')\
            .rename_column('answers.answer_type', 'answer_type')

### functions

In [None]:
#refer to official base line
def preprocess_function(examples):
        examples["question"] = ["[CLS]"+examples["question"][i] for i in range(len(examples["question"]))]
        tokenized_examples = tokenizer(
            examples['question' if pad_on_right else "context"],
            examples['context' if pad_on_right else "question"],
            padding = "max_length",
            max_length=args['max_length'],
            return_token_type_ids=True,
        )

        inputs = {
            'input_ids': [],
            'attention_mask': [],
            'start_positions': [],
            'end_positions': [],
        }
        for tokens, att_mask, type_ids, context, answer,question,start_char \
                in zip(tokenized_examples['input_ids'],
                       tokenized_examples['attention_mask'],
                       tokenized_examples['token_type_ids'],
                       examples['context'],
                       examples['answer'],
                       examples['question'],
                       examples['answer_start']):
                  
            sep_index = tokens.index(2)
            type_ids = [0 if i+1<=sep_index else att_mask[i] for i in range(len(att_mask))]

            answer = answer[0]
            start_char = start_char[0]
            offsets = get_offsets(tokens,context, tokenizer,
                                  args["norm_form"])
            

            ctx_start = tokens.index(2) + 1
            answer_start_index = 0
            answer_end_index = len(offsets) - 2
            
            while offsets[answer_start_index][0] < start_char:
                answer_start_index += 1
            while offsets[answer_end_index][1] > start_char + len(answer):
                answer_end_index -= 1
            answer_start_index += ctx_start
            answer_end_index += ctx_start

            span_inputs = {
                'input_ids': tokens,
                'attention_mask': att_mask,
                'token_type_ids': type_ids,
            }

            for span, answer_idx in make_spans(
                span_inputs,
                question_len=ctx_start,
                max_seq_len=args["max_length"],
                stride=args["doc_stride"],
                answer_start_position=answer_start_index,
                answer_end_position=answer_end_index):
                inputs['input_ids'].append(span['input_ids'])
                inputs['attention_mask'].append(span['attention_mask'])
                inputs['start_positions'].append(answer_idx[0])
                inputs['end_positions'].append(answer_idx[1]) 
            
        return inputs
   
def make_spans(
    inputs: Dict[str, Union[int, List[int]]],
    question_len: int,
    max_seq_len: int,
    stride: int,
    answer_start_position: int = -1,
    answer_end_position: int = -1,
) -> Iterator[Tuple[Dict[str, List[int]], Tuple[int, int]]]:
    input_len = len(inputs['input_ids'])
    context_len = input_len - question_len

    def make_value(input_list, i, padding=0):
        context_end = min(max_seq_len - question_len, context_len - i)
        pad_len = max_seq_len - question_len - context_end
        val = input_list[:question_len]
        val += input_list[question_len + i:question_len + i + context_end]
        val[-1] = input_list[-1]
        val += [padding] * pad_len
        return val
    for i in range(0, input_len - max_seq_len + stride, stride):
        span = {key: make_value(val, i) for key, val in inputs.items()}
        answer_start = answer_start_position - i
        answer_end = answer_end_position - i
        if answer_start < question_len or answer_end >= max_seq_len - 1:
            answer_start = answer_end = 0
        yield span, (answer_start, answer_end)    

def get_offsets(input_ids: List[int],
                context: str,
                tokenizer: AutoTokenizer,
                norm_form='NFKC') -> List[Tuple[int, int]]:
    
    cxt_start = input_ids.index(2) + 1
    cxt_end = cxt_start + input_ids[cxt_start:].index(2)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[cxt_start:cxt_end])
    tokens = [tok[2:] if tok.startswith('▁') else tok for tok in tokens]
    whitespace = string.whitespace + '\u3000'

    # 1 . Make offsets of normalized context within the original context.
    offsets_norm_context = []
    norm_context = ''
    for idx, char in enumerate(context):
        norm_char = unicodedata.normalize(norm_form, char)
        norm_context += norm_char
        offsets_norm_context.extend([idx] * len(norm_char))
    norm_context_org = unicodedata.normalize(norm_form, context)
    assert norm_context == norm_context_org, \
        'Normalized contexts are not the same: ' \
        + f'{norm_context} != {norm_context_org}'
    assert len(norm_context) == len(offsets_norm_context), \
        'Normalized contexts have different numbers of tokens: ' \
        + f'{len(norm_context)} != {len(offsets_norm_context)}'

    # 2. Make offsets of tokens (input_ids) within the normalized context.
    offsets_token = []
    unk_pointer = None
    cid = 0
    tid = 0
    while tid < len(tokens):
        cur_token = tokens[tid]
        if cur_token == tokenizer.unk_token:
            unk_pointer = tid
            offsets_token.append([cid, cid])
            cid += 1
        # elif norm_context[cid:cid + len(cur_token)] != cur_token:
        #     # Wrong offsets of the previous UNK token
        #     assert unk_pointer is not None, \
        #         'Normalized context and tokens are not matched'
        #     prev_unk_expected = offsets_token[unk_pointer]
        #     prev_unk_expected[1] += norm_context[prev_unk_expected[1] + 2:]\
        #         .index(tokens[unk_pointer + 1]) + 1
        #     tid = unk_pointer
        #     offsets_token = offsets_token[:tid] + [prev_unk_expected]
        #     cid = prev_unk_expected[1] + 1
        else:
            start_pos = norm_context[cid:].index(cur_token)
            if start_pos > 0 and tokens[tid - 1] == tokenizer.unk_token:
                offsets_token[-1][1] += start_pos
                cid += start_pos
                start_pos = 0
            offsets_token.append([cid, cid + len(cur_token) - 1])
            cid += len(cur_token)
            while cid < len(norm_context) and norm_context[cid] in whitespace:
                offsets_token[-1][1] += 1
                cid += 1
        tid += 1
    if tokens[-1] == tokenizer.unk_token:
        offsets_token[-1][1] = len(norm_context) - 1
    assert len(offsets_token) == len(tokens), \
        'The numbers of tokens and offsets are different'

    offsets_mapping = [(offsets_norm_context[start], offsets_norm_context[end])
                       for start, end in offsets_token]
    return [(0, 0)] + offsets_mapping+[(0, 0)] 

In [None]:
#テスト
features = preprocess_function(datasetdict['train'][:10])
print(features)

In [None]:
#all data
tokenized_train_dataset = datasetdict["train"].map(preprocess_function, batched=True, remove_columns=datasetdict["train"].column_names)
tokenized_valid_dataset = datasetdict["validation"].map(preprocess_function, batched=True, remove_columns=datasetdict["train"].column_names)
tokenized_train_dataset.set_format(type='torch')
tokenized_valid_dataset.set_format(type='torch')

## Fine-tuning the model

In [None]:
model_name = args['pretrained_model'].split("/")[-1]
note = args["note"]
data_collator = default_data_collator

train_args = TrainingArguments(
    output_dir= f"./model/{model_name}_{note}",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=args["lr"],
    per_device_train_batch_size= args["batch_size"],
    per_device_eval_batch_size= args["eval_batch_size"],
    num_train_epochs=args["epochs"],
    weight_decay=args["weight_decay"],
    push_to_hub=True,
    report_to="wandb",
    run_name=f"{model_name}_{note}",
    load_best_model_at_end = True,
    eval_accumulation_steps = args["eval_accumulation_steps"] #メモリ対策？
)

trainer =  Trainer(
    model,
    train_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_valid_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=None,
)

In [None]:
trainer.train()

## valuation

### function

In [None]:
def calculate_f1_score(start_positions, end_positions, start_preds,
                        end_preds):
    start_overlap = np.maximum(start_positions, start_preds)
    end_overlap = np.minimum(end_positions, end_preds)
    overlap = np.maximum(end_overlap - start_overlap + 1, 0)
    pred_token_count = np.maximum(end_preds - start_preds + 1, 0)
    ground_token_count = np.maximum(end_positions - start_positions + 1, 0)


    precision = torch.nan_to_num(overlap / pred_token_count, nan=0.)
    recall = torch.nan_to_num(overlap / ground_token_count, nan=0.)
    f1 = torch.nan_to_num(
        2 * precision * recall / (precision + recall), nan=0.)
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }
    
def calculate_exact_match(start_positions, end_positions, start_preds,
                            end_preds):
    equal_start = (start_preds == start_positions)
    equal_end = (end_preds == end_positions)
    return (equal_start * equal_end).to('cpu').detach().numpy().astype(np.float32)

def get_result(dataset = tokenized_valid_dataset):
    results = trainer.predict(dataset)
    start_preds = results.predictions[0]
    end_preds = results.predictions[1]
    start_positions = results.label_ids[0]
    end_positions = results.label_ids[1]
    start_pred= torch.from_numpy(start_preds).argmax(dim=-1).cpu().detach()
    end_pred= torch.from_numpy(end_preds).argmax(dim=-1).cpu().detach()
    start_position = torch.from_numpy(start_positions).cpu().detach()
    end_position = torch.from_numpy(end_positions).cpu().detach()
    return start_pred,end_pred,start_position,end_position

In [None]:
start_pred,end_pred,start_position,end_position = get_result(tokenized_valid_dataset)
f1_metrics = calculate_f1_score(start_position,end_position,start_pred,end_pred)

precision = f1_metrics['precision'].mean()
recall = f1_metrics['recall'].mean()
f1 = f1_metrics['f1'].mean()
em = calculate_exact_match(start_position,end_position,start_pred,end_pred).mean()

print("========== CV ==========")
print("f1:",f1.item())
print("em:",em.item())

wandb.log({"f1": f1, "em": em})

## Interface

In [None]:
question = '『三つ目がとおる』を発表したのはいつ？'
context = '"大阪帝国大学附属医学専門部在学中の1946年1月1日に4コマ漫画『マアチャンの日記帳』(『少国民新聞』連載)で漫画家としてデビューした。1947年、酒井七馬原案の描き下ろし単行本『新寶島』がベストセラーとなり、大阪に赤本ブームを引き起こす。1950年より漫画雑誌に登場、『鉄腕アトム』『ジャングル大帝』『リボンの騎士』といったヒット作を次々と手がけた。1963年、自作をもとに日本初となる30分枠のテレビアニメシリーズ『鉄腕アトム』を制作、現代につながる日本のテレビアニメ制作に多大な影響を及ぼした。1970年代には『ブラック・ジャック』『三つ目がとおる』『ブッダ』などのヒット作を発表。また晩年にも『陽だまりの樹』『アドルフに告ぐ』など青年漫画においても傑作を生み出す。デビューから1989年の死去まで第一線で作品を発表し続け、存命中から「マンガの神様」と評された。藤子不二雄(藤子・F・不二雄、藤子不二雄A)、石ノ森章太郎、赤塚不二夫、横山光輝、水野英子、矢代まさこ、萩尾望都などをはじめ数多くの人間が手塚に影響を受け、接触し漫画家を志した。'

In [None]:
inputs = tokenizer(
    question, context, add_special_tokens=True, return_tensors="pt")
input_ids = inputs["input_ids"].tolist()[0]
outputs = model(**inputs)
answer_start_scores = outputs.start_logits
answer_end_scores = outputs.end_logits

In [None]:
# Get the most likely beginning of answer with the argmax of the score.
answer_start = torch.argmax(answer_start_scores)

answer_end = torch.argmax(answer_end_scores) -1
# Get the most likely end of answer with the argmax of the score.
# 1 is added to `answer_end` because the index pointed by score is inclusive.

answer = tokenizer.convert_tokens_to_string(
    tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
answer