In [1]:
# Приближение CLS-эмбеддингов (после линейной проекции 256 -> 768) к эмбеддингам sbert_large_mt_nlu_ru
# Предсказание замаскированных токенов (full word masks)
# Дистилляция распределения всех токенов из sbert_large_mt_nlu_ru

In [2]:
# pip install transformers sentencepiece --trusted-host pypi.org --trusted-host files.pythonhosted.org

In [3]:
# pip install torch torchvision --trusted-host pypi.org --trusted-host files.pythonhosted.org

In [4]:
# this is needed for imac to work with BertForPreTraining
import os

os.environ['KMP_DUPLICATE_LIB_OK']='True'

### Tokenizer

In [5]:
from transformers import BertTokenizerFast, BertForPreTraining, BertModel, BertConfig

In [6]:
import pandas as pd

In [7]:
train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')

In [8]:
X_train=list(train['text'])
y_train=list(train['label'])
X_test=list(test['text'])
y_test=list(test['label'])

In [9]:
tokenizer = BertTokenizerFast.from_pretrained('./rubert-tiny')

In [10]:
from collections import Counter
from tqdm.auto import tqdm, trange

In [11]:
cnt = Counter()
for text in tqdm(X_train):
    cnt.update(tokenizer(str(text))['input_ids'])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10088.0), HTML(value='')))




In [12]:
resulting_vocab = {
    tokenizer.vocab[k] for k in tokenizer.special_tokens_map.values()
}

In [13]:
for k, v in cnt.items():
    if v > 3:
        resulting_vocab.add(k)

In [14]:
resulting_vocab = sorted(resulting_vocab)

In [15]:
tokenizer.save_pretrained('./bert_distill2');

In [16]:
inv_voc = {idx: word for word, idx in tokenizer.vocab.items()}

In [17]:
with open('./bert_distill2/vocab.txt', 'w', encoding='utf-8') as f:
    for idx in resulting_vocab:
        f.write(inv_voc[idx] + '\n')

In [18]:
# remove tokenizer.json and resave
! rm -rf bert_distill2/tokenizer.json

In [19]:
tokenizer = BertTokenizerFast.from_pretrained('./bert_distill2')

In [20]:
tokenizer.save_pretrained('./bert_distill2');

### Model

https://huggingface.co/docs/transformers/model_doc/bert

In [21]:
tokenizer_distill = BertTokenizerFast.from_pretrained('./bert_distill2')

In [22]:
config = BertConfig(
    emb_size=256,
    hidden_size=256,
    intermediate_size=256,
    max_position_embeddings=512,
    num_attention_heads=8,
    num_hidden_layers=3,
    vocab_size=tokenizer_distill.vocab_size
)

In [23]:
model = BertForPreTraining(config)

In [24]:
model.save_pretrained('./bert_distill2')

In [25]:
from transformers import BertModel
# load model without CLS-head
teacher = BertForPreTraining.from_pretrained('./rubert-tiny')

In [26]:
tokenizer_teacher = BertTokenizerFast.from_pretrained('./rubert-tiny')

In [27]:
# copy input embeddings accordingly with resulting_vocab
model.bert.embeddings.word_embeddings.weight.data = teacher.bert.embeddings.word_embeddings.weight.data[resulting_vocab, :256].clone()
model.bert.embeddings.position_embeddings.weight.data = teacher.bert.embeddings.position_embeddings.weight.data[:, :256].clone()

In [28]:
# copy output embeddings
model.cls.predictions.decoder.weight.data = teacher.cls.predictions.decoder.weight.data[resulting_vocab, :256].clone()

In [29]:
model.save_pretrained('./bert_distill2')

### Adapter

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [31]:
adapter_emb = torch.nn.Linear(256, 312)

### Train loop for embeddings approximation

In [32]:
from transformers import BertModel

teacher_mse = BertModel.from_pretrained('./rubert-tiny')

Some weights of the model checkpoint at ./rubert-tiny were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [33]:
from itertools import chain

optimizer = torch.optim.Adam(
    params=[p for p in chain(
        model.parameters(), 
        adapter_emb.parameters(),
        ) if p.requires_grad], 
    lr=1e-5
)

In [34]:
batch_size = 16
epochs = 50

In [35]:
from tqdm.auto import tqdm, trange

### MLM

In [36]:
from transformers import DataCollatorForWholeWordMask

data_collator = DataCollatorForWholeWordMask(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.2
)

In [37]:
def get_mask_labels(input_ids, tokenizer, collator):
    mask_labels = []
    for e in input_ids:
        ref_tokens = []
        for idx in e:
            token = tokenizer._convert_id_to_token(idx)
            ref_tokens.append(token)
        mask_labels.append(collator._whole_word_mask(ref_tokens))
    ml = torch.tensor(mask_labels)
    inputs, labels = collator.torch_mask_tokens(input_ids, ml)
    return inputs, labels

In [38]:
def preprocess_inputs(inputs, tokenizer, collator):
    inputs['input_ids'], inputs['labels'] = get_mask_labels(inputs['input_ids'], tokenizer, collator)
    return {k: v.to(model.device) for k, v in inputs.items()}

### Token distribution loss

In [39]:
# vocab mapping
teacher_vocab = tokenizer_teacher.vocab

In [40]:
vocab_mapping = sorted(teacher_vocab[w] for w in tokenizer.vocab)

In [41]:
def loss_kl(inputs, outputs, model, teacher, vocab_mapping, temperature=1.0):
    new_inputs = torch.tensor(
        [[vocab_mapping[i] for i in row] for row in inputs['input_ids']]
    ).to(inputs['input_ids'].device)
    with torch.no_grad():
        teacher_out = teacher(
            input_ids=new_inputs, 
            token_type_ids=inputs['token_type_ids'],
            attention_mask=inputs['attention_mask']
        )
    # the whole batch, all tokens after the [cls], the whole dimension
    kd_loss = torch.nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(outputs.prediction_logits[:, 1:, :] / temperature, dim=1), 
        F.softmax(teacher_out.prediction_logits[:, 1:, vocab_mapping] / temperature, dim=1)
    ) / outputs.prediction_logits.shape[-1]
    return kd_loss

### Train loop

In [42]:
model.train()
loss_emb = torch.nn.MSELoss()
loss_mlm = nn.CrossEntropyLoss()

for epoch in range(epochs):
    for _ in trange(int(train.shape[0] / batch_size)):
        texts = list(train.text.sample(batch_size))
        texts = [str(text) for text in texts]
        
        # MSE-loss
        input_teacher = {k: v for k, v in tokenizer_teacher(
                texts,
                return_tensors='pt',
                padding=True,
                max_length=64,
                truncation=True
            ).items()}

        with torch.no_grad():
            out_teacher = teacher_mse(**input_teacher)

        embeddings_teacher_norm = torch.nn.functional.normalize(out_teacher.pooler_output)

        input_distill = {k: v for k, v in tokenizer_distill(
                texts,
                return_tensors='pt',
                padding=True,
                max_length=64,
                truncation=True
            ).items()}

        out = model(**input_distill, output_hidden_states=True)
        embeddings = model.bert.pooler(out.hidden_states[-1])
        embeddings_norm = torch.nn.functional.normalize(adapter_emb(embeddings))
        loss = loss_emb(embeddings_norm, embeddings_teacher_norm)
        
        # KL-loss
        loss += loss_kl(input_distill, out, model, teacher, vocab_mapping)
        
        # MLM-loss
        inputs = tokenizer_distill(texts, return_tensors='pt', padding=True, truncation=True, max_length=64)
        inputs = preprocess_inputs(inputs, tokenizer_distill, data_collator)
        outputs = model(**inputs, output_hidden_states=True)
        loss += loss_mlm(
                outputs.prediction_logits.view(-1, model.config.vocab_size),
                inputs['labels'].view(-1)
            )
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()

    # evaluate loop
    test_loss = 0    
    for text in test.text:
        input_teacher = {k: v for k, v in tokenizer_teacher(
                str(text),
                return_tensors='pt',
                padding=True,
                max_length=64,
                truncation=True
            ).items()}
        
        with torch.no_grad():
            out_teacher = teacher_mse(**input_teacher)

        embeddings_teacher_norm = torch.nn.functional.normalize(out_teacher.pooler_output)

        input_distill = {k: v for k, v in tokenizer_distill(
                str(text),
                return_tensors='pt',
                padding=True,
                max_length=64,
                truncation=True
            ).items()}
        with torch.no_grad():
            out = model(**input_distill, output_hidden_states=True)
        embeddings = model.bert.pooler(out.hidden_states[-1])
        embeddings_norm = torch.nn.functional.normalize(adapter_emb(embeddings))

        test_loss += loss_emb(embeddings_norm, embeddings_teacher_norm)
    print(epoch, test_loss)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


0 tensor(8.2812, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


1 tensor(8.2388, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


2 tensor(8.2116, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


3 tensor(8.1689, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


4 tensor(8.0650, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


5 tensor(7.8349, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


6 tensor(7.4079, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


7 tensor(6.9094, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


8 tensor(6.5352, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


9 tensor(6.3031, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


10 tensor(6.1642, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


11 tensor(6.0702, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


12 tensor(5.9999, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


13 tensor(5.9080, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


14 tensor(5.8427, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


15 tensor(5.7685, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


16 tensor(5.6715, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


17 tensor(5.6287, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


18 tensor(5.5734, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


19 tensor(5.5358, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


20 tensor(5.4648, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


21 tensor(5.4481, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


22 tensor(5.3889, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


23 tensor(5.3594, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


24 tensor(5.3150, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


25 tensor(5.2651, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


26 tensor(5.2025, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


27 tensor(5.1772, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


28 tensor(5.1118, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


29 tensor(5.0526, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


30 tensor(5.0073, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


31 tensor(4.9525, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


32 tensor(4.9061, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


33 tensor(4.8714, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


34 tensor(4.8589, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


35 tensor(4.7980, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


36 tensor(4.7578, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


37 tensor(4.7487, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


38 tensor(4.7171, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


39 tensor(4.6677, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


40 tensor(4.6599, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


41 tensor(4.6346, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


42 tensor(4.6297, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


43 tensor(4.5930, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


44 tensor(4.5689, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


45 tensor(4.5434, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


46 tensor(4.5285, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


47 tensor(4.5332, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


48 tensor(4.4793, grad_fn=<AddBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=630.0), HTML(value='')))


49 tensor(4.4840, grad_fn=<AddBackward0>)


### Save model

In [43]:
model.save_pretrained('./bert_distill2')

### Train classifier

In [44]:
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):

  def __init__(self, texts, targets, tokenizer, max_len=512):
    self.texts = texts
    self.targets = targets
    self.tokenizer = tokenizer
    self.max_len = max_len

  def __len__(self):
    return len(self.texts)

  def __getitem__(self, idx):
    text = str(self.texts[idx])
    target = self.targets[idx]

    encoding = self.tokenizer.encode_plus(
        str(text),
        add_special_tokens=True,
        max_length=self.max_len,
        return_token_type_ids=False,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt',
        truncation=True
    )

    return {
      'text': text,
      'input_ids': encoding['input_ids'].flatten(),
      'attention_mask': encoding['attention_mask'].flatten(),
      'targets': torch.tensor(target, dtype=torch.long)
    }

In [45]:
from tqdm import tqdm
import numpy as np
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import precision_recall_fscore_support


class BertClassifier:

    def __init__(self, path, n_classes=2):
        self.path = path
        self.model = BertForSequenceClassification.from_pretrained(path, ignore_mismatched_sizes=True)
        self.tokenizer = BertTokenizer.from_pretrained(path)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.max_len = 512
        self.out_features = self.model.bert.encoder.layer[1].output.dense.out_features
        self.model.classifier = torch.nn.Linear(self.out_features, n_classes)
        self.model.to(self.device)

    
    def preparation(self, X_train, y_train, epochs):
        # create datasets
        self.train_set = CustomDataset(X_train, y_train, self.tokenizer)
        # create data loaders
        self.train_loader = DataLoader(self.train_set, batch_size=2, shuffle=True)
        # helpers initialization
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=2e-5,
            weight_decay=0.005,
            correct_bias=True
            )
        self.scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=500,
                num_training_steps=len(self.train_loader) * epochs
            )
        self.loss_fn = torch.nn.CrossEntropyLoss().to(self.device)


    def fit(self):
        self.model = self.model.train()
        losses = []
        correct_predictions = 0

        for data in tqdm(self.train_loader):
            input_ids = data["input_ids"].to(self.device)
            attention_mask = data["attention_mask"].to(self.device)
            targets = data["targets"].to(self.device)

            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask
                )

            preds = torch.argmax(outputs.logits, dim=1)
            loss = self.loss_fn(outputs.logits, targets)

            correct_predictions += torch.sum(preds == targets)

            losses.append(loss.item())
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()

        train_acc = correct_predictions.double() / len(self.train_set)
        train_loss = np.mean(losses)
        return train_acc, train_loss
    

    def train(self, X_train, y_train, X_test, y_test, epochs=1):
        print('*' * 10)
        print(f'Model: {self.path}')
        self.preparation(X_train, y_train, epochs)
        for epoch in range(epochs):
            print(f'Epoch {epoch + 1}/{epochs}')
            train_acc, train_loss = self.fit()
            print(f'Train loss {train_loss} accuracy {train_acc}')
            predictions_test = [self.predict(x) for x in X_test]
            precision, recall, f1score = precision_recall_fscore_support(y_test, predictions_test, average='macro')[:3]
            print('Test:')
            print(f'precision: {precision}, recall: {recall}, f1score: {f1score}')
        print('*' * 10)
    
    def predict(self, text):
        self.model = self.model.eval()
        encoding = self.tokenizer.encode_plus(
            str(text),
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        
        out = {
              'text': text,
              'input_ids': encoding['input_ids'].flatten(),
              'attention_mask': encoding['attention_mask'].flatten()
          }
        
        input_ids = out["input_ids"].to(self.device)
        attention_mask = out["attention_mask"].to(self.device)
        
        outputs = self.model(
            input_ids=input_ids.unsqueeze(0),
            attention_mask=attention_mask.unsqueeze(0)
        )
        
        prediction = torch.argmax(outputs.logits, dim=1).cpu().numpy()[0]

        return prediction

In [46]:
import random
torch.manual_seed(42)
random.seed(42)

In [47]:
classifier = BertClassifier(
    path='./bert_distill2',
    n_classes=2
)

Some weights of the model checkpoint at ./bert_distill2 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialize

In [48]:
EPOCHS = 20

In [49]:
classifier.train(
        X_train=list(train['text']),
        y_train=list(train['label']),
        X_test=list(test['text']),
        y_test=list(test['label']),
        epochs=EPOCHS
)

  0%|          | 0/5044 [00:00<?, ?it/s]

**********
Model: ./bert_distill2
Epoch 1/20


100%|██████████| 5044/5044 [17:30<00:00,  4.80it/s]


Train loss 0.6683227617041009 accuracy 0.6601903251387787


  _warn_prf(average, modifier, msg_start, len(result))
  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.3353376503237743, recall: 0.5, f1score: 0.4014396456256921
Epoch 2/20


100%|██████████| 5044/5044 [17:49<00:00,  4.71it/s]


Train loss 0.729729297435451 accuracy 0.7191712926249009


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.7681729016406436, recall: 0.7621881053855095, f1score: 0.7649781364306316
Epoch 3/20


100%|██████████| 5044/5044 [17:17<00:00,  4.86it/s]


Train loss 0.7564310771382682 accuracy 0.7555511498810468


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8141613260863687, recall: 0.7813347539713289, f1score: 0.7936747275399385
Epoch 4/20


100%|██████████| 5044/5044 [17:13<00:00,  4.88it/s]


Train loss 0.7300101639184079 accuracy 0.7771609833465504


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8405535613973284, recall: 0.7771638899651299, f1score: 0.7967529998309955
Epoch 5/20


100%|██████████| 5044/5044 [17:15<00:00,  4.87it/s]


Train loss 0.6956161043266408 accuracy 0.799167327517843


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8277556555863342, recall: 0.830066834560248, f1score: 0.8288871837191683
Epoch 6/20


100%|██████████| 5044/5044 [17:17<00:00,  4.86it/s]


Train loss 0.6624808436074341 accuracy 0.8150277557494052


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8351799475323823, recall: 0.8168171251452925, f1score: 0.8247337978504892
Epoch 7/20


100%|██████████| 5044/5044 [17:17<00:00,  4.86it/s]


Train loss 0.6560842920759709 accuracy 0.8194885011895321


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8601009433130565, recall: 0.8165149166989538, f1score: 0.8325768928193025
Epoch 8/20


100%|██████████| 5044/5044 [17:17<00:00,  4.86it/s]


Train loss 0.6049729652261197 accuracy 0.8394131641554322


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8631576266870384, recall: 0.7762727624951569, f1score: 0.8000584083718667
Epoch 9/20


100%|██████████| 5044/5044 [17:42<00:00,  4.75it/s]


Train loss 0.6021924147772983 accuracy 0.8432791435368755


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8621837236374508, recall: 0.842649166989539, f1score: 0.8511372745926852
Epoch 10/20


100%|██████████| 5044/5044 [17:45<00:00,  4.74it/s]


Train loss 0.57517073573522 accuracy 0.8544805709754163


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8685523408806984, recall: 0.8517657884540876, f1score: 0.8592223142409016
Epoch 11/20


100%|██████████| 5044/5044 [17:41<00:00,  4.75it/s]


Train loss 0.5847537535920411 accuracy 0.8546788263283108


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8656447832630203, recall: 0.8588889965129795, f1score: 0.8621037004453731
Epoch 12/20


100%|██████████| 5044/5044 [17:52<00:00,  4.70it/s]


Train loss 0.5427798054291916 accuracy 0.8647898493259318


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8710243499870183, recall: 0.8100561797752809, f1score: 0.8303941257687963
Epoch 13/20


100%|██████████| 5044/5044 [17:44<00:00,  4.74it/s]


Train loss 0.536667596598106 accuracy 0.8679619349722443


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8830348901811591, recall: 0.8332303370786517, f1score: 0.8513734695564252
Epoch 14/20


100%|██████████| 5044/5044 [17:43<00:00,  4.74it/s]


Train loss 0.5272144959559664 accuracy 0.869746233148295


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.874359849425002, recall: 0.8506305695466874, f1score: 0.8607292665388303
Epoch 15/20


100%|██████████| 5044/5044 [17:57<00:00,  4.68it/s]


Train loss 0.5116586265004787 accuracy 0.8755947660586836


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.861475475227245, recall: 0.8642754746222394, f1score: 0.8628448892410597
Epoch 16/20


100%|██████████| 5044/5044 [17:55<00:00,  4.69it/s]


Train loss 0.5090549799665146 accuracy 0.8786677240285488


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8832916807509656, recall: 0.8349922510654786, f1score: 0.8527557948187865
Epoch 17/20


100%|██████████| 5044/5044 [17:47<00:00,  4.72it/s]


Train loss 0.5044575620025782 accuracy 0.8795598731165741


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8767755326597979, recall: 0.8502479659046881, f1score: 0.8613535980248338
Epoch 18/20


100%|██████████| 5044/5044 [17:57<00:00,  4.68it/s]


Train loss 0.496105151217204 accuracy 0.8831284694686756


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8834819125412914, recall: 0.8448740798140255, f1score: 0.8599443191595897
Epoch 19/20


100%|██████████| 5044/5044 [17:45<00:00,  4.73it/s]


Train loss 0.5021169058862165 accuracy 0.880650277557494


  0%|          | 0/5044 [00:00<?, ?it/s]

Test:
precision: 0.8853192350285373, recall: 0.8505046493607129, f1score: 0.8644356457672802
Epoch 20/20


100%|██████████| 5044/5044 [17:46<00:00,  4.73it/s]


Train loss 0.4818639386245342 accuracy 0.8856066613798572
Test:
precision: 0.8811343031836391, recall: 0.8501975978302982, f1score: 0.862833836858006
**********
