#### Скачаем `SWAG` датасет

In [None]:
!git clone https://github.com/rowanz/swagaf.git
!mv swagaf/data/ ../datasets/SWAG
!rm -fr swagaf

In [1]:
import sys
%load_ext autoreload
%autoreload 2
sys.path.append('..')

import numpy as np
import random
import torch
import os
from pytorch_pretrained_bert.tokenization import BertTokenizer

from lib import data_processors, tasks
from pytorch_pretrained_bert import BertForMultipleChoice
from lib.train_eval import train, evaluate, predict

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:

params = {
    'data_dir': '../datasets/SWAG',
    'output_dir': '../output',
    'cache_dir': '../model_cache',
    'task_name': 'swag',
    'bert_model': 'bert-base-uncased',
    'max_seq_length': 128,
    'train_batch_size': 12,
    'eval_batch_size': 8,
    'learning_rate': 2e-5,
    'warmup_proportion': 0.1,
    'num_train_epochs': 1,
    'seed': 1331,
    'device': torch.device(
        'cuda' if torch.cuda.is_available()
        else 'cpu')
}

random.seed(params['seed'])
np.random.seed(params['seed'])
torch.manual_seed(params['seed'])

<torch._C.Generator at 0x7fe090000730>

In [3]:
processor = tasks.processors[params['task_name']]()
tokenizer = BertTokenizer.from_pretrained(
    params['bert_model'], do_lower_case=True)

train_examples = processor.get_train_examples(params['data_dir'])
dev_examples = processor.get_dev_examples(params['data_dir'])

model = BertForMultipleChoice.from_pretrained(
    params['bert_model'],
    cache_dir=params['cache_dir'], num_choices=4).to(params['device'])

In [5]:
EPOCH_NUM = 1

params['num_train_epochs'] = 1
checkpoint_files = {
    'config': 'bert_config.json',
    'model_weigths': 'model_{}_epoch_{}.pth'.format(
        params['task_name'], EPOCH_NUM)
}

model, result = train(model, tokenizer, params,
                      train_examples,
                      valid_examples=dev_examples,
                      checkpoint_files=checkpoint_files)


converting examples: 100%|██████████| 73546/73546 [01:18<00:00, 937.95it/s] 


***** Running training *****
Num examples: 73546
Batch size:   12
Num steps:    6128

Epoch: 1


Iteration:   3%|▎         | 202/6129 [00:33<17:18,  5.71it/s]

KeyboardInterrupt: 

In [None]:
result, prob_preds = evaluate(model, tokenizer, params,
                              dev_examples)
result

converting examples:   0%|          | 67/20006 [00:00<00:30, 660.73it/s]

***** Running evaluation *****
Num examples:  20006
Batch size:    8


converting examples: 100%|██████████| 20006/20006 [00:23<00:00, 844.60it/s]
Evaluating:  28%|██▊       | 702/2501 [01:44<04:29,  6.69it/s]