# Data

## 1. Setup

In [191]:
import os
import hydra
from omegaconf import OmegaConf

import numpy as np

import pytorch_lightning as pl

from transformers import AutoTokenizer, AutoModel, AutoConfig
from datasets import load_dataset, concatenate_datasets

In [217]:
overrides = ['working_dir=/workspace/language-model-distillation']
with hydra.initialize(config_path='../configs'):
    config = hydra.compose('config', overrides=overrides, return_hydra_config=True)

In [218]:
config.hydra.run.dir

'outputs/2021-11-25/04-58-45'

In [202]:
class BaseModel(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.teacher, self.student = self.prepare_models()
    
    def prepare_models(self):
        teacher = AutoModel.from_pretrained(
            self.hparams.teacher.name_or_model_path,
            output_attentions = True,
            output_hidden_states = True
        )
        
        config = AutoConfig.from_pretrained(
            self.hparams.teacher.name_or_model_path,
            output_attention = True,
            output_hidden_states = True,
            **self.hparams.student
        )
        
        student = AutoModel.from_config(config)
        
        for param in teacher.parameters():
            param.requires_grad = False
            
        return teacher, student
    
    
    def student_param_groups(self):
        no_decay = ["bias", "bn", "ln", "norm"]
        param_groups = [
            {
                # apply weight decay
                "params": [p for n, p in self.student.named_parameters() if not any(nd in n.lower() for nd in no_decay)],
                "weight_decay": self.hparams.optim.weight_decay
            },
            {
                # not apply weight decay
                "params": [p for n, p in self.student.named_parameters() if any(nd in n.lower() for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        return param_groups


    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.student_param_groups(), 
            lr = self.hparams.optim.lr, 
            betas = self.hparams.optim.betas,
            weight_decay = self.hparams.optim.weight_decay,
            eps = self.hparams.optim.adam_epsilon,
        )

        num_training_steps = self.hparams.max_step
        num_warmup_steps = int(num_training_steps * self.hparams.warmup_ratio)
        scheduler = get_scheduler(self.hparams.scheduler, optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
                'frequency': self.hparams.optim.accumulate_grad_batches,
            }
        }


In [203]:
m = BaseModel(**config.model)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
class Model(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.teacher, self.student = self.prepare_models()
        
        
    def prepare_models(self):
        self.teacher = AutoModel.from_pretrained(
            self.hparams.teacher.name_or_model_path,
            output_attentions = True,
            output_hidden_states = True
        )
        
        config = AutoConfig.from_pretrained(
            self.hparams.teacher.name_or_model_path,
            output_attention = True,
            output_hidden_states = True,
            **self.hparams.student
        )
        




    def step(self, batch, phase):
        input_ids, labels = batch
        attention_mask = input_ids.ne(self.teacher.config.pad_token_id).float()
        
        teacher_outputs = self.teacher(input_ids, attention_mask=attention_mask)
        student_outputs = self.student(input_ids, attention_mask=attention_mask)

        tk, tq, tv = self.teacher.q[self.hparams.teacher_layer_index], self.teacher.k[self.hparams.teacher_layer_index], self.teacher.v[self.hparams.teacher_layer_index] # (batch, head, seq, head_dim)
        sk, sq, sv = self.student.q[self.hparams.student_layer_index], self.student.k[self.hparams.student_layer_index], self.student.v[self.hparams.student_layer_index] # (batch, head, seq, head_dim)

        loss_q = minilm_loss(tq, sq, self.hparams.num_relation_heads, attention_mask=attention_mask)
        loss_k = minilm_loss(tk, sk, self.hparams.num_relation_heads, attention_mask=attention_mask)
        loss_v = minilm_loss(tv, sv, self.hparams.num_relation_heads, attention_mask=attention_mask)
        loss = loss_q + loss_k + loss_v

        self.log_dict({f'{phase}/loss': loss}, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self.step(batch, 'valid')

    def test_step(self, batch, batch_idx):
        return self.step(batch, 'test')