<a href="https://colab.research.google.com/github/respect5716/deep-learning-paper-implementation/blob/main/03_NLP/TinyBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TinyBERT

## 0. Introduction

### Paper
* title: TInyBERT: Distilling BERT for Natural Language Understanding
* authors: Xiaoqi Jiao et al.
* url: https://arxiv.org/abs/1909.10351

### Reference
* https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT

## 1. Setup

In [None]:
!pip install -q wandb transformers pytorch_lightning datasets

In [1]:
import os
import math
import wandb
import easydict
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from transformers import Trainer, TrainingArguments
from transformers import get_scheduler
from transformers import BatchEncoding
from transformers import DataCollatorForLanguageModeling
from transformers import AutoTokenizer, AutoConfig, AutoModelForMaskedLM, AutoModelForSequenceClassification

from datasets import load_metric, load_dataset, concatenate_datasets

In [2]:
config = easydict.EasyDict(

    data = {
        'datasets': ['namuwiki'],
        'data_dir': 'drive/Shareddrives/dataset',
        'pretrained_model_name_or_path': 'klue/bert-base',
        'batch_size': 4,
        'mlm_probability': 0.15,
        'max_seq_length': 512
    },

    teacher = {
        'model_name_or_path': 'klue/bert-base',
        'hidden_dropout_prob': 0.,
        'attention_probs_dropout_prob': 0.,
        'output_attentions': True,
        'output_hidden_states': True
    },

    student = {
        'num_hidden_layers': 6,
        'hidden_size': 384,
        'intermediate_size': 1536,
        'output_attentions': True,
        'output_hidden_states': True
    },

    optimizer = {
        'name': 'adamw',
        'lr': 6e-4,
        'betas': (0.9, 0.98),
        'weight_decay': 0.01,
    },

    scheduler = {
        'name': 'linear',
        'max_steps': 10000,
        'warmup_ratio': 0.05
    },

    distil = {
        'temperature': 2.,
        'alpha_hidden': 2.0, # hidden states mse loss
        'alpha_attn': 2.0,  # attn mse loss
        'alpha_pred': 0. # logits kl-div loss
    },

    trainer = {
        'gpus': -1,
        'log_every_n_steps': 10,
        'num_sanity_val_steps': 100,
        'val_check_interval': 1000,
        'limit_val_batches': 100,

        'max_steps': 10000,
        'accumulate_grad_batches': 4,
        'gradient_clip_val': 5.0,
        'precision': 32,
    }
)

## 2. Data

In [3]:
class DataModule(pl.LightningDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.pretrained_model_name_or_path)
        
    def setup(self, stage=None):
        dataset = []
        for dname in self.hparams.datasets:
            _dataset = load_dataset('text', data_files=os.path.join(self.hparams.data_dir, f'{dname}.txt'))['train']
            dataset.append(_dataset)

        self.dataset = concatenate_datasets(dataset)
        self.dataset.set_transform(lambda batch: transform(batch, self.tokenizer, self.hparams.max_seq_length))
        self.dataset = self.dataset.train_test_split(test_size=0.01)
        self.train_dataset, self.eval_dataset = self.dataset['train'], self.dataset['test']
        
        self.collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)


    def collate_fn(self, batch):
        batch = BatchEncoding(self.collator(batch))
        batch['attention_mask'] = batch.input_ids.ne(self.tokenizer.pad_token_id).float()
        return batch

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=self.collate_fn)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.eval_dataset, batch_size=self.hparams.batch_size, shuffle=False, collate_fn=self.collate_fn)

    def test_dataloader(self):
        return self.val_dataloader()
    
    
def transform(batch, tokenizer, max_length):
    new_batch = []
    for text in batch['text']:
        text = slice_text(text)
        new_batch.append(text)
    
    return tokenizer(new_batch, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')


def slice_text(text, max_char_length=1024):
    if len(text) > max_char_length:
        idx = np.random.randint(low=0, high=len(text)-max_char_length)
        text = text[idx : idx+max_char_length]
    return text

## 3. Model

In [4]:
def bert_self_attention_forward(
    self,
    hidden_states,
    attention_mask=None,
    head_mask=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    past_key_value=None,
    output_attentions=False,
):
    mixed_query_layer = self.query(hidden_states)
    mixed_key_layer = self.key(hidden_states)
    mixed_value_layer = self.value(hidden_states)

    self.q = mixed_query_layer # (bs, sq, dim)
    self.k = mixed_key_layer # (bs, sq, dim)
    self.v = mixed_value_layer # (bs, sq, dim)
    
    query_layer = self.transpose_for_scores(mixed_query_layer) # (bs, nh, sq, dim)
    key_layer = self.transpose_for_scores(mixed_key_layer) # (bs, nh, sq, dim)
    value_layer = self.transpose_for_scores(mixed_value_layer) # (bs, nh, sq, dim)
    
    if self.is_decoder:
        past_key_value = (key_layer, value_layer)

    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # (bs, nh, sq, sq)
    self.attn = attention_scores

    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
        seq_length = hidden_states.size()[1]
        position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
        position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
        distance = position_ids_l - position_ids_r
        positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
        positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

        if self.position_embedding_type == "relative_key":
            relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores
        elif self.position_embedding_type == "relative_key_query":
            relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

    attention_scores = attention_scores / math.sqrt(self.attention_head_size)
    if attention_mask is not None:
        attention_scores = attention_scores + attention_mask

    attention_probs = nn.Softmax(dim=-1)(attention_scores)
    attention_probs = self.dropout(attention_probs)

    if head_mask is not None:
        attention_probs = attention_probs * head_mask

    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(*new_context_layer_shape)
    self.o = context_layer

    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
    if self.is_decoder:
        outputs = outputs + (past_key_value,)
    return outputs


def to_distill(model):
    model.base_model.encoder.layer[0].attention.self.__class__._forward = bert_self_attention_forward
    for layer in model.base_model.encoder.layer:
        layer.attention.self.forward = layer.attention.self._forward
    return model

def get_layer_mapper(student_num_layers, teacher_num_layers):
    share = teacher_num_layers // student_num_layers
    layer_mapper = {0:0}
    for s in range(1, student_num_layers+1):
        layer_mapper[s] = s * share
    return layer_mapper

def get_attns(model):
    attns = [l.attention.self.attn for l in model.base_model.encoder.layer]
    return attns

def kl_div_loss(s, t, temperature):
    if len(s.size()) != 2:
        s = s.view(-1, s.size(-1))
        t = t.view(-1, t.size(-1))

    s = F.log_softmax(s / temperature, dim=-1)
    t = F.softmax(t / temperature, dim=-1)
    return F.kl_div(s, t, reduction='batchmean') * (temperature ** 2)

In [5]:
optim_dict = {
    'adam': torch.optim.Adam,
    'adamw': torch.optim.AdamW
}

def prepare_optimizer(params, optimizer_hparams):
    name = optimizer_hparams['name']
    hparams = {k:v for k,v in optimizer_hparams.items() if k != 'name'}
    return optim_dict[name](params, **hparams)


def prepare_scheduler(optimizer, scheduler_hparams):
    num_training_steps = scheduler_hparams['max_steps']
    num_warmup_steps = int(num_training_steps * scheduler_hparams['warmup_ratio'])
    scheduler = get_scheduler(scheduler_hparams['name'], optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
    return scheduler


def select_tensor(tensor, attention_mask):
    mask = attention_mask.unsqueeze(-1).expand_as(tensor).bool()
    selected = torch.masked_select(tensor, mask)  # (bs * seq_length * voc_size)
    selected = selected.view(-1, tensor.size(-1))  # (bs * seq_length, voc_size)
    return selected

In [6]:
class Model(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.teacher, self.student, self.tokenizer = self.prepare()
        self.layer_mapper = get_layer_mapper(self.student.config.num_hidden_layers, self.teacher.config.num_hidden_layers)
    
    def prepare(self):
        teacher_kwargs = {k:v for k, v in self.hparams.teacher.items() if k not in ['model_name_or_path']}
        teacher = AutoModelForMaskedLM.from_pretrained(
            self.hparams.teacher.model_name_or_path,
            **teacher_kwargs
        )
        for param in teacher.parameters():
            param.requires_grad = False

        config = AutoConfig.from_pretrained(
            self.hparams.teacher.model_name_or_path,
            **self.hparams.student
        )
        
        student = AutoModelForMaskedLM.from_config(config)
        student.upsampler = nn.ModuleList([nn.Linear(student.config.hidden_size, teacher.config.hidden_size) for _ in range(student.config.num_hidden_layers+1)])

        teacher = to_distill(teacher)
        student = to_distill(student)

        teacher.eval()
        student.train()

        tokenizer = AutoTokenizer.from_pretrained(self.hparams.teacher.model_name_or_path)
        return teacher, student, tokenizer
    
    
    def student_param_groups(self):
        no_decay = ["bias", "bn", "ln", "norm"]
        param_groups = [
            {
                # apply weight decay
                "params": [p for n, p in self.student.named_parameters() if not any(nd in n.lower() for nd in no_decay)],
                "weight_decay": self.hparams.optimizer.weight_decay
            },
            {
                # not apply weight decay
                "params": [p for n, p in self.student.named_parameters() if any(nd in n.lower() for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        return param_groups


    def configure_optimizers(self):
        optimizer = prepare_optimizer(self.student_param_groups(), self.hparams.optimizer)
        scheduler = prepare_scheduler(optimizer, self.hparams.scheduler)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
            }
        }


    def training_step(self, batch, batch_idx):
        loss = 0.
        log = {}

        to = self.teacher(**batch)
        so = self.student(**batch)
        tattns = get_attns(self.teacher)
        sattns = get_attns(self.student)
        
        if self.hparams.distil.alpha_hidden > 0:
            hidden_loss = 0.
            for si, ti in self.layer_mapper.items():
                th, sh = to.hidden_states[ti], so.hidden_states[si]
                sh = self.student.upsampler[si](sh)
                hidden_loss += F.mse_loss(sh, th)
            loss += self.hparams.distil.alpha_hidden * hidden_loss
            log['train/hidden_loss'] = hidden_loss.item()

        if self.hparams.distil.alpha_attn > 0:
            attn_loss = 0.
            for ta, sa in zip(tattns, sattns):
                attn_loss += F.mse_loss(sa, ta) / (ta.size(0) * ta.size(1))
            loss += self.hparams.distil.alpha_attn * attn_loss
            log['train/attn_loss'] = attn_loss.item()

        if self.hparams.distil.alpha_pred > 0:
            pred_loss = kl_div_loss(so.logits, to.logits, self.hparams.distil.temperature)
            loss += self.hparams.distil.alpha_pred * pred_loss
            log['train/pred_loss'] = pred_loss.item()

        log['train/loss'] = loss
        self.log_dict(log, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    @pl.utilities.rank_zero_only
    def on_save_checkpoint(self, checkpoint):
        ckpt_dir = os.path.join('ckpt', f'{self.trainer.global_step:06d}')
        self.student.save_pretrained(ckpt_dir)
        self.tokenizer.save_pretrained(ckpt_dir)

## 4. Distillation

In [7]:
data_module = DataModule(**config.data)

In [8]:
model = Model(**config)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
logger = pl.loggers.WandbLogger(
    project = 'paper',
    log_model = False,
    reinit = True,
)

logger.watch(model, log='gradients')

[34m[1mwandb[0m: Currently logged in as: [33mrespect5716[0m (use `wandb login --relogin` to force relogin)


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [10]:
lr_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')

In [11]:
trainer = pl.Trainer(    
    logger = logger,
    callbacks = [lr_callback],
    resume_from_checkpoint = False,
    **config.trainer
)

  "Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [12]:
data_module.setup()

Using custom data configuration default-a5f702edb7742337
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-a5f702edb7742337/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8)


  0%|          | 0/1 [00:00<?, ?it/s]

In [13]:
trainer.fit(model, data_module)

  "`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type            | Params
--------------------------------------------
0 | teacher | BertForMaskedLM | 110 M 
1 | student | BertForMaskedLM | 25.4 M
--------------------------------------------
25.4 M    Trainable params
110 M     Non-trainable params
136 M     Total params
544.137   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [14]:
model.student.save_pretrained('transformers')
model.tokenizer.save_pretrained('transformers')

('transformers/tokenizer_config.json',
 'transformers/special_tokens_map.json',
 'transformers/vocab.txt',
 'transformers/added_tokens.json',
 'transformers/tokenizer.json')

## 5. Donwstream

In [15]:
def set_example(example):
    """example -> text_a, text_b, label
    """
    return {'text_a': example['premise'], 'text_b': example['hypothesis'], 'labels': example['label']}


def convert_example_to_feature(example, tokenizer, max_length):
    """text_a, text_b, label -> input_ids, attention_mask, token_type_ids, label
    """
    feature = tokenizer(
        example['text_a'], example['text_b'], 
        max_length = max_length, 
        padding = 'max_length', 
        truncation = True
    )
    return feature

In [16]:
ckpt_dir = 'transformers'
model = AutoModelForSequenceClassification.from_pretrained(ckpt_dir, num_labels=3)
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)

Some weights of the model checkpoint at transformers were not used when initializing BertForSequenceClassification: ['upsampler.2.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'upsampler.3.bias', 'upsampler.4.bias', 'cls.predictions.decoder.weight', 'upsampler.3.weight', 'upsampler.5.bias', 'cls.predictions.transform.dense.weight', 'upsampler.2.weight', 'upsampler.1.bias', 'cls.predictions.bias', 'upsampler.6.bias', 'upsampler.5.weight', 'upsampler.4.weight', 'upsampler.0.bias', 'upsampler.1.weight', 'cls.predictions.transform.LayerNorm.bias', 'upsampler.0.weight', 'cls.predictions.decoder.bias', 'upsampler.6.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification fr

In [17]:
dataset = load_dataset('klue', 'nli')
dataset = dataset.map(set_example)
dataset = dataset.map(lambda example: convert_example_to_feature(example, tokenizer, 256))
dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

Reusing dataset klue (/root/.cache/huggingface/datasets/klue/nli/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/klue/nli/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e/cache-d887b1d0663cd4a2.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/klue/nli/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e/cache-11c31ebdcbd1555a.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/klue/nli/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e/cache-cd8b36c12f24e315.arrow


  0%|          | 0/3000 [00:00<?, ?ex/s]

In [18]:
training_args = TrainingArguments(
    'training_args',
    num_train_epochs = 3,
)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = dataset['train'],
    eval_dataset = dataset['validation'],
)

In [19]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: premise, hypothesis, text_b, guid, text_a, source. If premise, hypothesis, text_b, guid, text_a, source are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 24998
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 9375
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss
500,1.0781
1000,0.916
1500,0.8553
2000,0.8071
2500,0.7916
3000,0.7668
3500,0.6779
4000,0.6435
4500,0.6572
5000,0.6279


Saving model checkpoint to training_args/checkpoint-500
Configuration saved in training_args/checkpoint-500/config.json
Model weights saved in training_args/checkpoint-500/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-1000
Configuration saved in training_args/checkpoint-1000/config.json
Model weights saved in training_args/checkpoint-1000/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-1500
Configuration saved in training_args/checkpoint-1500/config.json
Model weights saved in training_args/checkpoint-1500/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-2000
Configuration saved in training_args/checkpoint-2000/config.json
Model weights saved in training_args/checkpoint-2000/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-2500
Configuration saved in training_args/checkpoint-2500/config.json
Model weights saved in training_args/checkpoint-2500/pytorch_model.bin
Saving model checkpoint to training_ar

TrainOutput(global_step=9375, training_loss=0.6687557552083333, metrics={'train_runtime': 414.0877, 'train_samples_per_second': 181.107, 'train_steps_per_second': 22.64, 'total_flos': 1243662713422848.0, 'train_loss': 0.6687557552083333, 'epoch': 3.0})

In [20]:
loader = torch.utils.data.DataLoader(dataset['validation'], batch_size=8, shuffle=False)

correct = []
for batch in tqdm(loader):
    batch = {k:v.cuda() for k, v in batch.items()}
    outputs = model(**batch)
    preds = outputs.logits.argmax(dim=1)
    _corrct = (batch['labels'] == preds).cpu()
    correct.append(_corrct)

acc = torch.cat(correct).float().mean() * 100
print(f'ACC: {acc:.2f}')

  0%|          | 0/375 [00:00<?, ?it/s]

ACC: 63.47
