# Дистилляция BERT

## Загрузка необходимых модулей

In [2]:
import os
os.listdir('/kaggle/input')

['model-bert1', 'norm-sep']

In [None]:
cd /content/drive/MyDrive/cmc/classif_wp

In [3]:
!pip3 install transformers catboost

[0m

In [3]:
import os
import random
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

from catboost import Pool, CatBoostRegressor

from sklearn.metrics import classification_report, f1_score
from tqdm import tqdm

In [4]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(13)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device

device(type='cuda')

## Загрузка токенизатора, модели, конфигурации BERT

In [6]:
# config
config = AutoConfig.from_pretrained('../input/model-bert1/')
# tokenizer
tokenizer = AutoTokenizer.from_pretrained('../input/model-bert1/', pad_to_max_length=True)
# model
model = AutoModelForSequenceClassification.from_pretrained('../input/model-bert1/', config=config)
     

## Подготовка данных

In [7]:
category_index = {'WEB': 0,
 'SECURITY': 1,
 'INTROS': 2,
 'DATA_SCIENCE': 3,
 'ORG': 4,
 'KnD': 5,
 'OPENING': 6,
 'DEVOPS': 7,
 'GAMES': 8,
 'SPECIAL': 9,
 'USER': 10,
 'EVENTS': 11,
 'DEV': 12,
 'MOBILE': 13,
 'SYSADM': 14,
 'MULTIMEDIA': 15,
 'DATABASES': 16,
 'MESSENGERS': 17,
 'DIY': 18,
 'MANAGEMENT': 19,
 'HARDWARE': 20,
 'MISC': 21,
 'LAW': 22,
 'EDUCATION': 23,
 'HISTORY': 24}
reverse_category_index = dict(map(reversed, category_index.items()))

In [8]:
df = pd.read_csv('../input/norm-sep/norm_sep_max_25.csv', index_col=0)
texts = df.document.values
labels = [category_index[i] for i in tqdm(df.category.values)]

100%|██████████| 17018/17018 [00:00<00:00, 2641599.70it/s]


In [9]:
train_texts, test_texts, train_labels, test_labels = train_test_split(list(texts), labels, random_state = 42, train_size = 0.9, stratify = labels)

len(train_texts), len(test_texts)

(15316, 1702)

In [10]:
tokens = [tokenizer.encode(
                      text,                            
                      add_special_tokens=True,       
                      max_length=500,                
                      padding='max_length',         
                      truncation=True,               
                      return_attention_mask=True,  
                      pad_to_max_length='right') for text in tqdm(train_texts)]

100%|██████████| 15316/15316 [00:37<00:00, 407.30it/s]


In [12]:
tokens_tensor = torch.tensor(tokens)

In [13]:
batch_size = 16
sampler = SequentialSampler(tokens_tensor)
dataloader = DataLoader(tokens_tensor, sampler=sampler, batch_size=batch_size)

## Получение логитов BERT

In [14]:
train_logits = []
with torch.no_grad():
    model.to(device)
    for batch in tqdm(dataloader):
        batch = batch.to(device)
        outputs = model(batch)
        logits = outputs[0].detach().cpu().numpy()
        train_logits.extend(logits)

100%|██████████| 958/958 [04:15<00:00,  3.75it/s]


## Обучение ученика

In [15]:
data_pool = Pool(tokens, train_logits)

In [16]:
distilled_model = CatBoostRegressor(iterations=3000, 
                          depth=8, 
                          learning_rate=0.03, 
                          loss_function='MultiRMSE',
                          verbose=100)

In [17]:
distilled_model.fit(data_pool)

0:	learn: 8.8494332	total: 11.6s	remaining: 9h 37m 41s
100:	learn: 8.4593724	total: 18m 19s	remaining: 8h 45m 56s
200:	learn: 8.1159379	total: 36m 18s	remaining: 8h 25m 43s
300:	learn: 7.7984303	total: 54m 30s	remaining: 8h 8m 43s
400:	learn: 7.5210958	total: 1h 12m 12s	remaining: 7h 48m 1s
500:	learn: 7.2618525	total: 1h 29m 49s	remaining: 7h 28m 2s
600:	learn: 7.0104185	total: 1h 47m 37s	remaining: 7h 9m 34s
700:	learn: 6.7676619	total: 2h 5m 21s	remaining: 6h 51m 8s
800:	learn: 6.5483457	total: 2h 22m 55s	remaining: 6h 32m 22s
900:	learn: 6.3381784	total: 2h 40m 32s	remaining: 6h 13m 59s
1000:	learn: 6.1295712	total: 2h 58m 2s	remaining: 5h 55m 33s
1100:	learn: 5.9296991	total: 3h 15m 25s	remaining: 5h 37m 3s
1200:	learn: 5.7513463	total: 3h 32m 52s	remaining: 5h 18m 52s
1300:	learn: 5.5773290	total: 3h 50m 19s	remaining: 5h 47s
1400:	learn: 5.4066444	total: 4h 7m 46s	remaining: 4h 42m 47s
1500:	learn: 5.2462546	total: 4h 25m 3s	remaining: 4h 24m 42s
1600:	learn: 5.0946798	total: 4h

<catboost.core.CatBoostRegressor at 0x75d483ecb310>

## Сравнение качества моделей

In [19]:
test_tokens = [tokenizer.encode(
                      text,                            
                      add_special_tokens=True,       
                      max_length=500,                
                      padding='max_length',         
                      truncation=True,               
                      return_attention_mask=True,  
                      pad_to_max_length='right') for text in tqdm(test_texts)]

100%|██████████| 1702/1702 [00:04<00:00, 414.56it/s]


In [20]:
test_attention_masks = [[float(i>0) for i in sequence] for sequence in tqdm(test_tokens, total=len(test_tokens))]

100%|██████████| 1702/1702 [00:00<00:00, 11125.98it/s]


In [22]:
test_inputs = torch.tensor(test_tokens)

test_labels_ = torch.tensor(test_labels)

test_masks = torch.tensor(test_attention_masks)

In [23]:
batch_size = 16

test_data = TensorDataset(test_inputs, test_masks, test_labels_)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

In [24]:
model.eval()

predictions, true_labels = [], []

for batch in test_dataloader:
    
    batch = tuple(t.to(device) for t in batch)
  
    b_input_ids, b_input_mask, b_labels = batch

    with torch.no_grad():

        outputs = model(b_input_ids, token_type_ids=None, 
                        attention_mask=b_input_mask)

    logits = outputs[0].detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()

    batch_preds = np.argmax(logits, axis=1)

    predictions.extend(batch_preds) 
    true_labels.extend(label_ids)

In [25]:
# результаты модели BERT
print("F1-micro           ", f1_score(true_labels, predictions, average="micro"))
print(classification_report(true_labels, predictions, target_names=reverse_category_index.values(), zero_division=True))

F1-micro            0.6692126909518213
              precision    recall  f1-score   support

         WEB       0.70      0.83      0.76       133
    SECURITY       0.73      0.71      0.72       123
      INTROS       0.67      0.42      0.52        19
DATA_SCIENCE       0.78      0.84      0.81        50
         ORG       0.74      0.53      0.62       165
         KnD       0.63      0.81      0.71       230
     OPENING       0.27      0.40      0.32        20
      DEVOPS       0.78      0.71      0.75       119
       GAMES       0.77      0.85      0.81        20
     SPECIAL       0.46      0.70      0.55        91
        USER       0.66      0.67      0.67       109
      EVENTS       0.57      0.65      0.60        20
         DEV       0.78      0.69      0.73       210
      MOBILE       0.77      0.71      0.74        38
      SYSADM       0.63      0.62      0.63       109
  MULTIMEDIA       0.68      0.73      0.70        44
   DATABASES       0.82      0.55      0.6

In [26]:
# результаты модели-ученика
tokens_pool = Pool(test_tokens)

distilled_predicted_logits = distilled_model.predict(tokens_pool, prediction_type='RawFormulaVal') 

In [28]:
pred = np.argmax(distilled_predicted_logits, axis=1)
print("F1-micro           ", f1_score(test_labels, pred, average="micro"))
print(classification_report(test_labels, pred, target_names=reverse_category_index.values()))

F1-micro            0.16862514688601646
              precision    recall  f1-score   support

         WEB       0.23      0.17      0.19       133
    SECURITY       0.00      0.00      0.00       123
      INTROS       0.00      0.00      0.00        19
DATA_SCIENCE       0.20      0.04      0.07        50
         ORG       0.39      0.12      0.19       165
         KnD       0.18      0.84      0.30       230
     OPENING       0.00      0.00      0.00        20
      DEVOPS       0.00      0.00      0.00       119
       GAMES       0.00      0.00      0.00        20
     SPECIAL       0.04      0.11      0.06        91
        USER       0.33      0.01      0.02       109
      EVENTS       0.00      0.00      0.00        20
         DEV       0.17      0.16      0.16       210
      MOBILE       0.00      0.00      0.00        38
      SYSADM       0.17      0.04      0.06       109
  MULTIMEDIA       0.00      0.00      0.00        44
   DATABASES       0.00      0.00      0.

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
