## 自分で書いたオリジナルのベースラインモデル

### 下準備

In [1]:
import os
import string
import unicodedata
from typing import Any, Dict, Iterator, List, Tuple, Union
import random
import pandas as pd

import datasets
import numpy as np
from pytorch_lightning import LightningModule, Trainer, seed_everything
import torch
from torch.utils.data import TensorDataset, random_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.optim import lr_scheduler
import transformers 
from transformers import BertJapaneseTokenizer, ElectraForMaskedLM,ElectraForQuestionAnswering
from transformers import RobertaTokenizer, RobertaModel
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import callbacks
from pytorch_lightning import loggers
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, AdamW
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
args = {
    'random_seed': 42,  # Random Seed
    # Transformers PLM name.
    'pretrained_model': 'cl-tohoku/bert-base-japanese-whole-word-masking',
    # Optional, Transformers Tokenizer name. Overrides `pretrained_model`
    'pretrained_tokenizer': 'cl-tohoku/bert-base-japanese-whole-word-masking',
    'norm_form': 'NFKC',
    'batch_size': 8,  # <=32 for TPUv2-8
    'lr': 2e-5,  # Learning Rate
    'max_length': 384,  # Max Length input size
    'doc_stride': 128,  # The interval of the context when splitting is needed
    'epochs': 3,  # Max Epochs
    'dataset': 'SkelterLabsInc/JaQuAD',
    'huggingface_auth_token': None,
    'test_mode': False,  # Test Mode enables `fast_dev_run`
    'optimizer': 'AdamW',
    'weight_decay': 0.01,  # Weight decaying parameter for AdamW
    'lr_scheduler': 'warmup_lin',
    'warmup_ratio': 0.1,
    'fp16': False,  # Enable train on FP16 (if GPU)
    'tpu_cores': 8,  # Enable TPU with 1 core or 8 cores
    'cpu_workers': os.cpu_count(), #questionとcontextを合わせた最大語数。今回のデータセットは1300は超えない
    'note':"リクルートベースライン",
}

args

#seed値を固定
def set_seed(seed =42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
set_seed(seed=args["random_seed"])

In [None]:
# ====================================================
# wandb
# ====================================================
import wandb
wandb.login

wandb_logger = WandbLogger(project="JaQuad",config=args,
                 name = args["note"] + "_"+args["pretrained_model"])

## データセットの準備

In [4]:
datasetdict = datasets.load_dataset(
    args['dataset'], use_auth_token=args['huggingface_auth_token'])
datasetdict = datasetdict.flatten()\
            .rename_column('answers.text', 'answer')\
            .rename_column('answers.answer_start', 'answer_start')\
            .rename_column('answers.answer_type', 'answer_type')

# train =pd.DataFrame(jaquad_dataset['train'][:].values(), index=jaquad_dataset['train'][:].keys()).T
# valid =pd.DataFrame(jaquad_dataset['validation'][:].values(), index=jaquad_dataset['validation'][:].keys()).T
# train.to_csv("JaQuAD_train.csv")
# valid.to_csv("JaQuAD_valid.csv")

Using custom data configuration default
Reusing dataset ja_qu_ad (/home/s16991/.cache/huggingface/datasets/SkelterLabsInc___ja_qu_ad/default/0.1.0/5847b2e2ab5e02de284395bb15f87f13eae8f6f6ff1f01e4ee9c5c0dcf8ef8eb)


  0%|          | 0/2 [00:00<?, ?it/s]

### データの前処理

In [6]:
def preprocess_function(examples):
        tokenized_examples = tokenizer(
            examples['question'],
            examples['context'],
        )

        inputs = {
            'input_ids': [],
            'attention_mask': [],
            'token_type_ids': [],
            'start_positions': [],
            'end_positions': [],
        }
        for tokens, att_mask, type_ids, context, answer, start_char \
                in zip(tokenized_examples['input_ids'],
                       tokenized_examples['attention_mask'],
                       tokenized_examples['token_type_ids'],
                       examples['context'],
                       examples['answer'],
                       examples['answer_start']):
            answer = answer[0]
            start_char = start_char[0]
            offsets = get_offsets(tokens, context, tokenizer,
                                  args["norm_form"])

            ctx_start = tokens.index(tokenizer.sep_token_id) + 1
            answer_start_index = ctx_start
            answer_end_index = len(offsets) - 1
            while offsets[answer_start_index][0] < start_char:
                answer_start_index += 1
            while offsets[answer_end_index][1] > start_char + len(answer):
                answer_end_index -= 1

            span_inputs = {
                'input_ids': tokens,
                'attention_mask': att_mask,
                'token_type_ids': type_ids,
            }
            for span, answer_idx in make_spans(
                    span_inputs,
                    question_len=ctx_start,
                    max_seq_len=args["max_length"],
                    stride=args["doc_stride"],
                    answer_start_position=answer_start_index,
                    answer_end_position=answer_end_index):
                inputs['input_ids'].append(span['input_ids'])
                inputs['attention_mask'].append(span['attention_mask'])
                inputs['token_type_ids'].append(span['token_type_ids'])
                inputs['start_positions'].append(answer_idx[0])
                inputs['end_positions'].append(answer_idx[1])
        print(len(inputs))
        return inputs


def make_spans(
    inputs: Dict[str, Union[int, List[int]]],
    question_len: int,
    max_seq_len: int,
    stride: int,
    answer_start_position: int = -1,
    answer_end_position: int = -1
) -> Iterator[Tuple[Dict[str, List[int]], Tuple[int, int]]]:
    input_len = len(inputs['input_ids'])
    context_len = input_len - question_len

    def make_value(input_list, i, padding=0):
        context_end = min(max_seq_len - question_len, context_len - i)
        pad_len = max_seq_len - question_len - context_end
        val = input_list[:question_len]
        val += input_list[question_len + i:question_len + i + context_end]
        val[-1] = input_list[-1]
        val += [padding] * pad_len
        return val
    for i in range(0, input_len - max_seq_len + stride, stride):
        span = {key: make_value(val, i) for key, val in inputs.items()}
        answer_start = answer_start_position - i
        answer_end = answer_end_position - i
        if answer_start < question_len or answer_end >= max_seq_len - 1:
            answer_start = answer_end = 0
        yield span, (answer_start, answer_end)
        

def get_offsets(input_ids: List[int],
                context: str,
                tokenizer: AutoTokenizer,
                norm_form='NFKC') -> List[Tuple[int, int]]:
    
    cxt_start = input_ids.index(tokenizer.sep_token_id) + 1
    cxt_end = cxt_start + input_ids[cxt_start:].index(tokenizer.sep_token_id)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[cxt_start:cxt_end])
    tokens = [tok[2:] if tok.startswith('##') else tok for tok in tokens]
    whitespace = string.whitespace + '\u3000'

    # 1 . Make offsets of normalized context within the original context.
    offsets_norm_context = []
    norm_context = ''
    for idx, char in enumerate(context):
        norm_char = unicodedata.normalize(norm_form, char)
        norm_context += norm_char
        offsets_norm_context.extend([idx] * len(norm_char))
    norm_context_org = unicodedata.normalize(norm_form, context)
    assert norm_context == norm_context_org, \
        'Normalized contexts are not the same: ' \
        + f'{norm_context} != {norm_context_org}'
    assert len(norm_context) == len(offsets_norm_context), \
        'Normalized contexts have different numbers of tokens: ' \
        + f'{len(norm_context)} != {len(offsets_norm_context)}'

    # 2. Make offsets of tokens (input_ids) within the normalized context.
    offsets_token = []
    unk_pointer = None
    cid = 0
    tid = 0
    while tid < len(tokens):
        cur_token = tokens[tid]
        if cur_token == tokenizer.unk_token:
            unk_pointer = tid
            offsets_token.append([cid, cid])
            cid += 1
        elif norm_context[cid:cid + len(cur_token)] != cur_token:
            # Wrong offsets of the previous UNK token
            assert unk_pointer is not None, \
                'Normalized context and tokens are not matched'
            prev_unk_expected = offsets_token[unk_pointer]
            prev_unk_expected[1] += norm_context[prev_unk_expected[1] + 2:]\
                .index(tokens[unk_pointer + 1]) + 1
            tid = unk_pointer
            offsets_token = offsets_token[:tid] + [prev_unk_expected]
            cid = prev_unk_expected[1] + 1
        else:
            start_pos = norm_context[cid:].index(cur_token)
            if start_pos > 0 and tokens[tid - 1] == tokenizer.unk_token:
                offsets_token[-1][1] += start_pos
                cid += start_pos
                start_pos = 0
            assert start_pos == 0, f'{start_pos} != 0 (cur: {cur_token}'
            offsets_token.append([cid, cid + len(cur_token) - 1])
            cid += len(cur_token)
            while cid < len(norm_context) and norm_context[cid] in whitespace:
                offsets_token[-1][1] += 1
                cid += 1
        tid += 1
    if tokens[-1] == tokenizer.unk_token:
        offsets_token[-1][1] = len(norm_context) - 1
    else:
        assert cid == len(norm_context) == offsets_token[-1][1] + 1, \
            'Offsets do not include all characters'
    assert len(offsets_token) == len(tokens), \
        'The numbers of tokens and offsets are different'

    offsets_mapping = [(offsets_norm_context[start], offsets_norm_context[end])
                       for start, end in offsets_token]
    return [(-1, -1)] * cxt_start + offsets_mapping

In [None]:
tokenizer = BertJapaneseTokenizer.from_pretrained(
            args["pretrained_tokenizer"]
            if args["pretrained_tokenizer"] else
            args["pretrained_model"],
            use_auth_token=args["huggingface_auth_token"])
tokenized_dataset = datasetdict.map(
            preprocess_function,
            batched=True,
            remove_columns=datasetdict['train'].column_names)
#distillbertの際使う
tokenized_dataset = tokenized_dataset.map(
            remove_columns=['token_type_ids'])
tokenized_dataset.set_format(type='torch')
print(len(tokenized_dataset["train"]))
train_dataloader = DataLoader(tokenized_dataset["train"],shuffle = True,batch_size=args["batch_size"],num_workers=args["cpu_workers"])
valid_dataloader = DataLoader(tokenized_dataset['validation'], shuffle=False,batch_size=args["batch_size"],num_workers=args["cpu_workers"])

### モデル

In [None]:
class WarmupLinearLR(lr_scheduler.LambdaLR):
    '''The learning rate is linearly increased for the first `warmup_steps`
    and linearly decreased to zero afterward.
    '''

    def __init__(self, optimizer, warmup_steps, max_steps, last_epoch=-1):

        def lr_lambda(step):
            if step < warmup_steps:
                return float(step) / float(max(1.0, warmup_steps))
            ratio = 1 - float(step - warmup_steps) / float(max_steps -
                                                           warmup_steps)
            return max(0.0, min(1.0, ratio))

        super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)

In [None]:
#bertを使ったQAモデル
class QAModel(LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()  # kwargs are saved in self.hparams
        self.automatic_optimization = False

        self.question_answerer = AutoModelForQuestionAnswering.from_pretrained(
            self.hparams.pretrained_model,
            use_auth_token=self.hparams.huggingface_auth_token)
        self.tokenizer = BertJapaneseTokenizer.from_pretrained(
            self.hparams.pretrained_tokenizer
            if self.hparams.pretrained_tokenizer else
            self.hparams.pretrained_model,
            use_auth_token=self.hparams.huggingface_auth_token)

    def forward(self, **kwargs):
        return self.question_answerer(**kwargs)

    def step(self, batch, batch_idx):
        outputs = self(**batch)
        print(outputs)

        loss = outputs.loss
        start_preds = outputs.start_logits.argmax(dim=-1).cpu().detach()
        end_preds = outputs.end_logits.argmax(dim=-1).cpu().detach()
        start_positions = batch['start_positions'].cpu().detach()
        end_positions = batch['end_positions'].cpu().detach()

        return {
            'loss': loss,
            'start_preds': start_preds,
            'end_preds': end_preds,
            'start_positions': start_positions,
            'end_positions': end_positions,
        }

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        opt.zero_grad()
        outputs = self.step(batch, batch_idx)
        self.manual_backward(outputs['loss'])
        opt.step()

        # single scheduler
        sch = self.lr_schedulers()
        sch.step()
        return outputs

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx)

    @staticmethod
    def calculate_f1_score(start_positions, end_positions, start_preds,
                           end_preds):
        start_overlap = np.maximum(start_positions, start_preds)
        end_overlap = np.minimum(end_positions, end_preds)
        overlap = np.maximum(end_overlap - start_overlap + 1, 0)

        pred_token_count = np.maximum(end_preds - start_preds + 1, 0)
        ground_token_count = np.maximum(end_positions - start_positions + 1, 0)

        precision = torch.nan_to_num(overlap / pred_token_count, nan=0.)
        recall = torch.nan_to_num(overlap / ground_token_count, nan=0.)
        f1 = torch.nan_to_num(
            2 * precision * recall / (precision + recall), nan=0.)
        return {
            'precision': precision,
            'recall': recall,
            'f1': f1,
        }

    @staticmethod
    def calculate_exact_match(start_positions, end_positions, start_preds,
                              end_preds):
        equal_start = (start_preds == start_positions)
        equal_end = (end_preds == end_positions)
        return (equal_start * equal_end).type(torch.float)

    def epoch_end(self, outputs, state='train'):
        loss = torch.tensor(0, dtype=torch.float)
        precision = torch.tensor(0, dtype=torch.float)
        recall = torch.tensor(0, dtype=torch.float)
        f1 = torch.tensor(0, dtype=torch.float)
        em = torch.tensor(0, dtype=torch.float)

        for i in outputs:
            loss += i['loss'].cpu().detach()
            f1_metrics = self.calculate_f1_score(i['start_positions'],
                                                 i['end_positions'],
                                                 i['start_preds'],
                                                 i['end_preds'])
            precision += f1_metrics['precision'].mean()
            recall += f1_metrics['recall'].mean()
            f1 += f1_metrics['f1'].mean()
            em += self.calculate_exact_match(i['start_positions'],
                                             i['end_positions'],
                                             i['start_preds'],
                                             i['end_preds']).mean()
        loss = loss / len(outputs)
        precision = precision / len(outputs)
        recall = recall / len(outputs)
        f1 = f1 / len(outputs)
        em = em / len(outputs)
        metrics = {
            state + '_loss': float(loss),
            state + '_precision': precision,
            state + '_recall': recall,
            state + '_f1': f1,
            state + '_em': em,
        }

        self.log_dict(metrics, on_epoch=True)

        return metrics

    def training_epoch_end(self, outputs):
        self.epoch_end(outputs, state='train')

    def validation_epoch_end(self, outputs):
        self.epoch_end(outputs, state='val')

    def configure_optimizers(self):
        if self.hparams.optimizer == 'AdamW':
            optimizer = AdamW(
                self.parameters(),
                lr=self.hparams.lr,
                weight_decay=self.hparams.weight_decay)
        else:
            raise NotImplementedError('Only AdamW is Supported!')

        if self.hparams.lr_scheduler == 'cos':
            scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=1, T_mult=2)
        elif self.hparams.lr_scheduler == 'exp':
            scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.5)
        elif self.hparams.lr_scheduler == 'warmup_lin':
            steps_per_epoch = len(self.train_dataloader())
            if os.environ.get('COLAB_TPU_ADDR'):
                steps_per_epoch = steps_per_epoch // self.hparams.tpu_cores
            total_steps = steps_per_epoch * self.hparams.epochs
            warmup_steps = int(total_steps * self.hparams.warmup_ratio)
            scheduler = WarmupLinearLR(
                optimizer, warmup_steps=warmup_steps, max_steps=total_steps)
        else:
            raise NotImplementedError(
                'Only cos, exp, and warmup_lin lr scheduler is Supported!')
        return [optimizer], [scheduler]
            
    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return valid_dataloader

In [None]:
checkpoint_callback = callbacks.ModelCheckpoint(
    filename='val_loss{val_loss:.4f}-val_f1{val_f1:.4f}-epoch{epoch}',
    monitor='val_loss',
    mode='min',
    save_top_k=5,
    auto_insert_metric_name=False,
)
lr_callback = callbacks.LearningRateMonitor(logging_interval='step')

In [None]:
print('Using PyTorch Ver', torch.__version__)
print('Fix Seed:', args['random_seed'])
seed_everything(args['random_seed'])

model = QAModel(**args)

trainer = Trainer(
    callbacks=[lr_callback,
               checkpoint_callback],
    log_every_n_steps=16,  # Logging frequency of **learning rate**
    max_epochs=args['epochs'],
    fast_dev_run=args['test_mode'],
    # logger=wandb_logger,
    num_sanity_val_steps=None if args['test_mode'] else 0,
    # For GPU Setup
    deterministic=torch.cuda.is_available(),
    gpus=[0] if torch.cuda.is_available() else None,  # Use one GPU (idx 0)
    precision=16 if args['fp16'] and torch.cuda.is_available() else 32,
)

In [None]:
##　軽量化
def quantize_transform(model):
    model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
    return model
quantize_transform(model)

In [None]:
print(':: Start Training ::')
trainer.fit(model)

In [None]:
def experiment():
    wandb.login
    wandb_logger = WandbLogger(project="JaQuad",config=args,
                    name = args["note"] + "_"+args["pretrained_model"])
    datasetdict = datasets.load_dataset(
        args['dataset'], use_auth_token=args['huggingface_auth_token'])
    datasetdict = datasetdict.flatten()\
                .rename_column('answers.text', 'answer')\
                .rename_column('answers.answer_start', 'answer_start')\
                .rename_column('answers.answer_type', 'answer_type')
    tokenizer = AutoTokenizer.from_pretrained(
                args["pretrained_tokenizer"]
                if args["pretrained_tokenizer"] else
                args["pretrained_model"],
                use_auth_token=args["huggingface_auth_token"])
    tokenized_dataset = datasetdict.map(
                preprocess_function,
                batched=True,
                remove_columns=datasetdict['train'].column_names)
    #distillbertの際使う
    tokenized_dataset = tokenized_dataset.map(remove_columns=['token_type_ids'])
    tokenized_dataset.set_format(type='torch')
    train_dataloader = DataLoader(tokenized_dataset["train"],shuffle = True,batch_size=args["batch_size"],num_workers=args["cpu_workers"])
    valid_dataloader = DataLoader(tokenized_dataset['validation'], shuffle=False,batch_size=args["batch_size"],num_workers=args["cpu_workers"])
    checkpoint_callback = callbacks.ModelCheckpoint(
        filename='val_loss{val_loss:.4f}-val_f1{val_f1:.4f}-epoch{epoch}',
        monitor='val_loss',
        mode='min',
        save_top_k=5,
        auto_insert_metric_name=False,
    )
    lr_callback = callbacks.LearningRateMonitor(logging_interval='step')
    model = QAModel(**args)
    trainer = Trainer(
        callbacks=[lr_callback,
                checkpoint_callback],
        log_every_n_steps=16,  # Logging frequency of **learning rate**
        max_epochs=args['epochs'],
        fast_dev_run=args['test_mode'],
        logger=wandb_logger,
        num_sanity_val_steps=None if args['test_mode'] else 0,
        # For GPU Setup
        deterministic=torch.cuda.is_available(),
        gpus=[0] if torch.cuda.is_available() else None,  # Use one GPU (idx 0)
        precision=16 if args['fp16'] and torch.cuda.is_available() else 32,
    )
    quantize_transform(model)
    trainer.fit(model)
    wandb.finish()

In [None]:
args['pretrained_model'] = 'Cinnamon/electra-small-japanese-generator'
args['pretrained_tokenizer'] = 'Cinnamon/electra-small-japanese-generator'
experiment()

## Inference

In [None]:
args['pretrained_model'] = 'SkelterLabsInc/bert-base-japanese-jaquad'
model = QAModel(**args)

## ゴミ箱