In [1]:
import sys

In [2]:
sys.path.append('..')

In [3]:
import os
from src.training_utils import train_evaluate, predict_metrics
from src.BertCLF import BertCLF
from transformers import AutoModel
from transformers import AutoTokenizer
import torch
import torch.optim as optim
import torch.nn as nn
import json
from src.preparing_data_utils import prepare_data, prepare_data_notebook, prepare_dataset
import pandas as pd

In [17]:
config = dict(
    transformer_model = dict(
        model = "cointegrated/rubert-tiny",
        path_to_state_dict = False,
        device = 'cuda',
        dropout = 0.2,
        tiny_bert = True, 
        learning_rate = 1e-6,
        batch_size = 8,
        shuffle = True,
        maxlen = 512,
    ),
    data = dict(
        train_data_path = "../../gvk_dnie_one_list.xlsx",
        test_data_path = None,
        text_column = "Комментарий",
        target_column = "target",
        random_state = 42,
        test_size = 0.3,
    ),
    training = dict (
    num_epochs = 20,
    average_f1 = 'macro',
    output_dir = "../results/"
    )
)

In [19]:
os.makedirs(config['training']['output_dir'], exist_ok=True)

In [6]:
device = torch.device(config['transformer_model']['device'])
tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path=config['transformer_model']["model"]
    )
model_bert = AutoModel.from_pretrained(
    pretrained_model_name_or_path=config['transformer_model']["model"]
).to(device)

Here you can either use `prepare_data` function if you want just pass the
path to your data or you can pass the data itself in the form of pandas dataframe. In the latter case
you should use `prepare_data_notebook` function

In [7]:
id2label, train_texts, valid_texts, train_targets, valid_targets = prepare_data(config=config)

In [8]:
df = pd.read_excel("../../gvk_dnie_one_list.xlsx")

id2label, train_texts, valid_texts, train_targets, valid_targets = prepare_data_notebook(
    config=config,
    train_df=df
)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return func(*args, **kwargs)


In [9]:
model = BertCLF(
    pretrained_model=model_bert,
    tokenizer=tokenizer,
    id2label=id2label,
    dropout=config['transformer_model']['dropout'],
    tiny=config['transformer_model']['tiny_bert'],
    device=device     
    )

In [10]:
model = model.to(device)

In [11]:
optimizer = optim.Adam(model.parameters(), lr=float(config['transformer_model']['learning_rate']))
criterion = nn.NLLLoss()

training_generator, valid_generator = prepare_dataset(
    tokenizer=tokenizer,
    train_texts=train_texts,
    train_targets=train_targets,
    valid_texts=valid_texts,
    valid_targets=valid_targets,
    config=config
)

In [13]:
model = train_evaluate(
    model=model,
    training_generator=training_generator,
    valid_generator=valid_generator,
    criterion=criterion,
    optimizer=optimizer,
    num_epocs=config['training']['num_epochs'],
    average=config['training']['average_f1']
)

==== Epoch 1 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:03<00:00, 34.50it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 286.20it/s]


Train F1: 0.06051076238576238
Eval F1: 0.07499860763018658

==== Epoch 2 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:03<00:00, 34.83it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 259.53it/s]


Train F1: 0.1008680054513388
Eval F1: 0.12719808781212288

==== Epoch 3 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:03<00:00, 34.96it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 289.61it/s]


Train F1: 0.1406148607284971
Eval F1: 0.21539699115513652

==== Epoch 4 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:03<00:00, 34.60it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 279.32it/s]


Train F1: 0.19148884875345915
Eval F1: 0.30956836535783905

==== Epoch 5 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:03<00:00, 35.03it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 288.45it/s]


Train F1: 0.25343331271902697
Eval F1: 0.3424190104891859

==== Epoch 6 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:03<00:00, 34.97it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 289.86it/s]


Train F1: 0.2739632604648839
Eval F1: 0.3136508780368429

==== Epoch 7 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:03<00:00, 34.97it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 289.01it/s]


Train F1: 0.30947806892882657
Eval F1: 0.3212433862433863

==== Epoch 8 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:03<00:00, 35.01it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 289.53it/s]


Train F1: 0.305180595218474
Eval F1: 0.34720709602288546

==== Epoch 9 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:03<00:00, 33.77it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 107.68it/s]


Train F1: 0.30203076945501195
Eval F1: 0.3508163928339367

==== Epoch 10 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:09<00:00, 13.69it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 91.31it/s]


Train F1: 0.3215376437535528
Eval F1: 0.3395225648295824

==== Epoch 11 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:09<00:00, 13.65it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 94.08it/s] 


Train F1: 0.33528779862113195
Eval F1: 0.32657455929385754

==== Epoch 12 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:09<00:00, 13.66it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 94.60it/s] 


Train F1: 0.33495664408543196
Eval F1: 0.3454079643553328

==== Epoch 13 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:09<00:00, 13.66it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 91.35it/s]


Train F1: 0.3423952730770913
Eval F1: 0.33387705536828344

==== Epoch 14 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:09<00:00, 13.69it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 90.84it/s]


Train F1: 0.34802808022504994
Eval F1: 0.35937617613056205

==== Epoch 15 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:09<00:00, 13.70it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 90.64it/s]


Train F1: 0.3458022850365923
Eval F1: 0.3747466941765188

==== Epoch 16 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:09<00:00, 13.69it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 91.66it/s]


Train F1: 0.3565292080254201
Eval F1: 0.3842119608786275

==== Epoch 17 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:09<00:00, 13.70it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 94.72it/s]


Train F1: 0.360635495456924
Eval F1: 0.40286249534369833

==== Epoch 18 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:09<00:00, 13.66it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 90.64it/s]


Train F1: 0.386904197093591
Eval F1: 0.39037813388690584

==== Epoch 19 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:09<00:00, 13.66it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 99.12it/s] 


Train F1: 0.37262078690108996
Eval F1: 0.4055525598508055

==== Epoch 20 out of 20 ====


Training loop: 100%|██████████| 132/132 [00:09<00:00, 13.68it/s]
Evaluating loop: 100%|██████████| 57/57 [00:00<00:00, 93.48it/s] 


Train F1: 0.40402298599268294
Eval F1: 0.41354706209969366




Computing final metrics...: 100%|██████████| 57/57 [00:00<00:00, 149.03it/s]

                                          precision    recall  f1-score   support

                             КОМПЕНСАЦИЯ       0.62      0.84      0.71       128
                      ПРОБЛЕМЫ С ЛИМИТОМ       0.00      0.00      0.00        37
                      Платформа СберДруг       0.55      0.52      0.53        93
                          Положительный        0.00      0.00      0.00         7
Процесс неизвестен, не дали использовать       0.00      0.00      0.00        30
                            СОГЛАСОВАНИЕ       0.00      0.00      0.00        33
                 ТАКСИ, ВОДИТЕЛИ, МАШИНЫ       0.54      0.82      0.65       125

                                accuracy                           0.57       453
                               macro avg       0.24      0.31      0.27       453
                            weighted avg       0.44      0.57      0.49       453




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [14]:
predict_metrics(model=model, iterator=valid_generator)

Computing final metrics...: 100%|██████████| 57/57 [00:00<00:00, 150.58it/s]

                                          precision    recall  f1-score   support

                             КОМПЕНСАЦИЯ       0.62      0.84      0.71       128
                      ПРОБЛЕМЫ С ЛИМИТОМ       0.00      0.00      0.00        37
                      Платформа СберДруг       0.55      0.52      0.53        93
                          Положительный        0.00      0.00      0.00         7
Процесс неизвестен, не дали использовать       0.00      0.00      0.00        30
                            СОГЛАСОВАНИЕ       0.00      0.00      0.00        33
                 ТАКСИ, ВОДИТЕЛИ, МАШИНЫ       0.54      0.82      0.65       125

                                accuracy                           0.57       453
                               macro avg       0.24      0.31      0.27       453
                            weighted avg       0.44      0.57      0.49       453




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [20]:
torch.save(model.state_dict(), os.path.join(config['training']['output_dir'], "model"))
with open(os.path.join(config['training']['output_dir'], 'label_mapper.json'), mode='w', encoding='utf-8') as f:
        json.dump(model.mapper, f, indent=4, ensure_ascii=False)