In [1]:
import torch
print(torch.cuda.is_available())

True


In [0]:
!wget -O "SST-2.zip" 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8'
!unzip "SST-2.zip"
!git clone https://username:password@github.com/yandexdataschool/lilbert.git
    
!pip install -r lilbert/requirements.txt
!mkdir ./lilbert/output

In [0]:
import sys
%load_ext autoreload
%autoreload 2
sys.path.append('lilbert/lilbert')

import numpy as np
import random
import torch
from tqdm import tqdm
from pytorch_pretrained_bert.tokenization import BertTokenizer

from lib import data_processors, tasks
from lib.bert import BertForSequenceClassification
from lib.train_eval import train, evaluate, predict

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [0]:
%env CUDA_VISIBLE_DEVICES=1

params = {
    'data_dir': 'SST-2',
    'output_dir': '../output',
    'cache_dir': '../model_cache',
    'task_name': 'sst2',
    'bert_model': 'bert-base-uncased',
    'max_seq_length': 128,
    'train_batch_size': 32,
    'eval_batch_size': 8,
    'learning_rate': 2e-5,
    'warmup_proportion': 0.1,
    'num_train_epochs': 1,
    'seed': 1331,
    'device': torch.device(
        'cuda' if torch.cuda.is_available()
        else 'cpu')
}

random.seed(params['seed'])
np.random.seed(params['seed'])
torch.manual_seed(params['seed'])

env: CUDA_VISIBLE_DEVICES=1


<torch._C.Generator at 0x7fdc0307b130>

In [0]:
device = params['device']

In [0]:
params['num_labels'] = tasks.num_labels[params['task_name']]
params['label_list'] = tasks.label_lists[params['task_name']]

processor = tasks.processors[params['task_name']]()
tokenizer = BertTokenizer.from_pretrained(
    params['bert_model'], do_lower_case=True)

train_examples = processor.get_train_examples(params['data_dir'])
dev_examples = processor.get_dev_examples(params['data_dir'])

100%|██████████| 231508/231508 [00:00<00:00, 2597699.64B/s]


In [0]:
!wget -O "model.pt" "https://www.dropbox.com/s/2gclpuhipfovph2/model_baseline_from_parts.pt?dl=0"

In [0]:
model = BertForSequenceClassification.from_pretrained(
    params['bert_model'],
    cache_dir=params['cache_dir'],
    num_labels=params['num_labels']).to(params['device'])
model = model.to(device)
model.load_state_dict(torch.load("model.pt"))

100%|██████████| 407873900/407873900 [00:12<00:00, 32256152.88B/s]


In [0]:
torch.save(model.bert.encoder.state_dict(), "model-all-enc.pt")

In [0]:
from sklearn.cluster import MiniBatchKMeans
import math
class QuantizedLayer(torch.nn.Module):
    def __init__(self, layer=None, n_clusters=8,  size=None):
        super(QuantizedLayer, self).__init__()
        self.n_bits = math.ceil(math.log2(n_clusters)) # int64
        self.code_size = 63 // self.n_bits  # 8 clusters to 63 bits
        self.emb_layer = emb_layer
        
        if layer is None:
            if size is None:
                raise ValueError("During random init, size must be passed.")
            self.matrix_size = size
            
            centroids = torch.randn(n_clusters).view(-1,1)
            centroids_idx = torch.randint(low=0, high=n_clusters, size=size).view(-1)
            self.bias=torch.nn.Parameter(torch.randn(size[0]))
        else:
            self.matrix_size = layer.weight.size()
            algo = MiniBatchKMeans(n_clusters)
            points = layer.weight.view(-1, 1).detach().cpu().numpy()
            algo.fit(points)
            
            centroids = torch.Tensor(algo.cluster_centers_)
            centroids_idx = torch.LongTensor(algo.predict(points))
            if hasattr(layer, 'bias'):
                self.bias = torch.nn.Parameter(layer.bias)
            else:
                self.bias = None
            
        pad = torch.zeros(-len(centroids_idx) % self.code_size).long()
        to_code = torch.cat([centroids_idx, pad]).view(-1, self.code_size)
        
        self.codes = torch.nn.Parameter(
            torch.sum(to_code.long() * torch.LongTensor([(2 ** self.n_bits) ** i for i in range(self.code_size)]), dim=-1),
                                        requires_grad=False) 
        self.codes_emb = torch.nn.Embedding.from_pretrained(centroids)
        
        
    def forward(self, input_):
          decoded = self.codes.view(-1, 1) //\
                              torch.LongTensor([(2 ** self.n_bits) ** i for i in range(self.code_size)]).to(input_.device) %\
                              torch.LongTensor([(2 ** self.n_bits) for _ in range(self.code_size)]).to(input_.device)
          decoded = decoded.view(-1)[:self.matrix_size.numel()]
          weight = self.codes_emb(decoded)
          weight = weight.view(self.matrix_size)
          return torch.functional.F.linear(input_, weight, self.bias)

Without training:

In [0]:
device = params['device']
n_clusters = 4
for transformer_layer_ind in tqdm(range(12)):
    
    model.bert.encoder.layer[transformer_layer_ind].attention.self.query = \
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.query, n_clusters).to(device)

    model.bert.encoder.layer[transformer_layer_ind].attention.self.key = \
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.key, n_clusters).to(device)


    model.bert.encoder.layer[transformer_layer_ind].attention.self.value = \
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.value, n_clusters).to(device)


    model.bert.encoder.layer[transformer_layer_ind].attention.output.dense =\
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.output.dense, n_clusters).to(device)


    model.bert.encoder.layer[transformer_layer_ind].intermediate.dense =\
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].intermediate.dense, n_clusters).to(device)


    model.bert.encoder.layer[transformer_layer_ind].output.dense =\
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].output.dense, n_clusters).to(device)

100%|██████████| 12/12 [10:29<00:00, 51.52s/it]


In [0]:
result, prob_preds = evaluate(model, tokenizer, params,
                              dev_examples)
result

***** Running evaluation *****
Num examples:  872
Batch size:    8


Evaluating: 100%|██████████| 109/109 [00:19<00:00,  5.69it/s]


{'eval_accuracy': 0.8555045871559633,
 'eval_f1_score': 0.8545034642032333,
 'eval_loss': 0.3484584675960824,
 'eval_matthews_corrcoef': 0.7120853635082497}

With training:

In [0]:
model = BertForSequenceClassification.from_pretrained(
    params['bert_model'],
    cache_dir=params['cache_dir'],
    num_labels=params['num_labels']).to(params['device'])
model = model.to(device)
model.load_state_dict(torch.load("model.pt"))

In [0]:
blocks = [
    [6, 3, 7,  8],
    [4, 5, 0, 11],
    [9, 2, 10, 1]
]

In [0]:
device = params['device']
n_clusters = 4
for i, block in enumerate(blocks):
    for transformer_layer_ind in tqdm(block):
        model.bert.encoder.layer[transformer_layer_ind].attention.self.query = \
        QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.query, n_clusters).to(device)

        model.bert.encoder.layer[transformer_layer_ind].attention.self.key = \
        QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.key, n_clusters).to(device)


        model.bert.encoder.layer[transformer_layer_ind].attention.self.value = \
        QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.value, n_clusters).to(device)


        model.bert.encoder.layer[transformer_layer_ind].attention.output.dense =\
        QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.output.dense, n_clusters).to(device)


        model.bert.encoder.layer[transformer_layer_ind].intermediate.dense =\
        QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].intermediate.dense, n_clusters).to(device)


        model.bert.encoder.layer[transformer_layer_ind].output.dense =\
        QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].output.dense, n_clusters).to(device)

    EPOCH_NUM = i

    params['num_train_epochs'] = 1
    checkpoint_files = {
        'config': 'bert_config.json',
        'file_to_save': 'model_{}_epoch_{}.pth'.format(
            params['task_name'], EPOCH_NUM)
    }

    model, result = train(model, tokenizer, params,
                          train_examples,
                          valid_examples=dev_examples,
                          checkpoint_files=checkpoint_files)
    print(result)

100%|██████████| 4/4 [03:25<00:00, 51.27s/it]


***** Running training *****
Num examples: 67349
Batch size:   32
Num steps:    2104


Iteration:   0%|          | 0/2105 [00:00<?, ?it/s]


Epoch: 1


Iteration: 100%|██████████| 2105/2105 [49:54<00:00,  1.29s/it]


{'train_loss': 0.1246670034977291, 'train_global_step': 2105}
***** Running evaluation *****
Num examples:  872
Batch size:    8


Evaluating: 100%|██████████| 109/109 [00:16<00:00,  6.74it/s]


{'eval_loss': 0.2155188312978989, 'eval_accuracy': 0.9151376146788991, 'eval_f1_score': 0.9174107142857143, 'eval_matthews_corrcoef': 0.8303008389339931}


  0%|          | 0/4 [00:00<?, ?it/s]

{'train_loss': 0.1246670034977291, 'train_global_step': 2105}


100%|██████████| 4/4 [03:25<00:00, 50.94s/it]


***** Running training *****
Num examples: 67349
Batch size:   32
Num steps:    2104


Iteration:   0%|          | 0/2105 [00:00<?, ?it/s]


Epoch: 1


Iteration: 100%|██████████| 2105/2105 [46:07<00:00,  1.20s/it]


{'train_loss': 0.07454932219837737, 'train_global_step': 2105}
***** Running evaluation *****
Num examples:  872
Batch size:    8


Evaluating: 100%|██████████| 109/109 [00:17<00:00,  6.22it/s]


{'eval_loss': 0.24157569218345373, 'eval_accuracy': 0.9197247706422018, 'eval_f1_score': 0.9220489977728284, 'eval_matthews_corrcoef': 0.8395490166354784}


  0%|          | 0/4 [00:00<?, ?it/s]

{'train_loss': 0.07454932219837737, 'train_global_step': 2105}


100%|██████████| 4/4 [03:19<00:00, 50.01s/it]


***** Running training *****
Num examples: 67349
Batch size:   32
Num steps:    2104


Iteration:   0%|          | 0/2105 [00:00<?, ?it/s]


Epoch: 1


Iteration: 100%|██████████| 2105/2105 [42:15<00:00,  1.10s/it]


{'train_loss': 0.07110275359483625, 'train_global_step': 2105}
***** Running evaluation *****
Num examples:  872
Batch size:    8


Evaluating: 100%|██████████| 109/109 [00:19<00:00,  5.65it/s]


{'eval_loss': 0.25014348505286876, 'eval_accuracy': 0.9151376146788991, 'eval_f1_score': 0.9185022026431717, 'eval_matthews_corrcoef': 0.8309517687830787}
{'train_loss': 0.07110275359483625, 'train_global_step': 2105}


In [0]:
result, prob_preds = evaluate(model, tokenizer, params,
                              dev_examples)
result

***** Running evaluation *****
Num examples:  872
Batch size:    8


Evaluating: 100%|██████████| 109/109 [00:19<00:00,  5.72it/s]


{'eval_accuracy': 0.9151376146788991,
 'eval_f1_score': 0.9185022026431717,
 'eval_loss': 0.25014348505286876,
 'eval_matthews_corrcoef': 0.8309517687830787}

In [0]:
torch.save(model.bert.encoder.state_dict(), 'model-4-tr-enc.pt')

8 clusters:

In [0]:
model = BertForSequenceClassification.from_pretrained(
    params['bert_model'],
    cache_dir=params['cache_dir'],
    num_labels=params['num_labels']).to(params['device'])
model = model.to(device)
model.load_state_dict(torch.load("model.pt"))

In [0]:
device = params['device']
n_clusters = 8
for transformer_layer_ind in tqdm(range(12)):
    model.bert.encoder.layer[transformer_layer_ind].attention.self.query = \
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.query, n_clusters).to(device)

    model.bert.encoder.layer[transformer_layer_ind].attention.self.key = \
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.key, n_clusters).to(device)


    model.bert.encoder.layer[transformer_layer_ind].attention.self.value = \
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.value, n_clusters).to(device)


    model.bert.encoder.layer[transformer_layer_ind].attention.output.dense =\
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.output.dense, n_clusters).to(device)


    model.bert.encoder.layer[transformer_layer_ind].intermediate.dense =\
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].intermediate.dense, n_clusters).to(device)


    model.bert.encoder.layer[transformer_layer_ind].output.dense =\
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].output.dense, n_clusters).to(device)

100%|██████████| 12/12 [10:31<00:00, 53.45s/it]


In [0]:
torch.save(model.bert.encoder.state_dict(), 'model8_enc.pt')

In [0]:
!ls -lS

total 843292
-rw-r--r-- 1 root root 437982975 Apr  9 14:21 model.pt
-rw-r--r-- 1 root root 340259226 Apr  9 16:55 model-all-enc.pt
-rw-r--r-- 1 root root  32895719 Apr  9 17:11 model8_enc.pt
-rw-r--r-- 1 root root  22456619 Apr  9 16:54 model4_enc.pt
-rw-r--r-- 1 root root  22456619 Apr  9 16:54 model4.pt
-rw-r--r-- 1 root root   7439277 May  2  2018 SST-2.zip
drwxr-xr-x 6 root root      4096 Apr  9 14:20 lilbert
drwxr-xr-x 1 root root      4096 Apr  4 20:20 sample_data
drwxrwxr-x 3 root root      4096 May  2  2018 SST-2


In [3]:
print("Compression rate of encoder, 8 clusters: {}".format(340259226 / 32895719))
print("Compression rate of encoder, 4 clusters: {}".format(340259226 / 22456619))

Compression rate of encoder, 8 clusters: 10.343571636175517
Compression rate of encoder, 4 clusters: 15.151845698588911
