# Классификатор. Дообучение предобученной сети

В качестве проекта была выбрана задача действующего соревнования на kaggle
# Contradictory, My Dear Watson
Detecting contradiction and entailment in multilingual text

Задача - определенние класса. На вход подаются да предложения, цель - определить, связаны ли эти предложения по смыслу (второе дополняет первое), противоречат ли они друг другу или они никак не связаны. Классы: 1) одно предложение связано с другим, дополняет его, они не противоречат друг другу, 2) второе предложение противоречит первому, 3) предложения не связаны (entailment, contradiction, neutral).

Для обучения была выбрана DistilBertModel, она проще и меньше bert, но использует ее эмбеддинги, что позволяет рассматривать контекст и связь слов. Обучим модель на определение 3 классов. Обучение таких моделей (даже очень облегченных) идет очень медленно. Обучим на 30 эпохах при размере батча 32. Это максимум, который позволяют ресурсы collab.

In [26]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm
from collections import Counter

import pandas as pd
from sklearn.metrics import accuracy_score,  precision_recall_fscore_support

In [27]:
# !pip install transformers

In [28]:
from transformers import pipeline

In [29]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

In [30]:
df = pd.read_csv('train.csv')

In [31]:
df.head()

Unnamed: 0,id,premise,hypothesis,lang_abv,language,label
0,5130fd2cb5,and these comments were considered in formulat...,The rules developed in the interim were put to...,en,English,0
1,5b72532a0b,These are issues that we wrestle with in pract...,Practice groups are not permitted to work on t...,en,English,2
2,3931fbe82a,Des petites choses comme celles-là font une di...,J'essayais d'accomplir quelque chose.,fr,French,0
3,5622f0c60b,you know they can't really defend themselves l...,They can't defend themselves because of their ...,en,English,0
4,86aaa48b45,ในการเล่นบทบาทสมมุติก็เช่นกัน โอกาสที่จะได้แสด...,เด็กสามารถเห็นได้ว่าชาติพันธุ์แตกต่างกันอย่างไร,th,Thai,1


In [32]:
# так как основой является модель bert, а задача у нас определить класс связи предложений, то разделим первое и второе высказывание сепаратором, который воспринимает система ввода
df['text'] = df['premise']+' [SEP] '+df['hypothesis']

In [33]:
df['text'][0]

'and these comments were considered in formulating the interim rules. [SEP] The rules developed in the interim were put together with these comments in mind.'

In [34]:
df.head()

Unnamed: 0,id,premise,hypothesis,lang_abv,language,label,text
0,5130fd2cb5,and these comments were considered in formulat...,The rules developed in the interim were put to...,en,English,0,and these comments were considered in formulat...
1,5b72532a0b,These are issues that we wrestle with in pract...,Practice groups are not permitted to work on t...,en,English,2,These are issues that we wrestle with in pract...
2,3931fbe82a,Des petites choses comme celles-là font une di...,J'essayais d'accomplir quelque chose.,fr,French,0,Des petites choses comme celles-là font une di...
3,5622f0c60b,you know they can't really defend themselves l...,They can't defend themselves because of their ...,en,English,0,you know they can't really defend themselves l...
4,86aaa48b45,ในการเล่นบทบาทสมมุติก็เช่นกัน โอกาสที่จะได้แสด...,เด็กสามารถเห็นได้ว่าชาติพันธุ์แตกต่างกันอย่างไร,th,Thai,1,ในการเล่นบทบาทสมมุติก็เช่นกัน โอกาสที่จะได้แสด...


In [35]:
#разделим датасет на обучающую и тестовую выборки
from sklearn.model_selection import train_test_split

In [36]:
df_train, df_val = train_test_split(df, test_size=0.3, random_state = 21)

In [37]:
df_train['label'].value_counts()

0    2907
2    2834
1    2743
Name: label, dtype: int64

Kлассы сбалансированы, балансировка не требуется

## Обучение

In [38]:
class TwitterDataset(torch.utils.data.Dataset):

    def __init__(self, txts, labels):
        self._labels = labels
        
        #пользуемся токенайзером bert, чтобы данные можно было далее использовать для формирования эмбеддингов обученной модели bert
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
        self._txts = [self.tokenizer(text, padding='max_length', max_length=100,
                                     truncation=True, return_tensors="pt")
                      for text in txts]

    def __len__(self):
        return len(self._txts)

    def __getitem__(self, index):
        return self._txts[index], self._labels[index]

In [39]:
#создает загрузчик данных для обучения

y_train = df_train['label'].values
y_val = df_val['label'].values

train_dataset = TwitterDataset(df_train['text'], y_train)
valid_dataset = TwitterDataset(df_val['text'], y_val)

train_loader = torch.utils.data.DataLoader(train_dataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=2)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
                          batch_size=32,
                          shuffle=False,
                          num_workers=1)

In [40]:
from torch import nn
from transformers import DistilBertModel


class BertClassifier(nn.Module):

    def __init__(self, dropout=0.1):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-multilingual-cased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 3)
        self.sigm = nn.Sigmoid()

    def forward(self, x):

        pooled_output = self.bert(input_ids=x, return_dict=None)
        dropout_output = self.dropout(pooled_output.last_hidden_state)
        linear_output = self.linear(dropout_output)
        final_layer = self.sigm(linear_output)
        return final_layer

In [41]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [42]:
model = BertClassifier().to(device)
criterion = nn.CrossEntropyLoss()

# optimizer = Adam(model.parameters(), lr=0.001)  # полное обучение
optimizer = Adam(model.linear.parameters(), lr=0.001)  # неполное обучение

model.safetensors:   0%|          | 0.00/542M [00:00<?, ?B/s]

In [43]:
print(model)
print("Parameters full train:", sum([param.nelement() for param in model.parameters()]))
print("Parameters transfer learning:", sum([param.nelement() for param in model.linear.parameters()]))

BertClassifier(
  (bert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(

In [183]:
for epoch_num in range(30):
    total_acc_train = 0
    total_loss_train = 0

    model.train()
    for train_input, train_label in tqdm(train_loader):

        input_id = train_input['input_ids'].squeeze(1).to(device)
        train_label = train_label.to(device)

        output = model(input_id)[:,-1]

        batch_loss = criterion(output, train_label)
        total_loss_train += batch_loss.item()

        acc = (output.argmax(dim=1) == train_label).sum().item()
        total_acc_train += acc

        model.zero_grad()
        batch_loss.backward()
        optimizer.step()

    model.eval()
    total_loss_val, total_acc_val = 0.0, 0.0
    for val_input, val_label in valid_loader:
        val_label = val_label.to(device)
        input_id = val_input['input_ids'].squeeze(1).to(device)

        output = model(input_id)[:,-1]

        batch_loss = criterion(output, val_label)
        total_loss_val += batch_loss.item()

        acc = (output.argmax(dim=1) == val_label).sum().item()
        total_acc_val += acc

    print(
        f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_dataset): .3f} \
        | Train Accuracy: {total_acc_train / len(train_dataset): .3f} \
        | Val Loss: {total_loss_val / len(valid_dataset): .3f} \
        | Val Accuracy: {total_acc_val / len(valid_dataset): .3f}')

100%|██████████| 266/266 [01:02<00:00,  4.23it/s]


Epochs: 1 | Train Loss:  0.034         | Train Accuracy:  0.386         | Val Loss:  0.034         | Val Accuracy:  0.354


100%|██████████| 266/266 [01:06<00:00,  4.03it/s]


Epochs: 2 | Train Loss:  0.034         | Train Accuracy:  0.391         | Val Loss:  0.034         | Val Accuracy:  0.356


100%|██████████| 266/266 [01:07<00:00,  3.92it/s]


Epochs: 3 | Train Loss:  0.034         | Train Accuracy:  0.389         | Val Loss:  0.034         | Val Accuracy:  0.360


100%|██████████| 266/266 [01:08<00:00,  3.90it/s]


Epochs: 4 | Train Loss:  0.034         | Train Accuracy:  0.389         | Val Loss:  0.034         | Val Accuracy:  0.364


100%|██████████| 266/266 [01:08<00:00,  3.88it/s]


Epochs: 5 | Train Loss:  0.034         | Train Accuracy:  0.389         | Val Loss:  0.034         | Val Accuracy:  0.363


100%|██████████| 266/266 [01:09<00:00,  3.85it/s]


Epochs: 6 | Train Loss:  0.034         | Train Accuracy:  0.388         | Val Loss:  0.034         | Val Accuracy:  0.366


100%|██████████| 266/266 [01:09<00:00,  3.84it/s]


Epochs: 7 | Train Loss:  0.034         | Train Accuracy:  0.393         | Val Loss:  0.034         | Val Accuracy:  0.366


100%|██████████| 266/266 [01:09<00:00,  3.82it/s]


Epochs: 8 | Train Loss:  0.034         | Train Accuracy:  0.392         | Val Loss:  0.034         | Val Accuracy:  0.359


100%|██████████| 266/266 [01:09<00:00,  3.80it/s]


Epochs: 9 | Train Loss:  0.034         | Train Accuracy:  0.386         | Val Loss:  0.034         | Val Accuracy:  0.369


100%|██████████| 266/266 [01:09<00:00,  3.83it/s]


Epochs: 10 | Train Loss:  0.034         | Train Accuracy:  0.392         | Val Loss:  0.034         | Val Accuracy:  0.363


100%|██████████| 266/266 [01:10<00:00,  3.78it/s]


Epochs: 11 | Train Loss:  0.034         | Train Accuracy:  0.388         | Val Loss:  0.034         | Val Accuracy:  0.377


100%|██████████| 266/266 [01:10<00:00,  3.77it/s]


Epochs: 12 | Train Loss:  0.034         | Train Accuracy:  0.395         | Val Loss:  0.034         | Val Accuracy:  0.368


100%|██████████| 266/266 [01:10<00:00,  3.78it/s]


Epochs: 13 | Train Loss:  0.034         | Train Accuracy:  0.396         | Val Loss:  0.034         | Val Accuracy:  0.383


100%|██████████| 266/266 [01:10<00:00,  3.79it/s]


Epochs: 14 | Train Loss:  0.034         | Train Accuracy:  0.400         | Val Loss:  0.034         | Val Accuracy:  0.352


100%|██████████| 266/266 [01:09<00:00,  3.81it/s]


Epochs: 15 | Train Loss:  0.034         | Train Accuracy:  0.396         | Val Loss:  0.034         | Val Accuracy:  0.375


100%|██████████| 266/266 [01:09<00:00,  3.81it/s]


Epochs: 16 | Train Loss:  0.034         | Train Accuracy:  0.397         | Val Loss:  0.034         | Val Accuracy:  0.381


100%|██████████| 266/266 [01:09<00:00,  3.81it/s]


Epochs: 17 | Train Loss:  0.034         | Train Accuracy:  0.387         | Val Loss:  0.034         | Val Accuracy:  0.380


100%|██████████| 266/266 [01:09<00:00,  3.83it/s]


Epochs: 18 | Train Loss:  0.034         | Train Accuracy:  0.398         | Val Loss:  0.034         | Val Accuracy:  0.376


100%|██████████| 266/266 [01:09<00:00,  3.85it/s]


Epochs: 19 | Train Loss:  0.034         | Train Accuracy:  0.401         | Val Loss:  0.034         | Val Accuracy:  0.358


100%|██████████| 266/266 [01:09<00:00,  3.83it/s]


Epochs: 20 | Train Loss:  0.034         | Train Accuracy:  0.401         | Val Loss:  0.034         | Val Accuracy:  0.372


100%|██████████| 266/266 [01:09<00:00,  3.82it/s]


Epochs: 21 | Train Loss:  0.034         | Train Accuracy:  0.401         | Val Loss:  0.034         | Val Accuracy:  0.377


100%|██████████| 266/266 [01:09<00:00,  3.83it/s]


Epochs: 22 | Train Loss:  0.034         | Train Accuracy:  0.398         | Val Loss:  0.034         | Val Accuracy:  0.380


100%|██████████| 266/266 [01:09<00:00,  3.83it/s]


Epochs: 23 | Train Loss:  0.034         | Train Accuracy:  0.408         | Val Loss:  0.034         | Val Accuracy:  0.373


100%|██████████| 266/266 [01:09<00:00,  3.82it/s]


Epochs: 24 | Train Loss:  0.034         | Train Accuracy:  0.396         | Val Loss:  0.034         | Val Accuracy:  0.382


100%|██████████| 266/266 [01:09<00:00,  3.83it/s]


Epochs: 25 | Train Loss:  0.034         | Train Accuracy:  0.402         | Val Loss:  0.034         | Val Accuracy:  0.371


100%|██████████| 266/266 [01:09<00:00,  3.82it/s]


Epochs: 26 | Train Loss:  0.034         | Train Accuracy:  0.404         | Val Loss:  0.034         | Val Accuracy:  0.381


100%|██████████| 266/266 [01:09<00:00,  3.83it/s]


Epochs: 27 | Train Loss:  0.034         | Train Accuracy:  0.402         | Val Loss:  0.034         | Val Accuracy:  0.381


100%|██████████| 266/266 [01:08<00:00,  3.87it/s]


Epochs: 28 | Train Loss:  0.034         | Train Accuracy:  0.397         | Val Loss:  0.034         | Val Accuracy:  0.382


100%|██████████| 266/266 [01:08<00:00,  3.90it/s]


Epochs: 29 | Train Loss:  0.034         | Train Accuracy:  0.403         | Val Loss:  0.034         | Val Accuracy:  0.379


100%|██████████| 266/266 [01:08<00:00,  3.89it/s]


Epochs: 30 | Train Loss:  0.034         | Train Accuracy:  0.406         | Val Loss:  0.034         | Val Accuracy:  0.386


  0%|          | 0/266 [00:00<?, ?it/s]

Вывод: обучение длится долго. За 30 эпох удалось повысить показатель accuracy train с 0.38 до 0.406. accuracy valid c 0.35 до 0.39. Динамика положительная, но требуется еще время и ресурсы на обучение. Классы сбалансированы, поэтому accuracy может быть объективным показателем оценки качества модели.
Для улучшения качества в дальнейшем нужно дополнительно обучить модель в течение длительного времени.

In [45]:
torch.save(model.state_dict(), "bert_contradictory.pth")

In [46]:
model1 = BertClassifier().to(device)
model1.load_state_dict(torch.load("bert_contradictory.pth"))
criterion = nn.CrossEntropyLoss()

# optimizer = Adam(model.parameters(), lr=0.001)  # полное обучение
optimizer = Adam(model.linear.parameters(), lr=0.001)  # неполное обучение