<a href="https://colab.research.google.com/github/respect5716/deep-learning-paper-implementation/blob/main/03_NLP/DistilBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DistilBERT

## 0. Introduction

### Paper
* title: DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter
* authors: Victor Sanh et al.
* url: https://arxiv.org/abs/1910.01108

### Reference
* https://github.com/huggingface/transformers/tree/1c06240e1b3477728129bb58e7b6c7734bb5074e/examples/research_projects/distillation

## 1. Setup

In [1]:
!pip install -q wandb transformers pytorch_lightning datasets

[K     |████████████████████████████████| 1.7 MB 14.8 MB/s 
[K     |████████████████████████████████| 3.1 MB 60.9 MB/s 
[K     |████████████████████████████████| 525 kB 61.6 MB/s 
[K     |████████████████████████████████| 298 kB 56.7 MB/s 
[K     |████████████████████████████████| 180 kB 64.9 MB/s 
[K     |████████████████████████████████| 140 kB 65.7 MB/s 
[K     |████████████████████████████████| 97 kB 8.1 MB/s 
[K     |████████████████████████████████| 63 kB 2.0 MB/s 
[K     |████████████████████████████████| 596 kB 61.9 MB/s 
[K     |████████████████████████████████| 3.3 MB 71.0 MB/s 
[K     |████████████████████████████████| 61 kB 554 kB/s 
[K     |████████████████████████████████| 895 kB 57.0 MB/s 
[K     |████████████████████████████████| 332 kB 71.8 MB/s 
[K     |████████████████████████████████| 829 kB 66.1 MB/s 
[K     |████████████████████████████████| 132 kB 63.2 MB/s 
[K     |████████████████████████████████| 1.1 MB 70.3 MB/s 
[K     |█████████████████████

In [1]:
import os
import wandb
import easydict
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from transformers import Trainer, TrainingArguments
from transformers import get_scheduler
from transformers import BatchEncoding
from transformers import DataCollatorForWholeWordMask
from transformers import AutoTokenizer, AutoConfig, AutoModelForMaskedLM, AutoModelForSequenceClassification

from datasets import load_metric, load_dataset, concatenate_datasets

In [2]:
config = easydict.EasyDict(

    data = {
        'datasets': ['namuwiki'],
        'data_dir': 'drive/Shareddrives/dataset',
        'pretrained_model_name_or_path': 'klue/bert-base',
        'batch_size': 4,
        'mlm_probability': 0.15,
        'max_seq_length': 512
    },

    teacher = {
        'model_name_or_path': 'klue/bert-base',
        'hidden_dropout_prob': 0.,
        'attention_probs_dropout_prob': 0.,
        'output_attentions': True,
        'output_hidden_states': True
    },

    student = {
        'num_hidden_layers': 3,
        'hidden_dropout_prob': 0.,
        'attention_probs_dropout_prob': 0.,
        'output_attentions': True,
        'output_hidden_states': True
    },

    optimizer = {
        'name': 'adamw',
        'lr': 6e-4,
        'betas': (0.9, 0.98),
        'weight_decay': 0.01,
    },

    scheduler = {
        'name': 'linear',
        'max_steps': 10000,
        'warmup_ratio': 0.05
    },

    distil = {
        'temperature': 2.,
        'alpha_mlm': 2.0, # mlm loss
        'alpha_ce': 5.0,  # logit distil loss
        'alpha_cos': 1.0 # hidden distil loss
    },

    trainer = {
        'gpus': -1,
        'log_every_n_steps': 10,
        'num_sanity_val_steps': 100,
        'val_check_interval': 1000,
        'limit_val_batches': 100,

        'max_steps': 10000,
        'accumulate_grad_batches': 4,
        'gradient_clip_val': 5.0,
        'precision': 32,
    }
)

## 2. Data

In [3]:
class DataModule(pl.LightningDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.pretrained_model_name_or_path)
        
    def setup(self, stage=None):
        dataset = []
        for dname in self.hparams.datasets:
            _dataset = load_dataset('text', data_files=os.path.join(self.hparams.data_dir, f'{dname}.txt'))['train']
            dataset.append(_dataset)

        self.dataset = concatenate_datasets(dataset)
        self.dataset.set_transform(lambda batch: transform(batch, self.tokenizer, self.hparams.max_seq_length))
        self.dataset = self.dataset.train_test_split(test_size=0.01)
        self.train_dataset, self.eval_dataset = self.dataset['train'], self.dataset['test']
        
        self.wwm = DataCollatorForWholeWordMask(tokenizer=self.tokenizer, mlm=True, mlm_probability=self.hparams.mlm_probability)


    def collate_fn(self, batch):
        batch = BatchEncoding(self.wwm(batch))
        batch['attention_mask'] = batch.input_ids.ne(self.tokenizer.pad_token_id).float()
        return batch

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=self.collate_fn)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.eval_dataset, batch_size=self.hparams.batch_size, shuffle=False, collate_fn=self.collate_fn)

    def test_dataloader(self):
        return self.val_dataloader()
    
    
def transform(batch, tokenizer, max_length):
    new_batch = []
    for text in batch['text']:
        text = slice_text(text)
        new_batch.append(text)
    
    return tokenizer(new_batch, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')


def slice_text(text, max_char_length=1024):
    if len(text) > max_char_length:
        idx = np.random.randint(low=0, high=len(text)-max_char_length)
        text = text[idx : idx+max_char_length]
    return text

## 3. Model

In [4]:
EMBEDDING_PARAMS = ['word_embeddings.weight', 'LayerNorm.weight', 'LayerNorm.bias']
ENCODER_PARAMS = [
    'attention.self.query.weight', 'attention.self.query.bias', 
    'attention.self.key.weight', 'attention.self.key.bias', 
    'attention.self.value.weight', 'attention.self.value.bias',
    'attention.output.dense.weight', 'attention.output.dense.bias',
    'attention.output.LayerNorm.weight', 'attention.output.LayerNorm.bias',
    'intermediate.dense.weight', 'intermediate.dense.bias',
    'output.dense.weight', 'output.dense.bias',
    'output.LayerNorm.weight', 'output.LayerNorm.bias'
]


def get_param_names_of_layer(model_name, layer_name, idx=0):
    if layer_name == 'embeddings':
        names = [f'{model_name}.{layer_name}.{p}' for p in EMBEDDING_PARAMS]
    elif layer_name == 'encoder':
        names = [f'{model_name}.{layer_name}.layer.{idx}.{p}' for p in ENCODER_PARAMS]
    return names


def get_param_names_of_model(model_name, encoder_idx):
    names = []
    names += get_param_names_of_layer(model_name, 'embeddings')
    for idx in encoder_idx:
        names += get_param_names_of_layer(model_name, 'encoder', idx)
    return names


def init_student_from_teacher(student, teacher):
    multiplier = teacher.config.num_hidden_layers // student.config.num_hidden_layers
    teacher_encoder_idx = [i for i in range(teacher.config.num_hidden_layers) if (i) % multiplier == 0]
    student_encoder_idx = [i for i in range(student.config.num_hidden_layers)]

    teacher_param_names = get_param_names_of_model(teacher.base_model_prefix, teacher_encoder_idx)
    student_param_names = get_param_names_of_model(student.base_model_prefix, student_encoder_idx)

    teacher_params = {k:v for k,v in teacher.named_parameters()}
    student_params = {k:v for k,v in student.named_parameters()}

    with torch.no_grad():
        for t, s in zip(teacher_param_names, student_param_names):
            student_params[s].copy_(teacher_params[t])

In [5]:
optim_dict = {
    'adam': torch.optim.Adam,
    'adamw': torch.optim.AdamW
}

def prepare_optimizer(params, optimizer_hparams):
    name = optimizer_hparams['name']
    hparams = {k:v for k,v in optimizer_hparams.items() if k != 'name'}
    return optim_dict[name](params, **hparams)


def prepare_scheduler(optimizer, scheduler_hparams):
    num_training_steps = scheduler_hparams['max_steps']
    num_warmup_steps = int(num_training_steps * scheduler_hparams['warmup_ratio'])
    scheduler = get_scheduler(scheduler_hparams['name'], optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
    return scheduler


def select_tensor(tensor, attention_mask):
    mask = attention_mask.unsqueeze(-1).expand_as(tensor).bool()
    selected = torch.masked_select(tensor, mask)  # (bs * seq_length * voc_size)
    selected = selected.view(-1, tensor.size(-1))  # (bs * seq_length, voc_size)
    return selected

In [6]:
class Model(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.teacher, self.student, self.tokenizer = self.prepare()
    
    def prepare(self):
        teacher_kwargs = {k:v for k, v in self.hparams.teacher.items() if k not in ['model_name_or_path']}
        teacher = AutoModelForMaskedLM.from_pretrained(
            self.hparams.teacher.model_name_or_path,
            **teacher_kwargs
        )
        
        config = AutoConfig.from_pretrained(
            self.hparams.teacher.model_name_or_path,
            **self.hparams.student
        )
        
        student = AutoModelForMaskedLM.from_config(config)
   
        for param in teacher.parameters():
            param.requires_grad = False
        
        tokenizer = AutoTokenizer.from_pretrained(self.hparams.teacher.model_name_or_path)

        init_student_from_teacher(student, teacher)
        return teacher, student, tokenizer
    
    
    def student_param_groups(self):
        no_decay = ["bias", "bn", "ln", "norm"]
        param_groups = [
            {
                # apply weight decay
                "params": [p for n, p in self.student.named_parameters() if not any(nd in n.lower() for nd in no_decay)],
                "weight_decay": self.hparams.optimizer.weight_decay
            },
            {
                # not apply weight decay
                "params": [p for n, p in self.student.named_parameters() if any(nd in n.lower() for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        return param_groups


    def configure_optimizers(self):
        optimizer = prepare_optimizer(self.student_param_groups(), self.hparams.optimizer)
        scheduler = prepare_scheduler(optimizer, self.hparams.scheduler)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
            }
        }


    def training_step(self, batch, batch_idx):
        loss = 0.
        log = {}

        teacher_outputs = self.teacher(**batch)
        student_outputs = self.student(**batch)
        teacher_logits = select_tensor(teacher_outputs.logits, batch.attention_mask)
        student_logits = select_tensor(student_outputs.logits, batch.attention_mask)

        if self.hparams.distil.alpha_mlm > 0:
            mlm_loss = student_outputs.loss
            loss += self.hparams.distil.alpha_mlm * mlm_loss
            log['train/mlm_loss'] = mlm_loss


        if self.hparams.distil.alpha_ce > 0:
            student_softmax = F.log_softmax(student_logits / self.hparams.distil.temperature, dim=-1)
            teacher_softmax = F.softmax(teacher_logits / self.hparams.distil.temperature, dim=-1)
            ce_loss = F.kl_div(student_softmax, teacher_softmax, reduction='batchmean') * (self.hparams.distil.temperature) ** 2
            loss += self.hparams.distil.alpha_ce * ce_loss
            log['train/ce_loss'] = ce_loss

        if self.hparams.distil.alpha_cos > 0:
            student_hidden = select_tensor(student_outputs.hidden_states[-1], batch.attention_mask)
            teacher_hidden = select_tensor(teacher_outputs.hidden_states[-1], batch.attention_mask)
            target = student_hidden.new(student_hidden.size(0)).fill_(1)
            cos_loss = F.cosine_embedding_loss(student_hidden, teacher_hidden, target, reduction='mean')
            loss += self.hparams.distil.alpha_cos * cos_loss
            log['train/cos_loss'] = cos_loss

        log['train/loss'] = loss
        self.log_dict(log, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss


    def eval_step(self, batch, phase):
        outputs = self.student(**batch)
        self.log_dict({f'{phase}/loss': outputs.loss}, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        return self.eval_step(batch, 'valid')

    def test_step(self, batch, batch_idx):
        return self.eval_step(batch, 'test')


    @pl.utilities.rank_zero_only
    def on_save_checkpoint(self, checkpoint):
        ckpt_dir = os.path.join('ckpt', f'{self.trainer.global_step:06d}')
        self.student.save_pretrained(ckpt_dir)
        self.tokenizer.save_pretrained(ckpt_dir)

## 4. Distillation

In [7]:
data_module = DataModule(**config.data)

In [8]:
model = Model(**config)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
logger = pl.loggers.WandbLogger(
    project = 'paper',
    log_model = False,
    reinit = True,
)

logger.watch(model, log='gradients')

[34m[1mwandb[0m: Currently logged in as: [33mrespect5716[0m (use `wandb login --relogin` to force relogin)


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [10]:
ckpt_callback = pl.callbacks.ModelCheckpoint(
    dirpath = 'ckpt', 
    filename = 'step={step:06d}-valid_loss={valid/loss:.3f}', 
    monitor = 'valid/loss',
    verbose = True,
    save_top_k = 1,
    save_weights_only = True,
    auto_insert_metric_name = False
)

lr_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')

In [11]:
trainer = pl.Trainer(    
    logger = logger,
    callbacks = [ckpt_callback, lr_callback],
    resume_from_checkpoint = False,
    **config.trainer
)

  "Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [13]:
trainer.fit(model, data_module)

  "`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
Using custom data configuration default-a5f702edb7742337
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-a5f702edb7742337/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)


  0%|          | 0/1 [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type            | Params
--------------------------------------------
0 | teacher | BertForMaskedLM | 110 M 
1 | student | BertForMaskedLM | 46.9 M
--------------------------------------------
46.9 M    Trainable params
110 M     Non-trainable params
157 M     Total params
630.044   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 0, global step 249: valid/loss reached 6.78821 (best 6.78821), saving model to "/content/ckpt/step=000249-valid_loss=6.788.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 499: valid/loss reached 6.66048 (best 6.66048), saving model to "/content/ckpt/step=000499-valid_loss=6.660.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 749: valid/loss reached 6.48521 (best 6.48521), saving model to "/content/ckpt/step=000749-valid_loss=6.485.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 999: valid/loss reached 5.37536 (best 5.37536), saving model to "/content/ckpt/step=000999-valid_loss=5.375.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 1249: valid/loss reached 4.83889 (best 4.83889), saving model to "/content/ckpt/step=001249-valid_loss=4.839.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 1499: valid/loss reached 4.67573 (best 4.67573), saving model to "/content/ckpt/step=001499-valid_loss=4.676.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 1749: valid/loss reached 4.55913 (best 4.55913), saving model to "/content/ckpt/step=001749-valid_loss=4.559.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 1999: valid/loss reached 4.43045 (best 4.43045), saving model to "/content/ckpt/step=001999-valid_loss=4.430.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 2249: valid/loss reached 4.36473 (best 4.36473), saving model to "/content/ckpt/step=002249-valid_loss=4.365.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 2499: valid/loss reached 4.33634 (best 4.33634), saving model to "/content/ckpt/step=002499-valid_loss=4.336.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 2749: valid/loss reached 4.20605 (best 4.20605), saving model to "/content/ckpt/step=002749-valid_loss=4.206.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 2999: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 3249: valid/loss reached 4.17748 (best 4.17748), saving model to "/content/ckpt/step=003249-valid_loss=4.177.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 3499: valid/loss reached 4.12038 (best 4.12038), saving model to "/content/ckpt/step=003499-valid_loss=4.120.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 3749: valid/loss reached 4.05505 (best 4.05505), saving model to "/content/ckpt/step=003749-valid_loss=4.055.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 3999: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 4249: valid/loss reached 3.97073 (best 3.97073), saving model to "/content/ckpt/step=004249-valid_loss=3.971.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 4499: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 4749: valid/loss reached 3.89380 (best 3.89380), saving model to "/content/ckpt/step=004749-valid_loss=3.894.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 4999: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 5249: valid/loss reached 3.86059 (best 3.86059), saving model to "/content/ckpt/step=005249-valid_loss=3.861.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 5499: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 5749: valid/loss reached 3.80719 (best 3.80719), saving model to "/content/ckpt/step=005749-valid_loss=3.807.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 5999: valid/loss reached 3.80084 (best 3.80084), saving model to "/content/ckpt/step=005999-valid_loss=3.801.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 6249: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 6499: valid/loss reached 3.71833 (best 3.71833), saving model to "/content/ckpt/step=006499-valid_loss=3.718.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 6749: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 6999: valid/loss reached 3.66755 (best 3.66755), saving model to "/content/ckpt/step=006999-valid_loss=3.668.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 7249: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 7499: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 7749: valid/loss reached 3.63784 (best 3.63784), saving model to "/content/ckpt/step=007749-valid_loss=3.638.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 7999: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 8249: valid/loss reached 3.59070 (best 3.59070), saving model to "/content/ckpt/step=008249-valid_loss=3.591.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 8499: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 8749: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 8999: valid/loss reached 3.56629 (best 3.56629), saving model to "/content/ckpt/step=008999-valid_loss=3.566.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 9249: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 9499: valid/loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 9749: valid/loss reached 3.54495 (best 3.54495), saving model to "/content/ckpt/step=009749-valid_loss=3.545.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 9999: valid/loss reached 3.49080 (best 3.49080), saving model to "/content/ckpt/step=009999-valid_loss=3.491.ckpt" as top 1


In [14]:
res = trainer.test(model, data_module)

Using custom data configuration default-a5f702edb7742337
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-a5f702edb7742337/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)


  0%|          | 0/1 [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/loss': 3.5919926166534424, 'test/loss_epoch': 3.5919926166534424}
--------------------------------------------------------------------------------


## 5. Donwstream

In [3]:
def set_example(example):
    """example -> text_a, text_b, label
    """
    return {'text_a': example['premise'], 'text_b': example['hypothesis'], 'labels': example['label']}


def convert_example_to_feature(example, tokenizer, max_length):
    """text_a, text_b, label -> input_ids, attention_mask, token_type_ids, label
    """
    feature = tokenizer(
        example['text_a'], example['text_b'], 
        max_length = max_length, 
        padding = 'max_length', 
        truncation = True
    )
    return feature

In [4]:
!ls ckpt

 000249   001749   003499   005999   009749
 000499   001999   003749   006499   009999
 000749   002249   004249   006999  'step=009999-valid_loss=3.491.ckpt'
 000999   002499   004749   007749
 001249   002749   005249   008249
 001499   003249   005749   008999


In [5]:
ckpt_dir = 'ckpt/009999'
model = AutoModelForSequenceClassification.from_pretrained(ckpt_dir, num_labels=3)
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)

Some weights of the model checkpoint at ckpt/009999 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ckpt/009999 and are newly initial

In [6]:
dataset = load_dataset('klue', 'nli')
dataset = dataset.map(set_example)
dataset = dataset.map(lambda example: convert_example_to_feature(example, tokenizer, 256))
dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

Reusing dataset klue (/root/.cache/huggingface/datasets/klue/nli/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/klue/nli/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e/cache-09b2284a4536f50d.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/klue/nli/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e/cache-6970876e840b9a8e.arrow


  0%|          | 0/24998 [00:00<?, ?ex/s]

  0%|          | 0/3000 [00:00<?, ?ex/s]

In [8]:
training_args = TrainingArguments(
    'training_args',
    num_train_epochs = 3,
)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = dataset['train'],
    eval_dataset = dataset['validation'],
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [9]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: hypothesis, premise, source, text_a, guid, text_b.
***** Running training *****
  Num examples = 24998
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 9375
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33mrespect5716[0m (use `wandb login --relogin` to force relogin)


Step,Training Loss
500,0.995
1000,0.8058
1500,0.7754
2000,0.7326
2500,0.7163
3000,0.7074
3500,0.5218
4000,0.4599
4500,0.4613
5000,0.4448


Saving model checkpoint to training_args/checkpoint-500
Configuration saved in training_args/checkpoint-500/config.json
Model weights saved in training_args/checkpoint-500/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-1000
Configuration saved in training_args/checkpoint-1000/config.json
Model weights saved in training_args/checkpoint-1000/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-1500
Configuration saved in training_args/checkpoint-1500/config.json
Model weights saved in training_args/checkpoint-1500/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-2000
Configuration saved in training_args/checkpoint-2000/config.json
Model weights saved in training_args/checkpoint-2000/pytorch_model.bin
Saving model checkpoint to training_args/checkpoint-2500
Configuration saved in training_args/checkpoint-2500/config.json
Model weights saved in training_args/checkpoint-2500/pytorch_model.bin
Saving model checkpoint to training_ar

TrainOutput(global_step=9375, training_loss=0.4856636100260417, metrics={'train_runtime': 410.3444, 'train_samples_per_second': 182.759, 'train_steps_per_second': 22.847, 'total_flos': 2517846031401984.0, 'train_loss': 0.4856636100260417, 'epoch': 3.0})

In [10]:
loader = torch.utils.data.DataLoader(dataset['validation'], batch_size=8, shuffle=False)

correct = []
for batch in tqdm(loader):
    batch = {k:v.cuda() for k, v in batch.items()}
    outputs = model(**batch)
    preds = outputs.logits.argmax(dim=1)
    _corrct = (batch['labels'] == preds).cpu()
    correct.append(_corrct)

acc = torch.cat(correct).float().mean()
print(f'ACC: {acc}')

  0%|          | 0/375 [00:00<?, ?it/s]

ACC: 0.6620000004768372
