In [25]:
!nvidia-smi

Sun Nov  3 23:11:09 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla V100-DGXS...  On   | 00000000:07:00.0 Off |                    0 |
| N/A   42C    P0    65W / 300W |  27725MiB / 32478MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-DGXS...  On   | 00000000:08:00.0 Off |                    0 |
| N/A   40C    P0    51W / 300W |   1920MiB / 32478MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-DGXS...  On   | 00000000:0E:00.0 Off |                    0 |
| N/A   

In [26]:
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import sys

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
import logging


logger = logging.getLogger('sequence_tagger_bert')

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

logger.handlers = []

fhandler = logging.handlers.TimedRotatingFileHandler(filename='logs.txt', when='midnight')
fhandler.setFormatter(formatter)
logger.addHandler(fhandler)

handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger.addHandler(handler)

logger.setLevel(logging.DEBUG)

In [28]:
import torch

device = torch.device('cuda')
n_gpu = torch.cuda.device_count()

for i in range(n_gpu):
    print(torch.cuda.get_device_name(i))

Tesla V100-DGXS-32GB


In [29]:
CACHE_DIR = '../workdir/cache'
BATCH_SIZE = 16
#BATCH_SIZE = 8
PRED_BATCH_SIZE = 1000
MAX_LEN = 128
MAX_N_EPOCHS = 20
#MAX_N_EPOCHS = 100
#MAX_N_EPOCHS = 50
#MAX_N_EPOCHS = 10
REDUCE_ON_PLATEAU = False
WEIGHT_DECAY = 0.01
LEARNING_RATE = 3e-7
#LEARNING_RATE = 1e-5
#LEARNING_RATE = 2e-5

In [9]:
!ls /workspace/bert_sequence_tagger/src/data/NER/Varvara_v1/

dev_aspect.tsv	   test_aspect.tsv     train_aspect.tsv
dev_pred_full.tsv  test_pred_full.tsv  train_pred_full.tsv


In [30]:
from flair.datasets import ColumnCorpus


data_folder = '/workspace/bert_sequence_tagger/src/data/NER/Varvara_v3'
corpus = ColumnCorpus(data_folder, 
                      {0 : 'text', 1 : 'tag'},
                      train_file='train_pred_full.tsv',
                      test_file='test_pred_full.tsv',
                      dev_file='dev_pred_full.tsv')

print(corpus.obtain_statistics())

print(corpus.obtain_statistics())

2019-11-03 23:11:51,229 Reading data from /workspace/bert_sequence_tagger/src/data/NER/Varvara_v3
2019-11-03 23:11:51,230 Train: /workspace/bert_sequence_tagger/src/data/NER/Varvara_v3/train_pred_full.tsv
2019-11-03 23:11:51,230 Dev: /workspace/bert_sequence_tagger/src/data/NER/Varvara_v3/dev_pred_full.tsv
2019-11-03 23:11:51,231 Test: /workspace/bert_sequence_tagger/src/data/NER/Varvara_v3/test_pred_full.tsv
{
    "TRAIN": {
        "dataset": "TRAIN",
        "total_number_of_documents": 3077,
        "number_of_documents_per_class": {},
        "number_of_tokens_per_tag": {},
        "number_of_tokens": {
            "total": 89350,
            "min": 6,
            "max": 106,
            "avg": 29.038024049398764
        }
    },
    "TEST": {
        "dataset": "TEST",
        "total_number_of_documents": 488,
        "number_of_documents_per_class": {},
        "number_of_tokens_per_tag": {},
        "number_of_tokens": {
            "total": 13736,
            "min": 6,
       

In [31]:
a = corpus.make_tag_dictionary(tag_type = 'tag')

In [37]:
a.idx2item

[b'<unk>',
 b'O',
 b'B-OBJ',
 b'B-PREDFULL',
 b'I-PREDFULL',
 b'I-OBJ',
 b'<START>',
 b'<STOP>']

In [33]:
from bert_sequence_tagger import SequenceTaggerBert, BertForTokenClassificationCustom
from pytorch_transformers import BertTokenizer, BertForTokenClassification
import torch.nn as nn

from bert_sequence_tagger.bert_utils import make_bert_tag_dict_from_flair_corpus


bpe_tokenizer = BertTokenizer.from_pretrained('bert-base-cased', cache_dir=None, do_lower_case=False)

idx2tag, tag2idx = make_bert_tag_dict_from_flair_corpus(corpus)

model = nn.DataParallel(BertForTokenClassificationCustom.from_pretrained('bert-base-cased', cache_dir=None, num_labels=len(tag2idx))).cuda()
#model = BertForTokenClassification.from_pretrained('bert-base-cased', cache_dir=CACHE_DIR, num_labels=len(tag2idx)).cuda()

seq_tagger = SequenceTaggerBert(bert_model=model, bpe_tokenizer=bpe_tokenizer, idx2tag=idx2tag, tag2idx=tag2idx, max_len=200)

In [38]:
BATCH_SIZE = 10
PRED_BATCH_SIZE = 10
MAX_N_EPOCHS = 20
from torch.utils.data import RandomSampler, SequentialSampler

from bert_sequence_tagger.bert_utils import create_loader_from_flair_corpus, get_parameters_without_decay
from bert_sequence_tagger.model_trainer_bert import ModelTrainerBert

from pytorch_transformers import AdamW, WarmupLinearSchedule

from bert_sequence_tagger.metrics import f1_entity_level, f1_token_level

test_dataloader = create_loader_from_flair_corpus(corpus.test,
                                                  SequentialSampler,
                                                  batch_size=PRED_BATCH_SIZE)

_, __, test_metrics = seq_tagger.predict(test_dataloader, evaluate=True, 
                                         metrics=[f1_entity_level, f1_token_level])
logger.info(f'Entity-level f1: {test_metrics[1]}')
logger.info(f'Token-level f1: {test_metrics[2]}')

train_dataloader = create_loader_from_flair_corpus(corpus.train, 
                                                   RandomSampler, 
                                                   batch_size=BATCH_SIZE)
val_dataloader = create_loader_from_flair_corpus(corpus.dev,
                                                 SequentialSampler,
                                                 batch_size=PRED_BATCH_SIZE)

optimizer = AdamW(get_parameters_without_decay(model), lr=LEARNING_RATE, betas=(0.9, 0.999), 
                  eps =1e-6, weight_decay=0.01, correct_bias=True)
lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=0.1, 
                                    t_total=(len(corpus.train) / BATCH_SIZE)*MAX_N_EPOCHS)
trainer = ModelTrainerBert(model=seq_tagger, 
                           optimizer=optimizer, 
                           lr_scheduler=lr_scheduler,
                           train_dataloader=train_dataloader, 
                           val_dataloader=val_dataloader,
                           update_scheduler='es',
                           keep_best_model=True,
                           restore_bm_on_lr_change=False,
                           max_grad_norm=1.,
                           validation_metrics=[f1_entity_level],
                           decision_metric=lambda metrics: -metrics[1])

trainer.train(epochs=MAX_N_EPOCHS)


test_dataloader = create_loader_from_flair_corpus(corpus.test,
                                                  SequentialSampler,
                                                  batch_size=PRED_BATCH_SIZE)

_, __, test_metrics = seq_tagger.predict(test_dataloader, evaluate=True, 
                                         metrics=[f1_entity_level, f1_token_level])
logger.info(f'Entity-level f1: {test_metrics[1]}')
logger.info(f'Token-level f1: {test_metrics[2]}')

2019-11-04 00:15:05,500 - sequence_tagger_bert - INFO - Entity-level f1: 0.5359922178988327
2019-11-04 00:15:05,502 - sequence_tagger_bert - INFO - Token-level f1: 0.5660714285714287





Epoch:   0%|          | 0/20 [00:00<?, ?it/s][A[A[A

2019-11-04 00:16:12,843 - sequence_tagger_bert - INFO - Train loss: 0.02296537380097071
2019-11-04 00:16:15,939 - sequence_tagger_bert - INFO - Validation loss: 0.5064210319641705
2019-11-04 00:16:15,940 - sequence_tagger_bert - INFO - Validation metrics: (0.4930847865303668,)
2019-11-04 00:16:15,965 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:   5%|▌         | 1/20 [01:10<22:17, 70.38s/it][A[A[A

2019-11-04 00:17:23,664 - sequence_tagger_bert - INFO - Train loss: 0.02211489531630348
2019-11-04 00:17:26,546 - sequence_tagger_bert - INFO - Validation loss: 0.5068650401946975
2019-11-04 00:17:26,547 - sequence_tagger_bert - INFO - Validation metrics: (0.5014961101137044,)
2019-11-04 00:17:26,565 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  10%|█         | 2/20 [02:20<21:08, 70.45s/it][A[A[A

2019-11-04 00:18:33,674 - sequence_tagger_bert - INFO - Train loss: 0.022489066854083015
2019-11-04 00:18:36,804 - sequence_tagger_bert - INFO - Validation loss: 0.5107156120058967
2019-11-04 00:18:36,805 - sequence_tagger_bert - INFO - Validation metrics: (0.49489489489489497,)
2019-11-04 00:18:36,806 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  15%|█▌        | 3/20 [03:31<19:56, 70.38s/it][A[A[A

2019-11-04 00:19:43,636 - sequence_tagger_bert - INFO - Train loss: 0.021868705845894534
2019-11-04 00:19:46,608 - sequence_tagger_bert - INFO - Validation loss: 0.5176699931744668
2019-11-04 00:19:46,609 - sequence_tagger_bert - INFO - Validation metrics: (0.49063444108761334,)
2019-11-04 00:19:46,609 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  20%|██        | 4/20 [04:41<18:43, 70.21s/it][A[A[A

2019-11-04 00:20:53,377 - sequence_tagger_bert - INFO - Train loss: 0.021474948615159282
2019-11-04 00:20:56,415 - sequence_tagger_bert - INFO - Validation loss: 0.5129166986145897
2019-11-04 00:20:56,416 - sequence_tagger_bert - INFO - Validation metrics: (0.4993997599039616,)
2019-11-04 00:20:56,417 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  25%|██▌       | 5/20 [05:50<17:31, 70.09s/it][A[A[A

2019-11-04 00:22:03,707 - sequence_tagger_bert - INFO - Train loss: 0.02082470166247471
2019-11-04 00:22:06,635 - sequence_tagger_bert - INFO - Validation loss: 0.5123378931624224
2019-11-04 00:22:06,637 - sequence_tagger_bert - INFO - Validation metrics: (0.4969987995198079,)
2019-11-04 00:22:06,637 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  30%|███       | 6/20 [07:01<16:21, 70.13s/it][A[A[A

2019-11-04 00:23:13,760 - sequence_tagger_bert - INFO - Train loss: 0.020918038098047813
2019-11-04 00:23:16,734 - sequence_tagger_bert - INFO - Validation loss: 0.5116124798986698
2019-11-04 00:23:16,735 - sequence_tagger_bert - INFO - Validation metrics: (0.4994026284348865,)
2019-11-04 00:23:16,736 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  35%|███▌      | 7/20 [08:11<15:11, 70.12s/it][A[A[A

2019-11-04 00:24:23,733 - sequence_tagger_bert - INFO - Train loss: 0.020607874843642696
2019-11-04 00:24:26,704 - sequence_tagger_bert - INFO - Validation loss: 0.5216088233822275
2019-11-04 00:24:26,705 - sequence_tagger_bert - INFO - Validation metrics: (0.49849488260084285,)
2019-11-04 00:24:26,706 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  40%|████      | 8/20 [09:21<14:00, 70.07s/it][A[A[A

2019-11-04 00:25:34,201 - sequence_tagger_bert - INFO - Train loss: 0.020006767906136688
2019-11-04 00:25:37,417 - sequence_tagger_bert - INFO - Validation loss: 0.5229435328581575
2019-11-04 00:25:37,418 - sequence_tagger_bert - INFO - Validation metrics: (0.49969969969969963,)
2019-11-04 00:25:37,419 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  45%|████▌     | 9/20 [10:31<12:52, 70.27s/it][A[A[A

2019-11-04 00:26:45,107 - sequence_tagger_bert - INFO - Train loss: 0.02056950890812902
2019-11-04 00:26:48,139 - sequence_tagger_bert - INFO - Validation loss: 0.5140314608948623
2019-11-04 00:26:48,140 - sequence_tagger_bert - INFO - Validation metrics: (0.5035714285714286,)
2019-11-04 00:26:48,160 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  50%|█████     | 10/20 [11:42<11:44, 70.41s/it][A[A[A

2019-11-04 00:27:54,893 - sequence_tagger_bert - INFO - Train loss: 0.019477534391342772
2019-11-04 00:27:57,966 - sequence_tagger_bert - INFO - Validation loss: 0.5256900111548375
2019-11-04 00:27:57,967 - sequence_tagger_bert - INFO - Validation metrics: (0.49939540507859737,)
2019-11-04 00:27:57,969 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  55%|█████▌    | 11/20 [12:52<10:32, 70.23s/it][A[A[A

2019-11-04 00:29:05,567 - sequence_tagger_bert - INFO - Train loss: 0.020042150701674356
2019-11-04 00:29:08,487 - sequence_tagger_bert - INFO - Validation loss: 0.5230054469557661
2019-11-04 00:29:08,488 - sequence_tagger_bert - INFO - Validation metrics: (0.4984984984984986,)
2019-11-04 00:29:08,489 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  60%|██████    | 12/20 [14:02<09:22, 70.32s/it][A[A[A

2019-11-04 00:30:14,667 - sequence_tagger_bert - INFO - Train loss: 0.020259920838016823
2019-11-04 00:30:17,684 - sequence_tagger_bert - INFO - Validation loss: 0.524547247911758
2019-11-04 00:30:17,685 - sequence_tagger_bert - INFO - Validation metrics: (0.4969843184559712,)
2019-11-04 00:30:17,686 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  65%|██████▌   | 13/20 [15:12<08:09, 69.98s/it][A[A[A

2019-11-04 00:31:25,581 - sequence_tagger_bert - INFO - Train loss: 0.01906152820296271
2019-11-04 00:31:28,617 - sequence_tagger_bert - INFO - Validation loss: 0.5232003611688525
2019-11-04 00:31:28,618 - sequence_tagger_bert - INFO - Validation metrics: (0.5002999400119976,)
2019-11-04 00:31:28,619 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  70%|███████   | 14/20 [16:23<07:01, 70.27s/it][A[A[A

2019-11-04 00:32:36,391 - sequence_tagger_bert - INFO - Train loss: 0.019067458004098047
2019-11-04 00:32:39,509 - sequence_tagger_bert - INFO - Validation loss: 0.5221355129288873
2019-11-04 00:32:39,511 - sequence_tagger_bert - INFO - Validation metrics: (0.5044829647340108,)
2019-11-04 00:32:39,530 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  75%|███████▌  | 15/20 [17:33<05:52, 70.46s/it][A[A[A

2019-11-04 00:33:46,821 - sequence_tagger_bert - INFO - Train loss: 0.019810797411921228
2019-11-04 00:33:49,890 - sequence_tagger_bert - INFO - Validation loss: 0.5217976092769787
2019-11-04 00:33:49,891 - sequence_tagger_bert - INFO - Validation metrics: (0.502092050209205,)
2019-11-04 00:33:49,892 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  80%|████████  | 16/20 [18:44<04:41, 70.43s/it][A[A[A

2019-11-04 00:34:58,436 - sequence_tagger_bert - INFO - Train loss: 0.01922515010787779
2019-11-04 00:35:01,669 - sequence_tagger_bert - INFO - Validation loss: 0.5251691339639719
2019-11-04 00:35:01,670 - sequence_tagger_bert - INFO - Validation metrics: (0.4987980769230769,)
2019-11-04 00:35:01,671 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  85%|████████▌ | 17/20 [19:56<03:32, 70.83s/it][A[A[A

2019-11-04 00:36:11,156 - sequence_tagger_bert - INFO - Train loss: 0.019155079037005494
2019-11-04 00:36:14,257 - sequence_tagger_bert - INFO - Validation loss: 0.5244880779788307
2019-11-04 00:36:14,258 - sequence_tagger_bert - INFO - Validation metrics: (0.5026929982046678,)
2019-11-04 00:36:14,259 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  90%|█████████ | 18/20 [21:08<02:22, 71.36s/it][A[A[A

2019-11-04 00:37:24,729 - sequence_tagger_bert - INFO - Train loss: 0.019374078425409443
2019-11-04 00:37:27,703 - sequence_tagger_bert - INFO - Validation loss: 0.5228917853106041
2019-11-04 00:37:27,704 - sequence_tagger_bert - INFO - Validation metrics: (0.5023866348448688,)
2019-11-04 00:37:27,705 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch:  95%|█████████▌| 19/20 [22:22<01:11, 71.99s/it][A[A[A

2019-11-04 00:38:37,752 - sequence_tagger_bert - INFO - Train loss: 0.019024234159634897
2019-11-04 00:38:41,119 - sequence_tagger_bert - INFO - Validation loss: 0.5231161052730252
2019-11-04 00:38:41,121 - sequence_tagger_bert - INFO - Validation metrics: (0.5014925373134328,)
2019-11-04 00:38:41,122 - sequence_tagger_bert - INFO - Current learning rate: 0.0





Epoch: 100%|██████████| 20/20 [23:35<00:00, 72.42s/it][A[A[A


[A[A[A

2019-11-04 00:38:45,008 - sequence_tagger_bert - INFO - Entity-level f1: 0.5389105058365758
2019-11-04 00:38:45,010 - sequence_tagger_bert - INFO - Token-level f1: 0.5688809629959874


In [34]:
#BATCH_SIZE = 10
#PRED_BATCH_SIZE = 10
#MAX_N_EPOCHS = 20
from torch.utils.data import RandomSampler, SequentialSampler

from bert_sequence_tagger.bert_utils import create_loader_from_flair_corpus, get_parameters_without_decay
from bert_sequence_tagger.model_trainer_bert import ModelTrainerBert

from pytorch_transformers import AdamW, WarmupLinearSchedule

from bert_sequence_tagger.metrics import f1_entity_level, f1_token_level

for LEARNING_RATE in [3e-6, 3e-7, 3e-4, 1e-5]:

        train_dataloader = create_loader_from_flair_corpus(corpus.train, 
                                                           RandomSampler, 
                                                           batch_size=BATCH_SIZE)
        val_dataloader = create_loader_from_flair_corpus(corpus.dev,
                                                         SequentialSampler,
                                                         batch_size=PRED_BATCH_SIZE)

        optimizer = AdamW(get_parameters_without_decay(model), lr=LEARNING_RATE, betas=(0.9, 0.999), 
                          eps =1e-6, weight_decay=0.01, correct_bias=True)
        lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=0.1, 
                                            t_total=(len(corpus.train) / BATCH_SIZE)*MAX_N_EPOCHS)
        trainer = ModelTrainerBert(model=seq_tagger, 
                                   optimizer=optimizer, 
                                   lr_scheduler=lr_scheduler,
                                   train_dataloader=train_dataloader, 
                                   val_dataloader=val_dataloader,
                                   update_scheduler='es',
                                   keep_best_model=True,
                                   restore_bm_on_lr_change=False,
                                   max_grad_norm=1.,
                                   validation_metrics=[f1_entity_level],
                                   decision_metric=lambda metrics: -metrics[1])

        trainer.train(epochs=MAX_N_EPOCHS)


        test_dataloader = create_loader_from_flair_corpus(corpus.test,
                                                          SequentialSampler,
                                                          batch_size=PRED_BATCH_SIZE)

        _, __, test_metrics = seq_tagger.predict(test_dataloader, evaluate=True, 
                                                 metrics=[f1_entity_level, f1_token_level])
        logger.info(f'Entity-level f1: {test_metrics[1]}')
        logger.info(f'Token-level f1: {test_metrics[2]}')


Epoch:   0%|          | 0/20 [00:00<?, ?it/s][A

2019-11-03 23:19:57,164 - sequence_tagger_bert - INFO - Train loss: 0.4787870369504153
2019-11-03 23:20:00,145 - sequence_tagger_bert - INFO - Validation loss: 0.28253188729286194
2019-11-03 23:20:00,146 - sequence_tagger_bert - INFO - Validation metrics: (0.3241469816272966,)
2019-11-03 23:20:00,166 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:   5%|▌         | 1/20 [00:56<17:50, 56.34s/it][A

2019-11-03 23:20:53,284 - sequence_tagger_bert - INFO - Train loss: 0.10299811927225305
2019-11-03 23:20:56,507 - sequence_tagger_bert - INFO - Validation loss: 0.32408687472343445
2019-11-03 23:20:56,508 - sequence_tagger_bert - INFO - Validation metrics: (0.4444444444444445,)
2019-11-03 23:20:56,532 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  10%|█         | 2/20 [01:52<16:54, 56.35s/it][A

2019-11-03 23:21:50,103 - sequence_tagger_bert - INFO - Train loss: 0.0721234318756841
2019-11-03 23:21:53,137 - sequence_tagger_bert - INFO - Validation loss: 0.36886006593704224
2019-11-03 23:21:53,138 - sequence_tagger_bert - INFO - Validation metrics: (0.4554334554334554,)
2019-11-03 23:21:53,166 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  15%|█▌        | 3/20 [02:49<15:59, 56.44s/it][A

2019-11-03 23:22:46,781 - sequence_tagger_bert - INFO - Train loss: 0.05845014256825719
2019-11-03 23:22:49,799 - sequence_tagger_bert - INFO - Validation loss: 0.4011937975883484
2019-11-03 23:22:49,800 - sequence_tagger_bert - INFO - Validation metrics: (0.4401215805471124,)
2019-11-03 23:22:49,801 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  20%|██        | 4/20 [03:45<15:03, 56.50s/it][A

2019-11-03 23:23:43,306 - sequence_tagger_bert - INFO - Train loss: 0.051256966424883955
2019-11-03 23:23:46,332 - sequence_tagger_bert - INFO - Validation loss: 0.4138947129249573
2019-11-03 23:23:46,333 - sequence_tagger_bert - INFO - Validation metrics: (0.4666666666666667,)
2019-11-03 23:23:46,352 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  25%|██▌       | 5/20 [04:42<14:07, 56.51s/it][A

2019-11-03 23:24:40,036 - sequence_tagger_bert - INFO - Train loss: 0.04522083744097868
2019-11-03 23:24:43,204 - sequence_tagger_bert - INFO - Validation loss: 0.4281786382198334
2019-11-03 23:24:43,205 - sequence_tagger_bert - INFO - Validation metrics: (0.46200980392156865,)
2019-11-03 23:24:43,206 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  30%|███       | 6/20 [05:39<13:12, 56.61s/it][A

2019-11-03 23:25:37,547 - sequence_tagger_bert - INFO - Train loss: 0.04169752058896376
2019-11-03 23:25:40,895 - sequence_tagger_bert - INFO - Validation loss: 0.40968552231788635
2019-11-03 23:25:40,896 - sequence_tagger_bert - INFO - Validation metrics: (0.4748371817643576,)
2019-11-03 23:25:40,916 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  35%|███▌      | 7/20 [06:37<12:20, 56.94s/it][A

2019-11-03 23:26:34,479 - sequence_tagger_bert - INFO - Train loss: 0.0381972702857049
2019-11-03 23:26:37,556 - sequence_tagger_bert - INFO - Validation loss: 0.4394214153289795
2019-11-03 23:26:37,557 - sequence_tagger_bert - INFO - Validation metrics: (0.47368421052631576,)
2019-11-03 23:26:37,559 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  40%|████      | 8/20 [07:33<11:22, 56.85s/it][A

2019-11-03 23:27:31,507 - sequence_tagger_bert - INFO - Train loss: 0.035292348071662566
2019-11-03 23:27:34,871 - sequence_tagger_bert - INFO - Validation loss: 0.42749372124671936
2019-11-03 23:27:34,872 - sequence_tagger_bert - INFO - Validation metrics: (0.48184019370460046,)
2019-11-03 23:27:34,891 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  45%|████▌     | 9/20 [08:31<10:26, 57.00s/it][A

2019-11-03 23:28:28,937 - sequence_tagger_bert - INFO - Train loss: 0.03331584910442329
2019-11-03 23:28:31,938 - sequence_tagger_bert - INFO - Validation loss: 0.4554072320461273
2019-11-03 23:28:31,939 - sequence_tagger_bert - INFO - Validation metrics: (0.4718137254901961,)
2019-11-03 23:28:31,940 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  50%|█████     | 10/20 [09:28<09:30, 57.01s/it][A

2019-11-03 23:29:25,538 - sequence_tagger_bert - INFO - Train loss: 0.031775198278495065
2019-11-03 23:29:28,536 - sequence_tagger_bert - INFO - Validation loss: 0.44853276014328003
2019-11-03 23:29:28,537 - sequence_tagger_bert - INFO - Validation metrics: (0.48554216867469885,)
2019-11-03 23:29:28,562 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  55%|█████▌    | 11/20 [10:24<08:32, 56.90s/it][A

2019-11-03 23:30:22,827 - sequence_tagger_bert - INFO - Train loss: 0.029974783775565537
2019-11-03 23:30:25,918 - sequence_tagger_bert - INFO - Validation loss: 0.4577467739582062
2019-11-03 23:30:25,919 - sequence_tagger_bert - INFO - Validation metrics: (0.4829683698296838,)
2019-11-03 23:30:25,920 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  60%|██████    | 12/20 [11:22<07:36, 57.03s/it][A

2019-11-03 23:31:19,600 - sequence_tagger_bert - INFO - Train loss: 0.02827518145249761
2019-11-03 23:31:22,733 - sequence_tagger_bert - INFO - Validation loss: 0.4715800881385803
2019-11-03 23:31:22,734 - sequence_tagger_bert - INFO - Validation metrics: (0.48842874543239956,)
2019-11-03 23:31:22,757 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  65%|██████▌   | 13/20 [12:18<06:38, 56.97s/it][A

2019-11-03 23:32:17,406 - sequence_tagger_bert - INFO - Train loss: 0.027177215616551707
2019-11-03 23:32:20,518 - sequence_tagger_bert - INFO - Validation loss: 0.4620170593261719
2019-11-03 23:32:20,519 - sequence_tagger_bert - INFO - Validation metrics: (0.49432835820895515,)
2019-11-03 23:32:20,542 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  70%|███████   | 14/20 [13:16<05:43, 57.22s/it][A

2019-11-03 23:33:14,775 - sequence_tagger_bert - INFO - Train loss: 0.026193941682770155
2019-11-03 23:33:17,778 - sequence_tagger_bert - INFO - Validation loss: 0.4734239876270294
2019-11-03 23:33:17,779 - sequence_tagger_bert - INFO - Validation metrics: (0.4817518248175182,)
2019-11-03 23:33:17,779 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  75%|███████▌  | 15/20 [14:13<04:46, 57.22s/it][A

2019-11-03 23:34:11,154 - sequence_tagger_bert - INFO - Train loss: 0.024578233642764197
2019-11-03 23:34:14,366 - sequence_tagger_bert - INFO - Validation loss: 0.4755070209503174
2019-11-03 23:34:14,367 - sequence_tagger_bert - INFO - Validation metrics: (0.49186256781193494,)
2019-11-03 23:34:14,368 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  80%|████████  | 16/20 [15:10<03:48, 57.03s/it][A

2019-11-03 23:35:08,163 - sequence_tagger_bert - INFO - Train loss: 0.024713597943224593
2019-11-03 23:35:11,260 - sequence_tagger_bert - INFO - Validation loss: 0.474751353263855
2019-11-03 23:35:11,261 - sequence_tagger_bert - INFO - Validation metrics: (0.4940047961630696,)
2019-11-03 23:35:11,262 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  85%|████████▌ | 17/20 [16:07<02:50, 56.99s/it][A

2019-11-03 23:36:04,699 - sequence_tagger_bert - INFO - Train loss: 0.024434842630120603
2019-11-03 23:36:08,056 - sequence_tagger_bert - INFO - Validation loss: 0.48089444637298584
2019-11-03 23:36:08,057 - sequence_tagger_bert - INFO - Validation metrics: (0.4933973589435774,)
2019-11-03 23:36:08,058 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  90%|█████████ | 18/20 [17:04<01:53, 56.93s/it][A

2019-11-03 23:37:01,138 - sequence_tagger_bert - INFO - Train loss: 0.023629867640647246
2019-11-03 23:37:04,316 - sequence_tagger_bert - INFO - Validation loss: 0.477573037147522
2019-11-03 23:37:04,317 - sequence_tagger_bert - INFO - Validation metrics: (0.4924924924924925,)
2019-11-03 23:37:04,318 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  95%|█████████▌| 19/20 [18:00<00:56, 56.73s/it][A

2019-11-03 23:37:58,003 - sequence_tagger_bert - INFO - Train loss: 0.023345020170193263
2019-11-03 23:38:01,070 - sequence_tagger_bert - INFO - Validation loss: 0.4798295199871063
2019-11-03 23:38:01,071 - sequence_tagger_bert - INFO - Validation metrics: (0.4924924924924925,)
2019-11-03 23:38:01,072 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch: 100%|██████████| 20/20 [18:57<00:00, 56.74s/it][A
[A

2019-11-03 23:38:04,720 - sequence_tagger_bert - INFO - Entity-level f1: 0.5571428571428572
2019-11-03 23:38:04,721 - sequence_tagger_bert - INFO - Token-level f1: 0.5844269466316709



Epoch:   0%|          | 0/20 [00:00<?, ?it/s][A

2019-11-03 23:38:58,046 - sequence_tagger_bert - INFO - Train loss: 0.025655677117445926
2019-11-03 23:39:01,038 - sequence_tagger_bert - INFO - Validation loss: 0.4682237505912781
2019-11-03 23:39:01,039 - sequence_tagger_bert - INFO - Validation metrics: (0.4930681133212779,)
2019-11-03 23:39:01,065 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:   5%|▌         | 1/20 [00:56<17:48, 56.26s/it][A

KeyboardInterrupt: 

In [39]:
with open('labels_v3_lr7.txt', 'w') as f:
    for _string in _:
        #f.seek(0)
        f.write(', '.join(_string) + '\n')

## отвалидировалась на существующих метриках, лучше при lr 3e-6, 3e-7

In [23]:
from torch.utils.data import RandomSampler, SequentialSampler

from bert_sequence_tagger.bert_utils import create_loader_from_flair_corpus, get_parameters_without_decay
from bert_sequence_tagger.model_trainer_bert import ModelTrainerBert

from pytorch_transformers import AdamW, WarmupLinearSchedule

from bert_sequence_tagger.metrics import f1_entity_level, f1_token_level

train_dataloader = create_loader_from_flair_corpus(corpus.train, 
                                                           RandomSampler, 
                                                           batch_size=BATCH_SIZE)
val_dataloader = create_loader_from_flair_corpus(corpus.dev,
                                                 SequentialSampler,
                                                 batch_size=PRED_BATCH_SIZE)

optimizer = AdamW(get_parameters_without_decay(model), lr=LEARNING_RATE, betas=(0.9, 0.999), 
                  eps =1e-6, weight_decay=0.01, correct_bias=True)
lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=0.1, 
                                    t_total=(len(corpus.train) / BATCH_SIZE)*MAX_N_EPOCHS)
trainer = ModelTrainerBert(model=seq_tagger, 
                           optimizer=optimizer, 
                           lr_scheduler=lr_scheduler,
                           train_dataloader=train_dataloader, 
                           val_dataloader=val_dataloader,
                           update_scheduler='es',
                           keep_best_model=True,
                           restore_bm_on_lr_change=False,
                           max_grad_norm=1.,
                           validation_metrics=[f1_entity_level],
                           decision_metric=lambda metrics: -metrics[1])

trainer.train(epochs=MAX_N_EPOCHS)


test_dataloader = create_loader_from_flair_corpus(corpus.test,
                                                  SequentialSampler,
                                                  batch_size=PRED_BATCH_SIZE)

preds, probs, test_metrics = seq_tagger.predict(test_dataloader, evaluate=True, 
                                         metrics=[f1_entity_level, f1_token_level])
logger.info(f'Entity-level f1: {test_metrics[1]}')
logger.info(f'Token-level f1: {test_metrics[2]}')


Epoch:   0%|          | 0/20 [00:00<?, ?it/s][A

2019-10-23 17:36:07,138 - sequence_tagger_bert - INFO - Train loss: 0.44207449675415433
2019-10-23 17:36:09,913 - sequence_tagger_bert - INFO - Validation loss: 0.19390782713890076
2019-10-23 17:36:09,914 - sequence_tagger_bert - INFO - Validation metrics: (0.5417457305502846,)
2019-10-23 17:36:09,933 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:   5%|▌         | 1/20 [00:49<15:37, 49.32s/it][A

2019-10-23 17:37:00,264 - sequence_tagger_bert - INFO - Train loss: 0.09446393226471422
2019-10-23 17:37:03,435 - sequence_tagger_bert - INFO - Validation loss: 0.2510414719581604
2019-10-23 17:37:03,438 - sequence_tagger_bert - INFO - Validation metrics: (0.389370306181398,)
2019-10-23 17:37:03,438 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  10%|█         | 2/20 [01:42<15:10, 50.58s/it][A

2019-10-23 17:37:53,917 - sequence_tagger_bert - INFO - Train loss: 0.06731621946347245
2019-10-23 17:37:56,680 - sequence_tagger_bert - INFO - Validation loss: 0.26085543632507324
2019-10-23 17:37:56,683 - sequence_tagger_bert - INFO - Validation metrics: (0.4366041896361632,)
2019-10-23 17:37:56,685 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  15%|█▌        | 3/20 [02:36<14:33, 51.38s/it][A

2019-10-23 17:38:44,783 - sequence_tagger_bert - INFO - Train loss: 0.055439955626542754
2019-10-23 17:38:47,742 - sequence_tagger_bert - INFO - Validation loss: 0.27321499586105347
2019-10-23 17:38:47,749 - sequence_tagger_bert - INFO - Validation metrics: (0.44395361678630596,)
2019-10-23 17:38:47,750 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  20%|██        | 4/20 [03:27<13:40, 51.28s/it][A

2019-10-23 17:39:40,312 - sequence_tagger_bert - INFO - Train loss: 0.048552938294491615
2019-10-23 17:39:43,413 - sequence_tagger_bert - INFO - Validation loss: 0.2877540588378906
2019-10-23 17:39:43,414 - sequence_tagger_bert - INFO - Validation metrics: (0.45419637959407566,)
2019-10-23 17:39:43,414 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  25%|██▌       | 5/20 [04:22<13:08, 52.60s/it][A

2019-10-23 17:40:31,209 - sequence_tagger_bert - INFO - Train loss: 0.042401829072826316
2019-10-23 17:40:33,981 - sequence_tagger_bert - INFO - Validation loss: 0.2951381504535675
2019-10-23 17:40:33,982 - sequence_tagger_bert - INFO - Validation metrics: (0.4681528662420382,)
2019-10-23 17:40:33,982 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  30%|███       | 6/20 [05:13<12:07, 51.99s/it][A

2019-10-23 17:41:23,348 - sequence_tagger_bert - INFO - Train loss: 0.039143779795465156
2019-10-23 17:41:26,505 - sequence_tagger_bert - INFO - Validation loss: 0.3359852731227875
2019-10-23 17:41:26,506 - sequence_tagger_bert - INFO - Validation metrics: (0.4252232142857143,)
2019-10-23 17:41:26,511 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  35%|███▌      | 7/20 [06:05<11:17, 52.15s/it][A

2019-10-23 17:42:17,811 - sequence_tagger_bert - INFO - Train loss: 0.03634464747917194
2019-10-23 17:42:20,560 - sequence_tagger_bert - INFO - Validation loss: 0.3336271643638611
2019-10-23 17:42:20,561 - sequence_tagger_bert - INFO - Validation metrics: (0.45098039215686275,)
2019-10-23 17:42:20,561 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  40%|████      | 8/20 [06:59<10:32, 52.72s/it][A

2019-10-23 17:43:08,055 - sequence_tagger_bert - INFO - Train loss: 0.03377418829195572
2019-10-23 17:43:10,806 - sequence_tagger_bert - INFO - Validation loss: 0.34779462218284607
2019-10-23 17:43:10,808 - sequence_tagger_bert - INFO - Validation metrics: (0.42690383546414673,)
2019-10-23 17:43:10,809 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  45%|████▌     | 9/20 [07:50<09:31, 51.98s/it][A

2019-10-23 17:44:01,690 - sequence_tagger_bert - INFO - Train loss: 0.03119575367962343
2019-10-23 17:44:04,794 - sequence_tagger_bert - INFO - Validation loss: 0.3810552656650543
2019-10-23 17:44:04,795 - sequence_tagger_bert - INFO - Validation metrics: (0.40594625500285875,)
2019-10-23 17:44:04,796 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  50%|█████     | 10/20 [08:44<08:45, 52.58s/it][A

2019-10-23 17:44:55,029 - sequence_tagger_bert - INFO - Train loss: 0.02917186085480303
2019-10-23 17:44:57,722 - sequence_tagger_bert - INFO - Validation loss: 0.3724142014980316
2019-10-23 17:44:57,723 - sequence_tagger_bert - INFO - Validation metrics: (0.44114411441144114,)
2019-10-23 17:44:57,724 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  55%|█████▌    | 11/20 [09:37<07:54, 52.69s/it][A

2019-10-23 17:45:46,541 - sequence_tagger_bert - INFO - Train loss: 0.027264553147507198
2019-10-23 17:45:49,428 - sequence_tagger_bert - INFO - Validation loss: 0.385182648897171
2019-10-23 17:45:49,429 - sequence_tagger_bert - INFO - Validation metrics: (0.42333333333333334,)
2019-10-23 17:45:49,429 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  60%|██████    | 12/20 [10:28<06:59, 52.39s/it][A

2019-10-23 17:46:41,505 - sequence_tagger_bert - INFO - Train loss: 0.026822505367085404
2019-10-23 17:46:44,564 - sequence_tagger_bert - INFO - Validation loss: 0.3787497580051422
2019-10-23 17:46:44,566 - sequence_tagger_bert - INFO - Validation metrics: (0.4508419337316676,)
2019-10-23 17:46:44,569 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  65%|██████▌   | 13/20 [11:23<06:12, 53.22s/it][A

2019-10-23 17:47:33,097 - sequence_tagger_bert - INFO - Train loss: 0.02519219604476747
2019-10-23 17:47:36,121 - sequence_tagger_bert - INFO - Validation loss: 0.39794737100601196
2019-10-23 17:47:36,122 - sequence_tagger_bert - INFO - Validation metrics: (0.43468715697036225,)
2019-10-23 17:47:36,123 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  70%|███████   | 14/20 [12:15<05:16, 52.72s/it][A

2019-10-23 17:48:29,517 - sequence_tagger_bert - INFO - Train loss: 0.024216730564798974
2019-10-23 17:48:32,981 - sequence_tagger_bert - INFO - Validation loss: 0.39265915751457214
2019-10-23 17:48:32,982 - sequence_tagger_bert - INFO - Validation metrics: (0.44782608695652176,)
2019-10-23 17:48:32,983 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  75%|███████▌  | 15/20 [13:12<04:29, 53.96s/it][A

2019-10-23 17:49:23,983 - sequence_tagger_bert - INFO - Train loss: 0.02380452183237493
2019-10-23 17:49:26,647 - sequence_tagger_bert - INFO - Validation loss: 0.4019433856010437
2019-10-23 17:49:26,648 - sequence_tagger_bert - INFO - Validation metrics: (0.44480874316939895,)
2019-10-23 17:49:26,649 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  80%|████████  | 16/20 [14:06<03:35, 53.87s/it][A

2019-10-23 17:50:14,958 - sequence_tagger_bert - INFO - Train loss: 0.022866339103015335
2019-10-23 17:50:17,961 - sequence_tagger_bert - INFO - Validation loss: 0.4048379957675934
2019-10-23 17:50:17,962 - sequence_tagger_bert - INFO - Validation metrics: (0.4456521739130435,)
2019-10-23 17:50:17,963 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  85%|████████▌ | 17/20 [14:57<02:39, 53.10s/it][A

2019-10-23 17:51:09,676 - sequence_tagger_bert - INFO - Train loss: 0.022506936902087604
2019-10-23 17:51:12,574 - sequence_tagger_bert - INFO - Validation loss: 0.4067346751689911
2019-10-23 17:51:12,575 - sequence_tagger_bert - INFO - Validation metrics: (0.4397163120567376,)
2019-10-23 17:51:12,577 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  90%|█████████ | 18/20 [15:51<01:47, 53.56s/it][A

2019-10-23 17:52:00,266 - sequence_tagger_bert - INFO - Train loss: 0.022491128675900904
2019-10-23 17:52:02,964 - sequence_tagger_bert - INFO - Validation loss: 0.4055901765823364
2019-10-23 17:52:02,965 - sequence_tagger_bert - INFO - Validation metrics: (0.44189852700491,)
2019-10-23 17:52:02,966 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch:  95%|█████████▌| 19/20 [16:42<00:52, 52.61s/it][A

2019-10-23 17:52:52,142 - sequence_tagger_bert - INFO - Train loss: 0.022022977412353813
2019-10-23 17:52:55,384 - sequence_tagger_bert - INFO - Validation loss: 0.4082149267196655
2019-10-23 17:52:55,386 - sequence_tagger_bert - INFO - Validation metrics: (0.4378762999452655,)
2019-10-23 17:52:55,387 - sequence_tagger_bert - INFO - Current learning rate: 0.0



Epoch: 100%|██████████| 20/20 [17:34<00:00, 52.55s/it][A
[A

2019-10-23 17:52:59,515 - sequence_tagger_bert - INFO - Entity-level f1: 0.5282833251352681
2019-10-23 17:52:59,516 - sequence_tagger_bert - INFO - Token-level f1: 0.4879333040807372


In [None]:
', '.join(preds[5])

In [None]:
len(preds)

In [16]:
with open('labels_v.txt', 'w') as f:
    for _string in preds:
        #f.seek(0)
        f.write(', '.join(_string) + '\n')

In [29]:
def prepare_flair_corpus(corpus, name='tag', filter_tokens={'-DOCSTART-'}):
    result = []
    for sent in corpus[:10]:
        print ("sent", sent)
        print ("sent[0].text", sent[0].text)
        if sent[0].text in filter_tokens:
            continue
        else:
            result.append(([token.text for token in sent.tokens],
                           [token.tags[name].value for token in sent.tokens]))
    
    return result


In [17]:
test_dataloader = create_loader_from_flair_corpus(corpus.test,
                                                  SequentialSampler,
                                                  batch_size=PRED_BATCH_SIZE)

_, __, test_metrics = seq_tagger.predict(test_dataloader, evaluate=True, 
                                         metrics=[f1_entity_level, f1_token_level])
logger.info(f'Entity-level f1: {test_metrics[1]}')
logger.info(f'Token-level f1: {test_metrics[2]}')

2019-10-23 12:13:32,709 - sequence_tagger_bert - INFO - Entity-level f1: 0.43243243243243246
2019-10-23 12:13:32,711 - sequence_tagger_bert - INFO - Token-level f1: 0.43655723158828746


In [None]:
(0.9143007822800387, 0.9306361914074436)

In [None]:
class EvaluatorBase():
    """EvaluatorBase is abstract base class for all evaluators"""
    def get_evaluation_score_train_dev_test(self, tagger, datasets_bank, batch_size=-1):
        if batch_size == -1:
            batch_size = tagger.batch_size
        score_train, _ = self.predict_evaluation_score(tagger=tagger,
                                                       word_sequences=datasets_bank.word_sequences_train,
                                                       targets_tag_sequences=datasets_bank.tag_sequences_train,
                                                       batch_size=batch_size)
        score_dev, _ = self.predict_evaluation_score(tagger=tagger,
                                                     word_sequences=datasets_bank.word_sequences_dev,
                                                     targets_tag_sequences=datasets_bank.tag_sequences_dev,
                                                     batch_size=batch_size)
        score_test, msg_test = self.predict_evaluation_score(tagger=tagger,
                                                             word_sequences=datasets_bank.word_sequences_test,
                                                             targets_tag_sequences=datasets_bank.tag_sequences_test,
                                                             batch_size=batch_size)
        return score_train, score_dev, score_test, msg_test

    def predict_evaluation_score(self, tagger, word_sequences, targets_tag_sequences, batch_size):
        outputs_tag_sequences = tagger.predict_tags_from_words(word_sequences, batch_size)
        return self.get_evaluation_score(targets_tag_sequences, outputs_tag_sequences, word_sequences)

In [None]:
class EvaluatorF1MacroTokenLevel(EvaluatorBase):
    def __init__(self):
        self.tag_list = None
        self.tag2idx = dict()

    def __init_tag_list(self, targets_tag_sequences):
        if self.tag_list is not None:
            return
        self.tag_list = list()
        for tag_seq in targets_tag_sequences:
            for t in tag_seq:
                if t not in self.tag_list:
                    self.tag_list.append(t)
                    self.tag2idx[t] = len(self.tag_list)
        self.tag_list.sort()

    def tag_seq_2_idx_list(self, tag_seq):
        return [self.tag2idx[t] for t in tag_seq]

    def __get_zeros_tag_dict(self):
        return {tag: 0 for tag in self.tag_list}

    def __add_dict(self, dict1, dict2):
        for tag in self.tag_list:
            dict1[tag] += dict2[tag]
        return dict1

    def __div_dict(self, dict, d):
        for tag in self.tag_list:
            dict[tag] /= d
        return dict

    def __get_M_F1_msg(self, F1, precision, recall):
        msg = '\nF1 scores\n'
        msg += '-' * 24 + '\n'
        sum_M_F1 = 0
        sum_precision = 0
        sum_recall = 0
        for tag in self.tag_list:
            sum_M_F1 += F1[tag]
            sum_precision += precision[tag]
            sum_recall += recall[tag]
            msg += '%15s = f1 = %1.2f, precision = %1.2f, recall = %1.2f\n' % (tag, F1[tag], precision[tag], recall[tag])
        M_F1 = sum_M_F1 / len(F1)
        M_PR = sum_precision / len(F1)
        M_RE = sum_recall / len(F1)
        msg += '-'*24 + '\n'
        msg += 'Macro-F1 = %1.3f' % M_F1
        msg += 'Macro-Prescion = %1.3f' % M_PR
        msg += 'Macro-Recall = %1.3f' % M_RE
        return M_F1, msg

    def __add_to_dict(self, dict_in, tag, val):
        if tag in dict_in:
            dict_in[tag] += val
        else:
            dict_in[tag] = val
        return dict_in

    """EvaluatorF1MacroTagComponents is macro-F1 scores evaluator for each class of BOI-like tags."""
    def get_evaluation_score(self, targets_tag_sequences, outputs_tag_sequences, word_sequences=None):
        # Create list of tags
        self.__init_tag_list(targets_tag_sequences)
        i = 0
        for elem in zip(targets_tag_sequences, outputs_tag_sequences):
            if (i < 4):
                i = i +1
                print (elem[0])
                print (elem[1])
        # Init values
        TP = self.__get_zeros_tag_dict()
        FP = self.__get_zeros_tag_dict()
        FN = self.__get_zeros_tag_dict()
        F1 = self.__get_zeros_tag_dict()
        precision = self.__get_zeros_tag_dict()
        recall = self.__get_zeros_tag_dict()
        for targets_seq, outputs_tag_seq in zip(targets_tag_sequences, outputs_tag_sequences):
            for t, o in zip(targets_seq, outputs_tag_seq):
                if t == o:
                    TP = self.__add_to_dict(TP, t, 1)
                else:
                    FN = self.__add_to_dict(FN, t, 1)
                    FP = self.__add_to_dict(FP, o, 1)
        # Calculate F1 for each tag
        for tag in self.tag_list:
            F1[tag] = (2 * TP[tag] / max(2 * TP[tag] + FP[tag] + FN[tag], 1)) * 100
            precision[tag] = (TP[tag] / max(TP[tag] + FP[tag], 1))*100
            recall[tag] = (TP[tag] / max(TP[tag] + FN[tag], 1))*100
        # Calculate Macro-F1 score and prepare the message
        M_F1, msg = self.__get_M_F1_msg(F1,precision, recall)
        print(msg)
        #self.validate_M_F1_scikitlearn( targets_tag_sequences, outputs_tag_sequences)
        return M_F1, msg
