In [10]:
import sys
%load_ext autoreload
%autoreload 2
sys.path.append('..')

import numpy as np
import random
import torch

from pytorch_pretrained_bert.tokenization import BertTokenizer

from lib import data_processors, tasks
from lib.bert import BertForSequenceClassification
from lib.train_eval import train, evaluate, predict

from lib.train_student import eval_teacher_soft_targets, train_student

from pytorch_pretrained_bert.modeling import BertConfig
import os

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Скачиваем модель-учителя - Bert.

In [4]:
# %env CUDA_VISIBLE_DEVICES=1
params = {
    'data_dir': '../datasets/SST-2',
    'output_dir': '../output',
    'cache_dir': '../model_cache',
    'task_name': 'sst2',
    'bert_model': 'bert-base-uncased',
    'max_seq_length': 128,
    'train_batch_size': 32,
    'eval_batch_size': 8,
    'learning_rate': 2e-5,
    'warmup_proportion': 0.1,
    'num_train_epochs': 1,
    'seed': 1331,
    'device': torch.device(
        'cuda' if torch.cuda.is_available()
        else 'cpu')
}

random.seed(params['seed'])
np.random.seed(params['seed'])
torch.manual_seed(params['seed'])

params['num_labels'] = tasks.num_labels[params['task_name']]
params['label_list'] = tasks.label_lists[params['task_name']]

processor = tasks.processors[params['task_name']]()
tokenizer = BertTokenizer.from_pretrained(
    params['bert_model'], do_lower_case=True)

train_examples = processor.get_train_examples(params['data_dir'])
dev_examples = processor.get_dev_examples(params['data_dir'])

checkpoint_files = {
    'config': 'bert_config.json',
    'model_weigths': 'model_{}_epoch_1.pth'.format(
        params['task_name'])
}

# Load a trained model and config that you have fine-tuned
config = BertConfig(os.path.join(params['output_dir'], checkpoint_files['config']))
teacher_model = BertForSequenceClassification(config, num_labels=params['num_labels'])

teacher_model.load_state_dict(torch.load(os.path.join(params['output_dir'], checkpoint_files['model_weigths'])))
teacher_model.to(params['device'])

05/14/2019 18:44:44 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/shakhrayv/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertInterme

В качестве студента возьмем, например, берта без первых 3 блоков и последних 3ех блоков. 

In [5]:
model = BertForSequenceClassification.from_pretrained(
    params['bert_model'],
    cache_dir=params['cache_dir'],
    num_labels=params['num_labels']).to(params['device'])

all_layers = model.bert.encoder.layer
model.bert.encoder.layer = all_layers[3:9]

05/14/2019 18:45:10 - INFO - lib.bert -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ../model_cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
05/14/2019 18:45:10 - INFO - lib.bert -   extracting archive file ../model_cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmp_p_e_foq
05/14/2019 18:45:12 - INFO - lib.bert -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

05/14/2019 18:45:14 - INFO - lib.bert -   Weights of BertForSequenceClassification not initi

Обучаем студента при помощи дистилляции. 

Функция train_student принимает те же параметры, что и обычный train, а также модель учителя teacher_model и название модели-студента для сохранения (необязательный параметр).

In [12]:
model, result = train_student(model, teacher_model, tokenizer, params, train_examples, 
                              valid_examples=dev_examples, name='bert_3_9_blocks', 
                              checkpoint_files=checkpoint_files)

***** Running training *****
Num examples: 33674
Batch size:   32
Num steps:    1052


Evaluating: 100%|██████████| 1053/1053 [01:36<00:00, 10.80it/s]


Epoch: 1


HBox(children=(IntProgress(value=0, description='Iteration', max=1053, style=ProgressStyle(description_width='…

  loss_first = KLDivLoss()(F.log_softmax(logits_model / temperature), F.softmax(teacher_logits / temperature))
  loss_first = KLDivLoss()(F.log_softmax(logits_model / temperature), F.softmax(teacher_logits / temperature))
Evaluating:   0%|          | 0/109 [00:00<?, ?it/s]


{'train_loss': 0.28954526175747486, 'train_global_step': 1053}
***** Running evaluation *****
Num examples:  872
Batch size:    8


Evaluating: 100%|██████████| 109/109 [00:01<00:00, 70.76it/s]


{'eval_loss': 0.28102876908287894, 'eval_accuracy': 0.8979357798165137, 'eval_f1_score': 0.8996617812852311, 'eval_matthews_corrcoef': 0.7958135025496113}


Данный код также сначала посчитает soft-logits для модели учителя. При каждом запуске он будет их считать заново. Поэтому если не хочется перезапускать обучение несколько раз, то лучше сделать так:

In [None]:
# compute soft logits from teacher
all_logits_teacher = eval_teacher_soft_targets(teacher_model, tokenizer, params, train_examples)

# train student via distillation
model, result = train_student(model, teacher_model, tokenizer, params, train_examples, 
                              valid_examples=dev_examples, name='bert_3_9_blocks', 
                              checkpoint_files=checkpoint_files, 
                              all_logits_teacher=all_logits_teacher)