In [5]:
!wget -O "SST-2.zip" 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8'
!unzip "SST-2.zip"
!git clone https://username:password@github.com/yandexdataschool/lilbert.git
    
!pip install -r lilbert/requirements.txt
!mkdir ./lilbert/output



In [4]:
import sys
%load_ext autoreload
%autoreload 2
sys.path.append('lilbert/lilbert')

import numpy as np
import random
import torch
from tqdm import tqdm
from pytorch_pretrained_bert.tokenization import BertTokenizer

from lib import data_processors, tasks
from lib.bert import BertForSequenceClassification
from lib.train_eval import train, evaluate, predict

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [5]:
%env CUDA_VISIBLE_DEVICES=1

params = {
    'data_dir': 'SST-2',
    'output_dir': '../output',
    'cache_dir': '../model_cache',
    'task_name': 'sst2',
    'bert_model': 'bert-base-uncased',
    'max_seq_length': 128,
    'train_batch_size': 32,
    'eval_batch_size': 8,
    'learning_rate': 2e-5,
    'warmup_proportion': 0.1,
    'num_train_epochs': 1,
    'seed': 1331,
    'device': torch.device(
        'cuda' if torch.cuda.is_available()
        else 'cpu')
}

random.seed(params['seed'])
np.random.seed(params['seed'])
torch.manual_seed(params['seed'])

env: CUDA_VISIBLE_DEVICES=1


<torch._C.Generator at 0x7f38e3e15190>

In [0]:
device = params['device']

In [0]:
params['num_labels'] = tasks.num_labels[params['task_name']]
params['label_list'] = tasks.label_lists[params['task_name']]

processor = tasks.processors[params['task_name']]()
tokenizer = BertTokenizer.from_pretrained(
    params['bert_model'], do_lower_case=True)

train_examples = processor.get_train_examples(params['data_dir'])
dev_examples = processor.get_dev_examples(params['data_dir'])
model = BertForSequenceClassification.from_pretrained(
    params['bert_model'],
    cache_dir=params['cache_dir'],
    num_labels=params['num_labels']).to(params['device'])

In [8]:
!wget -O "model.pt" "https://www.dropbox.com/s/2gclpuhipfovph2/model_baseline_from_parts.pt?dl=0"

--2019-04-05 20:37:28--  https://www.dropbox.com/s/2gclpuhipfovph2/model_baseline_from_parts.pt?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.82.1, 2620:100:6032:1::a27d:5201
Connecting to www.dropbox.com (www.dropbox.com)|162.125.82.1|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/2gclpuhipfovph2/model_baseline_from_parts.pt [following]
--2019-04-05 20:37:29--  https://www.dropbox.com/s/raw/2gclpuhipfovph2/model_baseline_from_parts.pt
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc2f2886dba38b4c2cc320f3e590.dl.dropboxusercontent.com/cd/0/inline/AeezP0fXSZA-pvB1D8L5rzBgYNxkERs757-0yXPnqMSvFE9BW4gdSvrTAS3d0WnGH1kDlozB5ZhQ__bI6NS--TkOt7Go09o_ORHglehhaaC4-Q/file# [following]
--2019-04-05 20:37:29--  https://uc2f2886dba38b4c2cc320f3e590.dl.dropboxusercontent.com/cd/0/inline/AeezP0fXSZA-pvB1D8L5rzBgYNxkERs757-0yXPnqMSvFE9BW4gdSvrTAS3d0WnGH1kDlozB5ZhQ__b

In [0]:
model.load_state_dict(torch.load("model.pt"))

To reload initial model:

In [0]:
model = BertForSequenceClassification.from_pretrained(
    params['bert_model'],
    cache_dir=params['cache_dir'],
    num_labels=params['num_labels']).to(params['device'])
model.load_state_dict(torch.load("model.pt"))

Experiment

In [0]:
from sklearn.cluster import MiniBatchKMeans
class QuantizedLayer(torch.nn.Module):
    def __init__(self, layer=None, n_clusters=8,  size=None):
        super(QuantizedLayer, self).__init__()
        if layer is None:
            if size is None:
                raise ValueError("During random init, size must be passed.")
            self.matrix_size = size
            centroids = torch.randn(n_clusters).view(-1,1)
            centroids_idx = torch.randint(low=0, high=n_clusters, size=size).view(1,-1).type(torch.ByteTensor)
            self.bias=torch.nn.Parameter(torch.randn(size[0]))
        else:
            self.matrix_size = layer.weight.size()
            algo = MiniBatchKMeans(n_clusters)
            points = layer.weight.view(-1, 1).detach().cpu().numpy()
            algo.fit(points)
            centroids = torch.Tensor(algo.cluster_centers_)
            centroids_idx = torch.ByteTensor(algo.predict(points))
            self.bias = torch.nn.Parameter(layer.bias)

        triplets = centroids_idx.view(-1, 2)
        self.codes = torch.nn.Parameter(torch.sum(triplets * torch.ByteTensor([1, 8]),dim=-1).type(torch.uint8),
                                        requires_grad=False)
        self.emb = torch.nn.Embedding.from_pretrained(centroids)
        
        
    def forward(self, input_):
          weight = self.emb((self.codes.view(-1, 1) //\
                              torch.ByteTensor([1, 8]).to(device) %\
                              torch.ByteTensor([8,8]).to(device)).view(-1).long())
          weight = weight.view(self.matrix_size)
          return torch.functional.F.linear(input_, weight, self.bias)

In [0]:
model = BertForSequenceClassification.from_pretrained(
    params['bert_model'],
    cache_dir=params['cache_dir'],
    num_labels=params['num_labels']).to(params['device'])
model.load_state_dict(torch.load("model.pt"))

In [14]:
n_clusters = 8
for transformer_layer_ind in tqdm(range(12)):
    model.bert.encoder.layer[transformer_layer_ind].attention.self.query = \
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.query, n_clusters)
#     break
    model.bert.encoder.layer[transformer_layer_ind].attention.self.key = \
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.key, n_clusters)
    
    
    model.bert.encoder.layer[transformer_layer_ind].attention.self.value = \
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.self.value, n_clusters)
    
    
    model.bert.encoder.layer[transformer_layer_ind].attention.output.dense =\
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].attention.output.dense, n_clusters)
    
    
    model.bert.encoder.layer[transformer_layer_ind].intermediate.dense =\
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].intermediate.dense, n_clusters)
    
    
    model.bert.encoder.layer[transformer_layer_ind].output.dense =\
    QuantizedLayer(model.bert.encoder.layer[transformer_layer_ind].output.dense, n_clusters)

100%|██████████| 12/12 [09:39<00:00, 49.21s/it]


In [0]:
model = model.to(params['device'])

In [17]:
result, prob_preds = evaluate(model, tokenizer, params,
                              dev_examples)
result

***** Running evaluation *****
Num examples:  872
Batch size:    8


Evaluating: 100%|██████████| 109/109 [00:17<00:00,  6.34it/s]


{'eval_loss': 0.20849375055802222,
 'eval_accuracy': 0.9277522935779816,
 'eval_f1_score': 0.9288135593220338,
 'eval_matthews_corrcoef': 0.8554944362755638}

In [0]:
torch.save(model.state_dict(), 'model_quant_all.pt')

In [19]:
!ls -lS

total 572444
-rw-r--r-- 1 root root 437982975 Apr  5 20:37 model.pt
-rw-r--r-- 1 root root 140731981 Apr  5 20:52 model_quant_all.pt
-rw-r--r-- 1 root root   7439277 May  2  2018 SST-2.zip
drwxr-xr-x 6 root root      4096 Apr  5 19:35 lilbert
drwxr-xr-x 1 root root      4096 Mar 27 20:26 sample_data
drwxrwxr-x 3 root root      4096 May  2  2018 SST-2


In [20]:
437982975 / 140731981

3.112177998830273

Load model

In [30]:
new_model = BertForSequenceClassification.from_pretrained(
    params['bert_model'],
    cache_dir=params['cache_dir'],
    num_labels=params['num_labels']).to(params['device'])
n_clusters = 8
for transformer_layer_ind in tqdm(range(12)):
    new_model.bert.encoder.layer[transformer_layer_ind].attention.self.query = \
    QuantizedLayer(size=new_model.bert.encoder.layer[transformer_layer_ind].attention.self.query.weight.size(), n_clusters=n_clusters)
#     break
    new_model.bert.encoder.layer[transformer_layer_ind].attention.self.key = \
    QuantizedLayer(size=new_model.bert.encoder.layer[transformer_layer_ind].attention.self.key.weight.size(), n_clusters=n_clusters)
    
    
    new_model.bert.encoder.layer[transformer_layer_ind].attention.self.value = \
    QuantizedLayer(size=new_model.bert.encoder.layer[transformer_layer_ind].attention.self.value.weight.size(), n_clusters=n_clusters)
    
    
    new_model.bert.encoder.layer[transformer_layer_ind].attention.output.dense =\
    QuantizedLayer(size=new_model.bert.encoder.layer[transformer_layer_ind].attention.output.dense.weight.size(), n_clusters=n_clusters)
    
    
    new_model.bert.encoder.layer[transformer_layer_ind].intermediate.dense =\
    QuantizedLayer(size=new_model.bert.encoder.layer[transformer_layer_ind].intermediate.dense.weight.size(), n_clusters=n_clusters)
    
    
    new_model.bert.encoder.layer[transformer_layer_ind].output.dense =\
    QuantizedLayer(size=new_model.bert.encoder.layer[transformer_layer_ind].output.dense.weight.size(), n_clusters=n_clusters)

100%|██████████| 12/12 [00:03<00:00,  3.67it/s]


In [0]:
new_model.load_state_dict(torch.load('model_quant_all.pt'))

In [0]:
new_model=new_model.to(params['device'])

In [34]:
result, prob_preds = evaluate(new_model, tokenizer, params,
                              dev_examples)
result

***** Running evaluation *****
Num examples:  872
Batch size:    8


Evaluating: 100%|██████████| 109/109 [00:17<00:00,  6.34it/s]


{'eval_loss': 0.20849375055802222,
 'eval_accuracy': 0.9277522935779816,
 'eval_f1_score': 0.9288135593220338,
 'eval_matthews_corrcoef': 0.8554944362755638}