In [1]:
import evaluate
from datasets import DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForMultipleChoice,
    TrainingArguments,
    Trainer,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
c3 = DatasetDict.load_from_disk('./c3/')
c3

DatasetDict({
    test: Dataset({
        features: ['id', 'context', 'question', 'choice', 'answer'],
        num_rows: 1625
    })
    train: Dataset({
        features: ['id', 'context', 'question', 'choice', 'answer'],
        num_rows: 11869
    })
    validation: Dataset({
        features: ['id', 'context', 'question', 'choice', 'answer'],
        num_rows: 3816
    })
})

In [3]:
c3['train'][0]

{'id': 0,
 'context': ['男：你今天晚上有时间吗?我们一起去看电影吧?', '女：你喜欢恐怖片和爱情片，但是我喜欢喜剧片，科幻片一般。所以……'],
 'question': '女的最喜欢哪种电影?',
 'choice': ['恐怖片', '爱情片', '喜剧片', '科幻片'],
 'answer': '喜剧片'}

In [4]:
c3.pop("test")

Dataset({
    features: ['id', 'context', 'question', 'choice', 'answer'],
    num_rows: 1625
})

In [5]:
c3

DatasetDict({
    train: Dataset({
        features: ['id', 'context', 'question', 'choice', 'answer'],
        num_rows: 11869
    })
    validation: Dataset({
        features: ['id', 'context', 'question', 'choice', 'answer'],
        num_rows: 3816
    })
})

In [6]:
tokenizer = AutoTokenizer.from_pretrained('../chinese-macbert-base/')
tokenizer

BertTokenizerFast(name_or_path='../chinese-macbert-base/', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

In [7]:
c3['train']['context'][1]

['男：足球比赛是明天上午八点开始吧?', '女：因为天气不好，比赛改到后天下午三点了。']

In [8]:
def process_function(examples):
    context = []
    question_choice = []
    labels = []
    for idx in range(len(examples["context"])):
        ctx = "\n".join(examples["context"][idx])
        question = examples["question"][idx]
        choices = examples["choice"][idx]
        for choice in choices:
            context.append(ctx)
            question_choice.append(question + " " + choice)
        if len(choices) < 4:
            for _ in range(4 - len(choices)):
                context.append(ctx)
                question_choice.append(question + " " + "不知道")
        labels.append(choices.index(examples["answer"][idx]))
    tokenized_examples = tokenizer(
        context,
        question_choice,
        truncation="only_first",
        max_length=256,
        padding="max_length",
    )
    tokenized_examples = {
        k: [v[i : i + 4] for i in range(0, len(v), 4)]
        for k, v in tokenized_examples.items()
    }
    tokenized_examples["labels"] = labels
    return tokenized_examples

In [9]:
tokenized_c3 = c3.map(process_function, batched=True)
tokenized_c3

Map:   0%|          | 0/3816 [00:00<?, ? examples/s]

Map: 100%|██████████| 3816/3816 [00:02<00:00, 1861.69 examples/s]


DatasetDict({
    train: Dataset({
        features: ['id', 'context', 'question', 'choice', 'answer', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 11869
    })
    validation: Dataset({
        features: ['id', 'context', 'question', 'choice', 'answer', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 3816
    })
})

In [10]:
model = AutoModelForMultipleChoice.from_pretrained("../chinese-macbert-base/")

Some weights of the model checkpoint at ../chinese-macbert-base/ were not used when initializing BertForMultipleChoice: ['cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model checkp

In [11]:
import numpy as np

accuracy = evaluate.load("accuracy")


def compute_metric(pred):
    predcitions, labels = pred
    predcitions = np.argmax(predcitions, axis=-1)
    return accuracy.compute(predictions=predcitions, references=labels)

In [12]:
args = TrainingArguments(
    output_dir="./mutiple_choice",
    per_device_train_batch_size=1,  # 训练时的batch_size
    gradient_accumulation_steps=16,  # *** 梯度累加 ***
    gradient_checkpointing=True,  # *** 梯度检查点 ***
    optim="adafactor",  # *** adafactor优化器 ***
    per_device_eval_batch_size=1,  # 验证时的batch_size
    num_train_epochs=1,  # 训练轮数
    logging_steps=10,  # log 打印的频率
    evaluation_strategy="epoch",  # 评估策略
    save_strategy="epoch",  # 保存策略
    save_total_limit=3,  # 最大保存数
    load_best_model_at_end=True,
    fp16=True,
)

In [13]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_c3["train"],
    eval_dataset=tokenized_c3["validation"],
    compute_metrics=compute_metric,
)

In [14]:
import torch

if hasattr(torch.cuda, "empty_cache"):
    torch.cuda.empty_cache()
trainer.train()

  1%|▏         | 10/741 [00:15<17:04,  1.40s/it]

{'loss': 1.3899, 'learning_rate': 4.939271255060729e-05, 'epoch': 0.01}


  3%|▎         | 20/741 [00:29<17:01,  1.42s/it]

{'loss': 1.4213, 'learning_rate': 4.878542510121458e-05, 'epoch': 0.03}


  4%|▍         | 30/741 [00:44<16:45,  1.41s/it]

{'loss': 1.4022, 'learning_rate': 4.8110661268556004e-05, 'epoch': 0.04}


  5%|▌         | 40/741 [00:58<16:23,  1.40s/it]

{'loss': 1.4053, 'learning_rate': 4.7435897435897435e-05, 'epoch': 0.05}


  7%|▋         | 50/741 [01:12<16:10,  1.40s/it]

{'loss': 1.3795, 'learning_rate': 4.676113360323887e-05, 'epoch': 0.07}


  8%|▊         | 60/741 [01:27<16:45,  1.48s/it]

{'loss': 1.4011, 'learning_rate': 4.6086369770580304e-05, 'epoch': 0.08}


  9%|▉         | 70/741 [01:42<16:30,  1.48s/it]

{'loss': 1.406, 'learning_rate': 4.541160593792173e-05, 'epoch': 0.09}


 11%|█         | 80/741 [01:57<16:20,  1.48s/it]

{'loss': 1.3972, 'learning_rate': 4.473684210526316e-05, 'epoch': 0.11}


 12%|█▏        | 90/741 [02:12<15:24,  1.42s/it]

{'loss': 1.4368, 'learning_rate': 4.406207827260459e-05, 'epoch': 0.12}


 13%|█▎        | 100/741 [02:29<17:37,  1.65s/it]

{'loss': 1.3938, 'learning_rate': 4.338731443994602e-05, 'epoch': 0.13}


 15%|█▍        | 110/741 [02:45<16:48,  1.60s/it]

{'loss': 1.3895, 'learning_rate': 4.271255060728745e-05, 'epoch': 0.15}


 16%|█▌        | 120/741 [03:00<16:01,  1.55s/it]

{'loss': 1.3825, 'learning_rate': 4.2037786774628884e-05, 'epoch': 0.16}


 18%|█▊        | 130/741 [03:16<15:48,  1.55s/it]

{'loss': 1.4182, 'learning_rate': 4.1363022941970315e-05, 'epoch': 0.18}


 19%|█▉        | 140/741 [03:32<16:05,  1.61s/it]

{'loss': 1.4153, 'learning_rate': 4.0688259109311746e-05, 'epoch': 0.19}


 20%|██        | 150/741 [03:48<16:01,  1.63s/it]

{'loss': 1.4036, 'learning_rate': 4.001349527665318e-05, 'epoch': 0.2}


 22%|██▏       | 160/741 [04:05<15:56,  1.65s/it]

{'loss': 1.3994, 'learning_rate': 3.93387314439946e-05, 'epoch': 0.22}


 23%|██▎       | 170/741 [04:21<15:37,  1.64s/it]

{'loss': 1.3929, 'learning_rate': 3.866396761133603e-05, 'epoch': 0.23}


 24%|██▍       | 180/741 [04:38<15:17,  1.64s/it]

{'loss': 1.4061, 'learning_rate': 3.798920377867746e-05, 'epoch': 0.24}


 26%|██▌       | 190/741 [04:54<14:56,  1.63s/it]

{'loss': 1.3875, 'learning_rate': 3.7314439946018894e-05, 'epoch': 0.26}


 27%|██▋       | 200/741 [05:11<14:43,  1.63s/it]

{'loss': 1.4207, 'learning_rate': 3.6639676113360325e-05, 'epoch': 0.27}


 28%|██▊       | 210/741 [05:27<14:28,  1.64s/it]

{'loss': 1.3836, 'learning_rate': 3.5964912280701756e-05, 'epoch': 0.28}


 30%|██▉       | 220/741 [05:43<13:38,  1.57s/it]

{'loss': 1.3892, 'learning_rate': 3.529014844804319e-05, 'epoch': 0.3}


 31%|███       | 230/741 [05:59<13:13,  1.55s/it]

{'loss': 1.3737, 'learning_rate': 3.461538461538462e-05, 'epoch': 0.31}


 32%|███▏      | 240/741 [06:15<12:50,  1.54s/it]

{'loss': 1.393, 'learning_rate': 3.394062078272604e-05, 'epoch': 0.32}


 34%|███▎      | 250/741 [06:29<11:57,  1.46s/it]

{'loss': 1.4013, 'learning_rate': 3.3265856950067474e-05, 'epoch': 0.34}


 35%|███▌      | 260/741 [06:44<11:36,  1.45s/it]

{'loss': 1.391, 'learning_rate': 3.259109311740891e-05, 'epoch': 0.35}


 36%|███▋      | 270/741 [06:59<11:22,  1.45s/it]

{'loss': 1.3857, 'learning_rate': 3.191632928475034e-05, 'epoch': 0.36}


 38%|███▊      | 280/741 [07:14<11:10,  1.45s/it]

{'loss': 1.3952, 'learning_rate': 3.124156545209177e-05, 'epoch': 0.38}


 39%|███▉      | 290/741 [07:28<10:54,  1.45s/it]

{'loss': 1.3977, 'learning_rate': 3.05668016194332e-05, 'epoch': 0.39}


 40%|████      | 300/741 [07:43<10:40,  1.45s/it]

{'loss': 1.3995, 'learning_rate': 2.989203778677463e-05, 'epoch': 0.4}


 42%|████▏     | 310/741 [07:58<10:25,  1.45s/it]

{'loss': 1.4016, 'learning_rate': 2.921727395411606e-05, 'epoch': 0.42}


 43%|████▎     | 320/741 [08:12<10:11,  1.45s/it]

{'loss': 1.41, 'learning_rate': 2.8542510121457488e-05, 'epoch': 0.43}


 45%|████▍     | 330/741 [08:27<09:59,  1.46s/it]

{'loss': 1.3738, 'learning_rate': 2.7867746288798923e-05, 'epoch': 0.44}


 46%|████▌     | 340/741 [08:42<09:45,  1.46s/it]

{'loss': 1.3874, 'learning_rate': 2.7192982456140354e-05, 'epoch': 0.46}


 47%|████▋     | 350/741 [08:57<10:02,  1.54s/it]

{'loss': 1.4129, 'learning_rate': 2.6518218623481785e-05, 'epoch': 0.47}


 49%|████▊     | 360/741 [09:13<09:49,  1.55s/it]

{'loss': 1.4121, 'learning_rate': 2.5843454790823212e-05, 'epoch': 0.49}


 50%|████▉     | 370/741 [09:28<09:22,  1.52s/it]

{'loss': 1.3979, 'learning_rate': 2.5168690958164643e-05, 'epoch': 0.5}


 51%|█████▏    | 380/741 [09:44<09:47,  1.63s/it]

{'loss': 1.4038, 'learning_rate': 2.4493927125506075e-05, 'epoch': 0.51}


 53%|█████▎    | 390/741 [09:59<08:24,  1.44s/it]

{'loss': 1.3871, 'learning_rate': 2.3819163292847506e-05, 'epoch': 0.53}


 54%|█████▍    | 400/741 [10:14<08:13,  1.45s/it]

{'loss': 1.3813, 'learning_rate': 2.3144399460188933e-05, 'epoch': 0.54}


 55%|█████▌    | 410/741 [10:28<08:02,  1.46s/it]

{'loss': 1.3986, 'learning_rate': 2.2469635627530368e-05, 'epoch': 0.55}


 57%|█████▋    | 420/741 [10:43<07:36,  1.42s/it]

{'loss': 1.4083, 'learning_rate': 2.1794871794871795e-05, 'epoch': 0.57}


 58%|█████▊    | 430/741 [10:57<07:29,  1.45s/it]

{'loss': 1.3766, 'learning_rate': 2.1120107962213226e-05, 'epoch': 0.58}


 59%|█████▉    | 440/741 [11:12<07:17,  1.45s/it]

{'loss': 1.4168, 'learning_rate': 2.0445344129554654e-05, 'epoch': 0.59}


 61%|██████    | 450/741 [11:26<07:02,  1.45s/it]

{'loss': 1.3942, 'learning_rate': 1.977058029689609e-05, 'epoch': 0.61}


 62%|██████▏   | 460/741 [11:41<06:38,  1.42s/it]

{'loss': 1.3758, 'learning_rate': 1.9095816464237516e-05, 'epoch': 0.62}


 63%|██████▎   | 470/741 [11:55<06:31,  1.44s/it]

{'loss': 1.3516, 'learning_rate': 1.8421052631578947e-05, 'epoch': 0.63}


 65%|██████▍   | 480/741 [12:10<06:19,  1.45s/it]

{'loss': 1.3174, 'learning_rate': 1.774628879892038e-05, 'epoch': 0.65}


 66%|██████▌   | 490/741 [12:24<06:04,  1.45s/it]

{'loss': 1.3461, 'learning_rate': 1.707152496626181e-05, 'epoch': 0.66}


 67%|██████▋   | 500/741 [12:39<05:43,  1.43s/it]

{'loss': 1.317, 'learning_rate': 1.639676113360324e-05, 'epoch': 0.67}


 69%|██████▉   | 510/741 [12:53<05:37,  1.46s/it]

{'loss': 1.3299, 'learning_rate': 1.572199730094467e-05, 'epoch': 0.69}


 70%|███████   | 520/741 [13:08<05:17,  1.44s/it]

{'loss': 1.3873, 'learning_rate': 1.5114709851551959e-05, 'epoch': 0.7}


 72%|███████▏  | 530/741 [13:22<05:04,  1.45s/it]

{'loss': 1.3392, 'learning_rate': 1.4507422402159246e-05, 'epoch': 0.71}


 73%|███████▎  | 540/741 [13:37<04:46,  1.42s/it]

{'loss': 1.3472, 'learning_rate': 1.3832658569500675e-05, 'epoch': 0.73}


 74%|███████▍  | 550/741 [13:51<04:37,  1.46s/it]

{'loss': 1.3664, 'learning_rate': 1.3157894736842106e-05, 'epoch': 0.74}


 76%|███████▌  | 560/741 [14:06<04:22,  1.45s/it]

{'loss': 1.4005, 'learning_rate': 1.2483130904183535e-05, 'epoch': 0.75}


 77%|███████▋  | 570/741 [14:20<04:01,  1.41s/it]

{'loss': 1.3573, 'learning_rate': 1.1808367071524966e-05, 'epoch': 0.77}


 78%|███████▊  | 580/741 [14:34<03:48,  1.42s/it]

{'loss': 1.3314, 'learning_rate': 1.1133603238866398e-05, 'epoch': 0.78}


 80%|███████▉  | 590/741 [14:49<03:37,  1.44s/it]

{'loss': 1.319, 'learning_rate': 1.0458839406207829e-05, 'epoch': 0.8}


 81%|████████  | 600/741 [15:03<03:19,  1.41s/it]

{'loss': 1.2926, 'learning_rate': 9.784075573549258e-06, 'epoch': 0.81}


 82%|████████▏ | 610/741 [15:17<03:05,  1.42s/it]

{'loss': 1.3309, 'learning_rate': 9.109311740890689e-06, 'epoch': 0.82}


 84%|████████▎ | 620/741 [15:32<02:55,  1.45s/it]

{'loss': 1.2818, 'learning_rate': 8.43454790823212e-06, 'epoch': 0.84}


 85%|████████▌ | 630/741 [15:46<02:40,  1.45s/it]

{'loss': 1.3321, 'learning_rate': 7.75978407557355e-06, 'epoch': 0.85}


 86%|████████▋ | 640/741 [16:01<02:22,  1.41s/it]

{'loss': 1.3479, 'learning_rate': 7.0850202429149805e-06, 'epoch': 0.86}


 88%|████████▊ | 650/741 [16:15<02:09,  1.42s/it]

{'loss': 1.3946, 'learning_rate': 6.41025641025641e-06, 'epoch': 0.88}


 89%|████████▉ | 660/741 [16:29<01:54,  1.41s/it]

{'loss': 1.3247, 'learning_rate': 5.735492577597841e-06, 'epoch': 0.89}


 90%|█████████ | 670/741 [16:44<01:40,  1.41s/it]

{'loss': 1.3778, 'learning_rate': 5.060728744939271e-06, 'epoch': 0.9}


 92%|█████████▏| 680/741 [16:58<01:26,  1.43s/it]

{'loss': 1.3408, 'learning_rate': 4.3859649122807014e-06, 'epoch': 0.92}


 93%|█████████▎| 690/741 [17:13<01:14,  1.46s/it]

{'loss': 1.3401, 'learning_rate': 3.711201079622133e-06, 'epoch': 0.93}


 94%|█████████▍| 700/741 [17:27<00:58,  1.43s/it]

{'loss': 1.3509, 'learning_rate': 3.0364372469635627e-06, 'epoch': 0.94}


 96%|█████████▌| 710/741 [17:42<00:43,  1.42s/it]

{'loss': 1.333, 'learning_rate': 2.3616734143049934e-06, 'epoch': 0.96}


 97%|█████████▋| 720/741 [17:56<00:29,  1.42s/it]

{'loss': 1.3376, 'learning_rate': 1.6869095816464238e-06, 'epoch': 0.97}


 99%|█████████▊| 730/741 [18:10<00:15,  1.41s/it]

{'loss': 1.3089, 'learning_rate': 1.0121457489878542e-06, 'epoch': 0.98}


100%|█████████▉| 740/741 [18:25<00:01,  1.44s/it]

{'loss': 1.3184, 'learning_rate': 3.373819163292848e-07, 'epoch': 1.0}


100%|██████████| 741/741 [18:26<00:00,  1.45s/it]
100%|██████████| 741/741 [19:57<00:00,  1.45s/it]  

{'eval_loss': 1.3249152898788452, 'eval_accuracy': 0.30660377358490565, 'eval_runtime': 90.3156, 'eval_samples_per_second': 42.252, 'eval_steps_per_second': 42.252, 'epoch': 1.0}


100%|██████████| 741/741 [19:57<00:00,  1.62s/it]

{'train_runtime': 1197.581, 'train_samples_per_second': 9.911, 'train_steps_per_second': 0.619, 'train_loss': 1.375907692027639, 'epoch': 1.0}





TrainOutput(global_step=741, training_loss=1.375907692027639, metrics={'train_runtime': 1197.581, 'train_samples_per_second': 9.911, 'train_steps_per_second': 0.619, 'train_loss': 1.375907692027639, 'epoch': 1.0})

In [16]:
from typing import Any


class MultipleChoicePipeline:
    def __init__(self, model, tokenizer) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.device = model.device

    def preprocess(self, context, question, choices):
        cs, qcs = [], []
        for choice in choices:
            cs.append(context)
            qcs.append(question + " " + choice)
        return tokenizer(
            cs, qcs, truncation="only_first", max_length=256, return_tensors="pt"
        )

    def predict(self, inputs):
        inputs = {k: v.unsqueeze(0).to(self.device) for k, v in inputs.items()}
        return self.model(**inputs).logits

    def postprocess(self, logits, choices):
        predition = torch.argmax(logits, dim=-1).cpu().item()
        return choices[predition]

    def __call__(self, context, question, choices) -> Any:
        inputs = self.preprocess(context, question, choices)
        logits = self.predict(inputs)
        result = self.postprocess(logits, choices)
        return result

In [17]:
pipe = MultipleChoicePipeline(model, tokenizer)

In [18]:
pipe("小明在北京上班", "小明在哪里上班？", ["北京", "上海", "河北", "海南", "河北", "海南"])

'北京'