# Truncate

## 1. Setup

In [18]:
from typing import List

import torch
import torch.nn as nn

from transformers import AutoConfig, AutoModel, AutoTokenizer

## 2. Truncate

In [64]:
def truncate_layers(model: torch.nn.Module, indices: List[int]):
    model.encoder.layer = nn.ModuleList([l for i, l in enumerate(model.encoder.layer) if i in indices])
    return model

def truncate_weight(weight: torch.Tensor, dim: int, size: int = 384):
    if dim == 0:
        weight = weight[:size]
    elif dim == 1:
        weight = weight[:, :size]
    return weight

In [85]:
def truncate_embedding(embedding, size = 384):
    embedding.weight.data = truncate_weight(embedding.weight.data, dim=1)
    embedding.embedding_dim = size
    return embedding
def truncate_layernorm(layernorm, size=384):
    layernorm.weight.data = truncate_weight(layernorm.weight.data, dim=0)
    layernorm.bias.data = truncate_weight(layernorm.bias.data, dim=0)
    layernorm.normalized_shape = (size,)
    return layernorm

def truncate_linear(linear, dims, size=384):
    for d in dims:
        linear.weight.data = truncate_weight(linear.weight.data, dim=d, size=size)
    if 0 in dims:
        linear.out_features = size
        linear.bias.data = truncate_weight(linear.bias.data, dim=0, size=size)
    if 1 in dims:
        linear.in_features = size
    return linear

In [101]:
def truncate_bert_embeddings(bert_embeddings, size=size):
    truncate_embedding(bert_embeddings.word_embeddings, size=size)
    truncate_embedding(bert_embeddings.position_embeddings, size=size)
    truncate_embedding(bert_embeddings.token_type_embeddings, size=size)
    truncate_layernorm(bert_embeddings.LayerNorm, size=size)
    
    
def truncate_bert_layer(bert_layer, size=384):
    bert_layer.attention.self.all_head_size = size
    bert_layer.attention.self.attention_head_size = size // bert_layer.attention.self.num_attention_heads
    
    truncate_linear(bert_layer.attention.self.query, dims=[0, 1], size=size)
    truncate_linear(bert_layer.attention.self.key, dims=[0, 1], size=size)
    truncate_linear(bert_layer.attention.self.value, dims=[0, 1], size=size)
    
    truncate_linear(bert_layer.attention.output.dense, dims=[0, 1], size=size)
    truncate_layernorm(bert_layer.attention.output.LayerNorm, size=size)
    
    
    truncate_linear(bert_layer.intermediate.dense, dims=[1], size=size)
    truncate_linear(bert_layer.output.dense, dims=[0], size=size)
    truncate_layernorm(bert_layer.output.LayerNorm, size=size)
    
def truncate_bert_pooler(bert_pooler, size):
    truncate_linear(bert_pooler.dense, dims=[0, 1], size=size)
    

def truncate_bert_model(bert_model, size=384):
    layer_indices = [i for i in range(12) if i % 2 != 0]
    bert_model = truncate_layers(bert_model, layer_indices)
    truncate_bert_embeddings(bert_model.embeddings, size=size)
    
    for bert_layer in bert_model.encoder.layer:
        truncate_bert_layer(bert_layer, size=size)
    
    truncate_bert_pooler(bert_model.pooler, size=size)

In [102]:
tokenizer = AutoTokenizer.from_pretrained('klue/bert-base')

In [103]:
model = AutoModel.from_pretrained('klue/bert-base', output_hidden_states=True, output_attentions=True)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [104]:
truncate_bert_model(model)

In [105]:
input_ids = torch.randint(30000, size=(4, 256))

In [106]:
out = model(input_ids)

In [109]:
out.last_hidden_state.size()

torch.Size([4, 256, 384])

In [107]:
len(out.hidden_states)

7

In [110]:
model.config.hidden_size = 384
model.config.num_hidden_layers = 6

model.save_pretrained('../init/transformers')
tokenizer.save_pretrained('../init/transformers')

('../init/transformers/tokenizer_config.json',
 '../init/transformers/special_tokens_map.json',
 '../init/transformers/vocab.txt',
 '../init/transformers/added_tokens.json',
 '../init/transformers/tokenizer.json')

In [111]:
!ls ../init/transformers

config.json	   special_tokens_map.json  tokenizer_config.json
pytorch_model.bin  tokenizer.json	    vocab.txt
