<a href="https://colab.research.google.com/github/respect5716/deep-learning-paper-implementation/blob/main/03_NLP/MiniLMv2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MiniLMv2

## 0. Introduction

### Paper
* title: MiniLMv2: Multi-Head Self-Attention Relation Distillation for Compressing Pretrained Transformers
* author: Wenhui Wang et al.
* url: https://arxiv.org/abs/2012.15828

### Reference
* https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/model_compression/minilmv2

## 1. Setup

In [1]:
!pip install -q wandb transformers pytorch_lightning datasets

[K     |████████████████████████████████| 1.7 MB 15.0 MB/s 
[K     |████████████████████████████████| 3.3 MB 55.5 MB/s 
[K     |████████████████████████████████| 525 kB 70.0 MB/s 
[K     |████████████████████████████████| 298 kB 58.7 MB/s 
[K     |████████████████████████████████| 97 kB 8.4 MB/s 
[K     |████████████████████████████████| 140 kB 66.1 MB/s 
[K     |████████████████████████████████| 180 kB 68.6 MB/s 
[K     |████████████████████████████████| 63 kB 2.1 MB/s 
[K     |████████████████████████████████| 596 kB 62.0 MB/s 
[K     |████████████████████████████████| 895 kB 57.1 MB/s 
[K     |████████████████████████████████| 3.3 MB 56.6 MB/s 
[K     |████████████████████████████████| 61 kB 629 kB/s 
[K     |████████████████████████████████| 332 kB 65.6 MB/s 
[K     |████████████████████████████████| 829 kB 57.2 MB/s 
[K     |████████████████████████████████| 132 kB 64.6 MB/s 
[K     |████████████████████████████████| 1.1 MB 51.4 MB/s 
[K     |█████████████████████

In [1]:
import os
import math
import wandb
import easydict
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from transformers import Trainer, TrainingArguments
from transformers import get_scheduler
from transformers import BatchEncoding
from transformers import DataCollatorForWholeWordMask
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForSequenceClassification

from datasets import load_metric, load_dataset, concatenate_datasets

In [2]:
config = easydict.EasyDict(

    data = {
        'datasets': ['namuwiki'],
        'data_dir': 'drive/Shareddrives/dataset',
        'pretrained_model_name_or_path': 'klue/bert-base',
        'batch_size': 4,
        'mlm_probability': 0.15,
        'max_seq_length': 512
    },

    teacher = {
        'model_name_or_path': 'klue/bert-base',
        'hidden_dropout_prob': 0.,
        'attention_probs_dropout_prob': 0.,
        'output_attentions': True,
        'output_hidden_states': True
    },

    student = {
        'num_hidden_layers': 3,
        'hidden_size': 384,
        'hidden_dropout_prob': 0.,
        'attention_probs_dropout_prob': 0.,
        'output_attentions': True,
        'output_hidden_states': True
    },

    optimizer = {
        'name': 'adamw',
        'lr': 6e-4,
        'betas': (0.9, 0.98),
        'weight_decay': 0.01,
    },

    scheduler = {
        'name': 'linear',
        'max_steps': 10000,
        'warmup_ratio': 0.05
    },

    distil = {
        'temperature': 2.,
        'teacher_layer_index': -1,
        'student_layer_index': -1,
        'num_relation_heads': 48,
    },

    trainer = {
        'gpus': -1,
        'log_every_n_steps': 10,
        'num_sanity_val_steps': 100,
        'val_check_interval': 1000,
        'limit_val_batches': 100,

        'max_steps': 10000,
        'accumulate_grad_batches': 4,
        'gradient_clip_val': 5.0,
        'precision': 32,
    }
)

## 2. Data

In [3]:
class DataModule(pl.LightningDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.pretrained_model_name_or_path)
        
    def setup(self, stage=None):
        dataset = []
        for dname in self.hparams.datasets:
            _dataset = load_dataset('text', data_files=os.path.join(self.hparams.data_dir, f'{dname}.txt'))['train']
            dataset.append(_dataset)

        self.dataset = concatenate_datasets(dataset)
        self.dataset.set_transform(lambda batch: transform(batch, self.tokenizer, self.hparams.max_seq_length))
        self.dataset = self.dataset.train_test_split(test_size=0.01)
        self.train_dataset, self.eval_dataset = self.dataset['train'], self.dataset['test']

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=collate_fn)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.eval_dataset, batch_size=self.hparams.batch_size, shuffle=False, collate_fn=collate_fn)

    def test_dataloader(self):
        return self.val_dataloader()
    
    
def transform(batch, tokenizer, max_length):
    new_batch = []
    for text in batch['text']:
        text = slice_text(text)
        new_batch.append(text)
    
    return tokenizer(new_batch, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')


def slice_text(text, max_char_length=1024):
    if len(text) > max_char_length:
        idx = np.random.randint(low=0, high=len(text)-max_char_length)
        text = text[idx : idx+max_char_length]
    return text


def collate_fn(batch):
    new_batch = {}
    
    keys = batch[0].keys()
    for k in keys:
        v = torch.stack([b[k] for b in batch], dim=0)
        new_batch[k] = v
    
    return BatchEncoding(new_batch)

## 3. Model

In [4]:
class BaseModel(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.teacher, self.student, self.tokenizer = self.prepare()
    
    def prepare(self):
        teacher_kwargs = {k:v for k, v in self.hparams.teacher.items() if k not in ['model_name_or_path']}
        teacher = AutoModel.from_pretrained(
            self.hparams.teacher.model_name_or_path,
            **teacher_kwargs
        )
        
        self.hparams.student.model_name_or_path = self.hparams.teacher.model_name_or_path        
        student_kwargs = {k:v for k, v in self.hparams.student.items() if k not in ['model_name_or_path']}
        config = AutoConfig.from_pretrained(
            self.hparams.student.model_name_or_path,
            **student_kwargs
        )
        
        student = AutoModel.from_config(config)
   
        for param in teacher.parameters():
            param.requires_grad = False
        
        tokenizer = AutoTokenizer.from_pretrained(self.hparams.teacher.model_name_or_path)
        return teacher, student, tokenizer
    
    
    def student_param_groups(self):
        no_decay = ["bias", "bn", "ln", "norm"]
        param_groups = [
            {
                # apply weight decay
                "params": [p for n, p in self.student.named_parameters() if not any(nd in n.lower() for nd in no_decay)],
                "weight_decay": self.hparams.optimizer.weight_decay
            },
            {
                # not apply weight decay
                "params": [p for n, p in self.student.named_parameters() if any(nd in n.lower() for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        return param_groups


    def configure_optimizers(self):
        optimizer = prepare_optimizer(self.student_param_groups(), self.hparams.optimizer)
        scheduler = prepare_scheduler(optimizer, self.hparams.scheduler)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
            }
        }


    @pl.utilities.rank_zero_only
    def on_save_checkpoint(self, checkpoint):
        ckpt_dir = f'ckpt/{self.trainer.global_step:06d}'
        self.student.save_pretrained(ckpt_dir)
        self.tokenizer.save_pretrained(ckpt_dir)


optim_dict = {
    'adam': torch.optim.Adam,
    'adamw': torch.optim.AdamW
}

def prepare_optimizer(params, optimizer_hparams):
    name = optimizer_hparams['name']
    hparams = {k:v for k,v in optimizer_hparams.items() if k != 'name'}
    return optim_dict[name](params, **hparams)


def prepare_scheduler(optimizer, scheduler_hparams):
    num_training_steps = scheduler_hparams['max_steps']
    num_warmup_steps = int(num_training_steps * scheduler_hparams['warmup_ratio'])
    scheduler = get_scheduler(scheduler_hparams['name'], optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
    return scheduler


In [5]:
class Model(BaseModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    

    def prepare(self):
        teacher, student, tokenizer = super().prepare()
        teacher = to_distill(teacher)
        student = to_distill(student)
        return teacher, student, tokenizer


    def step(self, batch, phase):
        teacher_outputs = self.teacher(**batch)
        student_outputs = self.student(**batch)

        teacher_qkv = get_qkvs(self.teacher)[self.hparams.distil.teacher_layer_index] # (batch, head, seq, head_dim)
        student_qkv = get_qkvs(self.student)[self.hparams.distil.student_layer_index] # (batch, head, seq, head_dim)

        loss_q = minilm_loss(teacher_qkv['q'], student_qkv['q'], self.hparams.distil.num_relation_heads, batch.attention_mask)
        loss_k = minilm_loss(teacher_qkv['k'], student_qkv['k'], self.hparams.distil.num_relation_heads, batch.attention_mask)
        loss_v = minilm_loss(teacher_qkv['v'], student_qkv['v'], self.hparams.distil.num_relation_heads, batch.attention_mask)
        loss = loss_q + loss_k + loss_v

        log = {f'{phase}/loss': loss, f'{phase}/loss_q': loss_q, f'{phase}/loss_k': loss_k, f'{phase}/loss_v': loss_v}
        self.log_dict(log, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self.step(batch, 'valid')

    def test_step(self, batch, batch_idx):
        return self.step(batch, 'test')


def to_distill(model):
    model.base_model.encoder.layer[0].attention.self.__class__._forward = bert_self_attention_forward
    for layer in model.base_model.encoder.layer:
        layer.attention.self.forward = layer.attention.self._forward
    return model


def get_qkvs(model):
    attns = [l.attention.self for l in model.base_model.encoder.layer]
    qkvs = [{'q': a.q, 'k': a.k, 'v': a.v} for a in attns]    
    return qkvs


def transpose_for_scores(h, num_heads):
    batch_size, seq_length, dim = h.size()
    head_size = dim // num_heads
    h = h.view(batch_size, seq_length, num_heads, head_size)
    return h.permute(0, 2, 1, 3) # (batch, num_heads, seq_length, head_size)


def attention(h1, h2, num_heads, attention_mask=None):
    assert h1.size() == h2.size()
    head_size = h1.size(-1) // num_heads
    h1 = transpose_for_scores(h1, num_heads) # (batch, num_heads, seq_length, head_size)
    h2 = transpose_for_scores(h2, num_heads) # (batch, num_heads, seq_length, head_size)

    attn = torch.matmul(h1, h2.transpose(-1, -2)) # (batch_size, num_heads, seq_length, seq_length)
    attn = attn / math.sqrt(head_size)
    if attention_mask is not None:
        attention_mask = attention_mask[:, None, None, :]
        attention_mask = (1 - attention_mask) * -10000.0
        attn = attn + attention_mask

    return attn


def kl_div_loss(s, t, temperature):
    if len(s.size()) != 2:
        s = s.view(-1, s.size(-1))
        t = t.view(-1, t.size(-1))

    s = F.log_softmax(s / temperature, dim=-1)
    t = F.softmax(t / temperature, dim=-1)
    return F.kl_div(s, t, reduction='batchmean')


def minilm_loss(t, s, num_relation_heads, attention_mask=None, temperature=1.0):
    attn_t = attention(t, t, num_relation_heads, attention_mask)
    attn_s = attention(s, s, num_relation_heads, attention_mask)
    loss = kl_div_loss(attn_s, attn_t, temperature=temperature)
    return loss


def bert_self_attention_forward(
    self,
    hidden_states,
    attention_mask=None,
    head_mask=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    past_key_value=None,
    output_attentions=False,
):
    mixed_query_layer = self.query(hidden_states)
    mixed_key_layer = self.key(hidden_states)
    mixed_value_layer = self.value(hidden_states)
    
    query_layer = self.transpose_for_scores(mixed_query_layer)
    key_layer = self.transpose_for_scores(mixed_key_layer)
    value_layer = self.transpose_for_scores(mixed_value_layer)
    
    self.q = mixed_query_layer # (Batch, Seq, Dim)
    self.k = mixed_key_layer # (Batch, Seq, Dim)
    self.v = mixed_value_layer # (Batch, Seq, Dim)

    if self.is_decoder:
        past_key_value = (key_layer, value_layer)

    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
        seq_length = hidden_states.size()[1]
        position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
        position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
        distance = position_ids_l - position_ids_r
        positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
        positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

        if self.position_embedding_type == "relative_key":
            relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores
        elif self.position_embedding_type == "relative_key_query":
            relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

    attention_scores = attention_scores / math.sqrt(self.attention_head_size)
    if attention_mask is not None:
        attention_scores = attention_scores + attention_mask

    attention_probs = nn.Softmax(dim=-1)(attention_scores)
    attention_probs = self.dropout(attention_probs)

    if head_mask is not None:
        attention_probs = attention_probs * head_mask

    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(*new_context_layer_shape)

    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
    if self.is_decoder:
        outputs = outputs + (past_key_value,)
    return outputs

## 4. Distillation

In [6]:
data_module = DataModule(**config.data)

In [8]:
model = Model(**config)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
logger = pl.loggers.WandbLogger(
    project = 'paper',
    log_model = False,
    reinit = True,
)

logger.watch(model, log='gradients')

[34m[1mwandb[0m: Currently logged in as: [33mrespect5716[0m (use `wandb login --relogin` to force relogin)


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [10]:
ckpt_callback = pl.callbacks.ModelCheckpoint(
    dirpath = 'ckpt', 
    filename = 'step={step:06d}-valid_loss={valid/loss:.3f}', 
    monitor = 'valid/loss',
    verbose = True,
    save_top_k = 1,
    save_weights_only = True,
    auto_insert_metric_name = False
)

lr_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')

In [11]:
trainer = pl.Trainer(    
    logger = logger,
    callbacks = [ckpt_callback, lr_callback],
    **config.trainer
)

  "Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [12]:
trainer.fit(model, data_module)

  "`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
Using custom data configuration default-a5f702edb7742337
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-a5f702edb7742337/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)


  0%|          | 0/1 [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | teacher | BertModel | 110 M 
1 | student | BertModel | 21.5 M
--------------------------------------
21.5 M    Trainable params
110 M     Non-trainable params
132 M     Total params
528.473   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 0, global step 249: valid/loss reached 2.74710 (best 2.74710), saving model to "/content/ckpt/step=000249-valid_loss=2.747.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 499: valid/loss reached 2.54490 (best 2.54490), saving model to "/content/ckpt/step=000499-valid_loss=2.545.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 749: valid/loss reached 2.42122 (best 2.42122), saving model to "/content/ckpt/step=000749-valid_loss=2.421.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 999: valid/loss reached 2.34586 (best 2.34586), saving model to "/content/ckpt/step=000999-valid_loss=2.346.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 1249: valid/loss reached 2.22789 (best 2.22789), saving model to "/content/ckpt/step=001249-valid_loss=2.228.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 1499: valid/loss reached 2.17042 (best 2.17042), saving model to "/content/ckpt/step=001499-valid_loss=2.170.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 1749: valid/loss reached 2.10276 (best 2.10276), saving model to "/content/ckpt/step=001749-valid_loss=2.103.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 1999: valid/loss reached 2.00416 (best 2.00416), saving model to "/content/ckpt/step=001999-valid_loss=2.004.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 2249: valid/loss reached 1.93929 (best 1.93929), saving model to "/content/ckpt/step=002249-valid_loss=1.939.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 2499: valid/loss reached 1.88054 (best 1.88054), saving model to "/content/ckpt/step=002499-valid_loss=1.881.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 2749: valid/loss reached 1.84415 (best 1.84415), saving model to "/content/ckpt/step=002749-valid_loss=1.844.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 2999: valid/loss reached 1.81906 (best 1.81906), saving model to "/content/ckpt/step=002999-valid_loss=1.819.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 3249: valid/loss reached 1.78651 (best 1.78651), saving model to "/content/ckpt/step=003249-valid_loss=1.787.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 3499: valid/loss reached 1.75411 (best 1.75411), saving model to "/content/ckpt/step=003499-valid_loss=1.754.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 3749: valid/loss reached 1.73828 (best 1.73828), saving model to "/content/ckpt/step=003749-valid_loss=1.738.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 3999: valid/loss reached 1.71542 (best 1.71542), saving model to "/content/ckpt/step=003999-valid_loss=1.715.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 4249: valid/loss reached 1.69884 (best 1.69884), saving model to "/content/ckpt/step=004249-valid_loss=1.699.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 4499: valid/loss reached 1.68060 (best 1.68060), saving model to "/content/ckpt/step=004499-valid_loss=1.681.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 4749: valid/loss reached 1.66400 (best 1.66400), saving model to "/content/ckpt/step=004749-valid_loss=1.664.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 4999: valid/loss reached 1.65372 (best 1.65372), saving model to "/content/ckpt/step=004999-valid_loss=1.654.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 5249: valid/loss reached 1.63758 (best 1.63758), saving model to "/content/ckpt/step=005249-valid_loss=1.638.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 5499: valid/loss reached 1.62368 (best 1.62368), saving model to "/content/ckpt/step=005499-valid_loss=1.624.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 5749: valid/loss reached 1.61775 (best 1.61775), saving model to "/content/ckpt/step=005749-valid_loss=1.618.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 5999: valid/loss reached 1.60317 (best 1.60317), saving model to "/content/ckpt/step=005999-valid_loss=1.603.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 6249: valid/loss reached 1.58406 (best 1.58406), saving model to "/content/ckpt/step=006249-valid_loss=1.584.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 6499: valid/loss reached 1.57889 (best 1.57889), saving model to "/content/ckpt/step=006499-valid_loss=1.579.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 6749: valid/loss reached 1.56683 (best 1.56683), saving model to "/content/ckpt/step=006749-valid_loss=1.567.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 6999: valid/loss reached 1.55136 (best 1.55136), saving model to "/content/ckpt/step=006999-valid_loss=1.551.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 7249: valid/loss reached 1.53473 (best 1.53473), saving model to "/content/ckpt/step=007249-valid_loss=1.535.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 7499: valid/loss reached 1.52294 (best 1.52294), saving model to "/content/ckpt/step=007499-valid_loss=1.523.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 7749: valid/loss reached 1.51142 (best 1.51142), saving model to "/content/ckpt/step=007749-valid_loss=1.511.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 7999: valid/loss reached 1.50291 (best 1.50291), saving model to "/content/ckpt/step=007999-valid_loss=1.503.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 8249: valid/loss reached 1.49520 (best 1.49520), saving model to "/content/ckpt/step=008249-valid_loss=1.495.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 8499: valid/loss reached 1.49004 (best 1.49004), saving model to "/content/ckpt/step=008499-valid_loss=1.490.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 8749: valid/loss reached 1.48452 (best 1.48452), saving model to "/content/ckpt/step=008749-valid_loss=1.485.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 8999: valid/loss reached 1.47197 (best 1.47197), saving model to "/content/ckpt/step=008999-valid_loss=1.472.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 9249: valid/loss reached 1.46819 (best 1.46819), saving model to "/content/ckpt/step=009249-valid_loss=1.468.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 9499: valid/loss reached 1.46417 (best 1.46417), saving model to "/content/ckpt/step=009499-valid_loss=1.464.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 9749: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 9999: valid/loss reached 1.45712 (best 1.45712), saving model to "/content/ckpt/step=009999-valid_loss=1.457.ckpt" as top 1


In [None]:
res = trainer.test(model, data_module)

Using custom data configuration default-a5f702edb7742337
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-a5f702edb7742337/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)


  0%|          | 0/1 [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

## 5. Downstream

In [2]:
def set_example(example):
    """example -> text_a, text_b, label
    """
    return {'text_a': example['premise'], 'text_b': example['hypothesis'], 'labels': example['label']}


def convert_example_to_feature(example, tokenizer, max_length):
    """text_a, text_b, label -> input_ids, attention_mask, token_type_ids, label
    """
    feature = tokenizer(
        example['text_a'], example['text_b'], 
        max_length = max_length, 
        padding = 'max_length', 
        truncation = True
    )
    return feature

In [3]:
ckpt_dir = 'ckpt/009999'
model = AutoModelForSequenceClassification.from_pretrained(ckpt_dir, num_labels=3)
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ckpt/009999 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
dataset = load_dataset('klue', 'nli')
dataset = dataset.map(set_example)
dataset = dataset.map(lambda example: convert_example_to_feature(example, tokenizer, 256))
dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

Downloading:   0%|          | 0.00/5.21k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.93k [00:00<?, ?B/s]

Downloading and preparing dataset klue/nli (download: 1.20 MiB, generated: 6.10 MiB, post-processed: Unknown size, total: 7.30 MiB) to /root/.cache/huggingface/datasets/klue/nli/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e...


Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset klue downloaded and prepared to /root/.cache/huggingface/datasets/klue/nli/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/24998 [00:00<?, ?ex/s]

  0%|          | 0/3000 [00:00<?, ?ex/s]

  0%|          | 0/24998 [00:00<?, ?ex/s]

  0%|          | 0/3000 [00:00<?, ?ex/s]

In [5]:
training_args = TrainingArguments(
    'training_args',
    num_train_epochs = 3,
)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = dataset['train'],
    eval_dataset = dataset['validation'],
)

In [6]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: hypothesis, text_b, text_a, premise, source, guid.
***** Running training *****
  Num examples = 24998
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 9375
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33mrespect5716[0m (use `wandb login --relogin` to force relogin)


Step,Training Loss
500,1.0546
1000,0.8741
1500,0.826
2000,0.7723
2500,0.7653
3000,0.7488
3500,0.6353
4000,0.5937
4500,0.6109
5000,0.5626


Saving model checkpoint to training_args/checkpoint-500
Configuration saved in training_args/checkpoint-500/config.json
Model weights saved in training_args/checkpoint-500/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-1000
Configuration saved in training_args/checkpoint-1000/config.json
Model weights saved in training_args/checkpoint-1000/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-1500
Configuration saved in training_args/checkpoint-1500/config.json
Model weights saved in training_args/checkpoint-1500/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-2000
Configuration saved in training_args/checkpoint-2000/config.json
Model weights saved in training_args/checkpoint-2000/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-2500
Configuration saved in training_args/checkpoint-2500/config.json
Model weights saved in training_args/checkpoint-2500/pytorch_model.bin
Saving model checkpoint to training_ar

TrainOutput(global_step=9375, training_loss=0.6084553922526041, metrics={'train_runtime': 275.3627, 'train_samples_per_second': 272.346, 'train_steps_per_second': 34.046, 'total_flos': 1038641548428288.0, 'train_loss': 0.6084553922526041, 'epoch': 3.0})

In [7]:
loader = torch.utils.data.DataLoader(dataset['validation'], batch_size=8, shuffle=False)

correct = []
for batch in tqdm(loader):
    batch = {k:v.cuda() for k, v in batch.items()}
    outputs = model(**batch)
    preds = outputs.logits.argmax(dim=1)
    _corrct = (batch['labels'] == preds).cpu()
    correct.append(_corrct)

acc = torch.cat(correct).float().mean()
print(f'ACC: {acc}')

  0%|          | 0/375 [00:00<?, ?it/s]

ACC: 0.6363333463668823
