In [32]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [34]:
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import sys

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
import logging


logger = logging.getLogger('sequence_tagger_bert')

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

logger.handlers = []

fhandler = logging.handlers.TimedRotatingFileHandler(filename='logs.txt', when='midnight')
fhandler.setFormatter(formatter)
logger.addHandler(fhandler)

handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger.addHandler(handler)

logger.setLevel(logging.DEBUG)

In [36]:
import torch

device = torch.device('cuda')
n_gpu = torch.cuda.device_count()

for i in range(n_gpu):
    print(torch.cuda.get_device_name(i))

Tesla V100-DGXS-32GB


In [22]:
CACHE_DIR = '../workdir/cache'
BATCH_SIZE = 16
#BATCH_SIZE = 8
PRED_BATCH_SIZE = 1000
MAX_LEN = 128
MAX_N_EPOCHS = 10
#MAX_N_EPOCHS = 100
#MAX_N_EPOCHS = 50
#MAX_N_EPOCHS = 10
REDUCE_ON_PLATEAU = False
WEIGHT_DECAY = 0.01
LEARNING_RATE = 3e-6
#LEARNING_RATE = 1e-5
#LEARNING_RATE = 2e-5

In [9]:
!pwd

/workspace/bert_sequence_tagger/src


In [51]:
from flair.datasets import ColumnCorpus


data_folder = '/workspace/bert_sequence_tagger/src/data/NER/Varvara_v3'
corpus = ColumnCorpus(data_folder, 
                      {0 : 'text', 1 : 'tag'},
                      train_file='train_pred_full.tsv',
                      test_file='test_manual_predfull_seq_labelling.tsv',
                      dev_file='dev_pred_full.tsv')

print(corpus.obtain_statistics())

2019-11-20 00:11:51,072 Reading data from /workspace/bert_sequence_tagger/src/data/NER/Varvara_v3
2019-11-20 00:11:51,073 Train: /workspace/bert_sequence_tagger/src/data/NER/Varvara_v3/train_pred_full.tsv
2019-11-20 00:11:51,074 Dev: /workspace/bert_sequence_tagger/src/data/NER/Varvara_v3/dev_pred_full.tsv
2019-11-20 00:11:51,075 Test: /workspace/bert_sequence_tagger/src/data/NER/Varvara_v3/test_manual_predfull_seq_labelling.tsv
{
    "TRAIN": {
        "dataset": "TRAIN",
        "total_number_of_documents": 3077,
        "number_of_documents_per_class": {},
        "number_of_tokens_per_tag": {},
        "number_of_tokens": {
            "total": 89350,
            "min": 6,
            "max": 106,
            "avg": 29.038024049398764
        }
    },
    "TEST": {
        "dataset": "TEST",
        "total_number_of_documents": 488,
        "number_of_documents_per_class": {},
        "number_of_tokens_per_tag": {},
        "number_of_tokens": {
            "total": 13736,
         

In [52]:
a = corpus.make_tag_dictionary(tag_type = 'tag')

In [53]:
a.idx2item

[b'<unk>',
 b'O',
 b'B-OBJ',
 b'B-PREDFULL',
 b'I-PREDFULL',
 b'I-OBJ',
 b'<START>',
 b'<STOP>']

In [54]:
from bert_sequence_tagger import SequenceTaggerBert, BertForTokenClassificationCustom
from pytorch_transformers import BertTokenizer, BertForTokenClassification
import torch.nn as nn

from bert_sequence_tagger.bert_utils import make_bert_tag_dict_from_flair_corpus


bpe_tokenizer = BertTokenizer.from_pretrained('bert-base-cased', cache_dir=None, do_lower_case=False)

idx2tag, tag2idx = make_bert_tag_dict_from_flair_corpus(corpus)

model = nn.DataParallel(BertForTokenClassificationCustom.from_pretrained('bert-base-cased', cache_dir=None, num_labels=len(tag2idx))).cuda()
#model = BertForTokenClassification.from_pretrained('bert-base-cased', cache_dir=CACHE_DIR, num_labels=len(tag2idx)).cuda()

seq_tagger = SequenceTaggerBert(bert_model=model, bpe_tokenizer=bpe_tokenizer, idx2tag=idx2tag, tag2idx=tag2idx, max_len=200)

In [55]:
bpe_tokenizer

<pytorch_transformers.tokenization_bert.BertTokenizer at 0x7f6250ad0390>

In [56]:
BATCH_SIZE = 10
PRED_BATCH_SIZE = 10
MAX_N_EPOCHS = 20
from torch.utils.data import RandomSampler, SequentialSampler

from bert_sequence_tagger.bert_utils import create_loader_from_flair_corpus, get_parameters_without_decay
from bert_sequence_tagger.model_trainer_bert import ModelTrainerBert

from pytorch_transformers import AdamW, WarmupLinearSchedule

from bert_sequence_tagger.metrics import f1_entity_level, f1_token_level

test_dataloader = create_loader_from_flair_corpus(corpus.test,
                                                  SequentialSampler,
                                                  batch_size=PRED_BATCH_SIZE)

_, __, test_metrics = seq_tagger.predict(test_dataloader, evaluate=True, 
                                         metrics=[f1_entity_level, f1_token_level])
logger.info(f'Entity-level f1: {test_metrics[1]}')
logger.info(f'Token-level f1: {test_metrics[2]}')

train_dataloader = create_loader_from_flair_corpus(corpus.train, 
                                                   RandomSampler, 
                                                   batch_size=BATCH_SIZE)
val_dataloader = create_loader_from_flair_corpus(corpus.dev,
                                                 SequentialSampler,
                                                 batch_size=PRED_BATCH_SIZE)

optimizer = AdamW(get_parameters_without_decay(model), lr=LEARNING_RATE, betas=(0.9, 0.999), 
                  eps =1e-6, weight_decay=0.01, correct_bias=True)
lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=0.1, 
                                    t_total=(len(corpus.train) / BATCH_SIZE)*MAX_N_EPOCHS)
trainer = ModelTrainerBert(model=seq_tagger, 
                           optimizer=optimizer, 
                           lr_scheduler=lr_scheduler,
                           train_dataloader=train_dataloader, 
                           val_dataloader=val_dataloader,
                           update_scheduler='es',
                           keep_best_model=True,
                           restore_bm_on_lr_change=False,
                           max_grad_norm=1.,
                           validation_metrics=[f1_entity_level],
                           decision_metric=lambda metrics: -metrics[1])

trainer.train(epochs=MAX_N_EPOCHS)


test_dataloader = create_loader_from_flair_corpus(corpus.test,
                                                  SequentialSampler,
                                                  batch_size=PRED_BATCH_SIZE)

_, __, test_metrics = seq_tagger.predict(test_dataloader, evaluate=True, 
                                         metrics=[f1_entity_level, f1_token_level])
logger.info(f'Entity-level f1: {test_metrics[1]}')
logger.info(f'Token-level f1: {test_metrics[2]}')

2019-11-20 00:18:43,851 - sequence_tagger_bert - INFO - Entity-level f1: 0.032533889468196034
2019-11-20 00:18:43,852 - sequence_tagger_bert - INFO - Token-level f1: 0.06207475067573866


Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

2019-11-20 00:19:36,479 - sequence_tagger_bert - INFO - Train loss: 0.3619607299953312
2019-11-20 00:19:39,953 - sequence_tagger_bert - INFO - Validation loss: 0.3088280222401386
2019-11-20 00:19:39,954 - sequence_tagger_bert - INFO - Validation metrics: (0.3985239852398525,)
2019-11-20 00:19:39,973 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:   5%|▌         | 1/20 [00:56<17:44, 56.04s/it]

2019-11-20 00:20:32,119 - sequence_tagger_bert - INFO - Train loss: 0.08740380431174652
2019-11-20 00:20:34,944 - sequence_tagger_bert - INFO - Validation loss: 0.39300361003090695
2019-11-20 00:20:34,945 - sequence_tagger_bert - INFO - Validation metrics: (0.4347826086956522,)
2019-11-20 00:20:34,964 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  10%|█         | 2/20 [01:51<16:43, 55.72s/it]

2019-11-20 00:21:26,956 - sequence_tagger_bert - INFO - Train loss: 0.06447363679309363
2019-11-20 00:21:29,632 - sequence_tagger_bert - INFO - Validation loss: 0.42421448148968743
2019-11-20 00:21:29,633 - sequence_tagger_bert - INFO - Validation metrics: (0.4388078630310716,)
2019-11-20 00:21:29,650 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  15%|█▌        | 3/20 [02:45<15:42, 55.41s/it]

2019-11-20 00:22:22,670 - sequence_tagger_bert - INFO - Train loss: 0.05221280366992699
2019-11-20 00:22:25,633 - sequence_tagger_bert - INFO - Validation loss: 0.43800021946521067
2019-11-20 00:22:25,635 - sequence_tagger_bert - INFO - Validation metrics: (0.44360428481411474,)
2019-11-20 00:22:25,659 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  20%|██        | 4/20 [03:41<14:49, 55.59s/it]

2019-11-20 00:23:17,594 - sequence_tagger_bert - INFO - Train loss: 0.04596274611902992
2019-11-20 00:23:20,634 - sequence_tagger_bert - INFO - Validation loss: 0.4470692996344552
2019-11-20 00:23:20,635 - sequence_tagger_bert - INFO - Validation metrics: (0.45316455696202534,)
2019-11-20 00:23:20,654 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  25%|██▌       | 5/20 [04:36<13:51, 55.41s/it]

2019-11-20 00:24:11,892 - sequence_tagger_bert - INFO - Train loss: 0.04020843446532917
2019-11-20 00:24:14,878 - sequence_tagger_bert - INFO - Validation loss: 0.48331384674259803
2019-11-20 00:24:14,879 - sequence_tagger_bert - INFO - Validation metrics: (0.45218492716909436,)
2019-11-20 00:24:14,880 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  30%|███       | 6/20 [05:30<12:50, 55.06s/it]

2019-11-20 00:25:06,428 - sequence_tagger_bert - INFO - Train loss: 0.0358084871920503
2019-11-20 00:25:09,618 - sequence_tagger_bert - INFO - Validation loss: 0.48216514689166373
2019-11-20 00:25:09,619 - sequence_tagger_bert - INFO - Validation metrics: (0.4474829086389062,)
2019-11-20 00:25:09,621 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  35%|███▌      | 7/20 [06:25<11:54, 54.96s/it]

2019-11-20 00:26:01,430 - sequence_tagger_bert - INFO - Train loss: 0.0332889553392306
2019-11-20 00:26:04,259 - sequence_tagger_bert - INFO - Validation loss: 0.5220233976909119
2019-11-20 00:26:04,261 - sequence_tagger_bert - INFO - Validation metrics: (0.46222222222222226,)
2019-11-20 00:26:04,281 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  40%|████      | 8/20 [07:20<10:58, 54.87s/it]

2019-11-20 00:26:57,505 - sequence_tagger_bert - INFO - Train loss: 0.03205127487421737
2019-11-20 00:27:00,557 - sequence_tagger_bert - INFO - Validation loss: 0.5088469949096623
2019-11-20 00:27:00,558 - sequence_tagger_bert - INFO - Validation metrics: (0.46134347275031684,)
2019-11-20 00:27:00,559 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  45%|████▌     | 9/20 [08:16<10:08, 55.29s/it]

2019-11-20 00:27:52,452 - sequence_tagger_bert - INFO - Train loss: 0.029160591297444294
2019-11-20 00:27:55,382 - sequence_tagger_bert - INFO - Validation loss: 0.48546486625039
2019-11-20 00:27:55,383 - sequence_tagger_bert - INFO - Validation metrics: (0.46892307692307694,)
2019-11-20 00:27:55,403 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  50%|█████     | 10/20 [09:11<09:11, 55.16s/it]

2019-11-20 00:28:47,639 - sequence_tagger_bert - INFO - Train loss: 0.026940659996327738
2019-11-20 00:28:50,424 - sequence_tagger_bert - INFO - Validation loss: 0.5129956489360732
2019-11-20 00:28:50,425 - sequence_tagger_bert - INFO - Validation metrics: (0.47160493827160493,)
2019-11-20 00:28:50,444 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  55%|█████▌    | 11/20 [10:06<08:16, 55.12s/it]

2019-11-20 00:29:41,321 - sequence_tagger_bert - INFO - Train loss: 0.025348623533424763
2019-11-20 00:29:44,479 - sequence_tagger_bert - INFO - Validation loss: 0.5219396179905389
2019-11-20 00:29:44,480 - sequence_tagger_bert - INFO - Validation metrics: (0.471100062150404,)
2019-11-20 00:29:44,482 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  60%|██████    | 12/20 [11:00<07:18, 54.80s/it]

2019-11-20 00:30:36,152 - sequence_tagger_bert - INFO - Train loss: 0.02420229448667764
2019-11-20 00:30:38,889 - sequence_tagger_bert - INFO - Validation loss: 0.5160302322176171
2019-11-20 00:30:38,890 - sequence_tagger_bert - INFO - Validation metrics: (0.47095179233621753,)
2019-11-20 00:30:38,891 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  65%|██████▌   | 13/20 [11:54<06:22, 54.68s/it]

2019-11-20 00:31:31,216 - sequence_tagger_bert - INFO - Train loss: 0.0222365621215809
2019-11-20 00:31:33,956 - sequence_tagger_bert - INFO - Validation loss: 0.5339932873924603
2019-11-20 00:31:33,957 - sequence_tagger_bert - INFO - Validation metrics: (0.47051520794537555,)
2019-11-20 00:31:33,958 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  70%|███████   | 14/20 [12:50<05:28, 54.80s/it]

2019-11-20 00:32:25,028 - sequence_tagger_bert - INFO - Train loss: 0.02142074331414414
2019-11-20 00:32:28,010 - sequence_tagger_bert - INFO - Validation loss: 0.5332436034365035
2019-11-20 00:32:28,011 - sequence_tagger_bert - INFO - Validation metrics: (0.4747225647348952,)
2019-11-20 00:32:28,029 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  75%|███████▌  | 15/20 [13:44<04:32, 54.58s/it]

2019-11-20 00:33:19,202 - sequence_tagger_bert - INFO - Train loss: 0.0203145507642963
2019-11-20 00:33:22,066 - sequence_tagger_bert - INFO - Validation loss: 0.5433915181187685
2019-11-20 00:33:22,067 - sequence_tagger_bert - INFO - Validation metrics: (0.47211895910780666,)
2019-11-20 00:33:22,067 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  80%|████████  | 16/20 [14:38<03:37, 54.42s/it]

2019-11-20 00:34:12,922 - sequence_tagger_bert - INFO - Train loss: 0.01986186942214134
2019-11-20 00:34:15,760 - sequence_tagger_bert - INFO - Validation loss: 0.5410953026049111
2019-11-20 00:34:15,761 - sequence_tagger_bert - INFO - Validation metrics: (0.4740740740740741,)
2019-11-20 00:34:15,761 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  85%|████████▌ | 17/20 [15:31<02:42, 54.20s/it]

2019-11-20 00:35:06,510 - sequence_tagger_bert - INFO - Train loss: 0.01940659679341668
2019-11-20 00:35:09,299 - sequence_tagger_bert - INFO - Validation loss: 0.5366039614920027
2019-11-20 00:35:09,300 - sequence_tagger_bert - INFO - Validation metrics: (0.4763076923076923,)
2019-11-20 00:35:09,324 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  90%|█████████ | 18/20 [16:25<01:48, 54.01s/it]

2019-11-20 00:36:01,028 - sequence_tagger_bert - INFO - Train loss: 0.018303427412839874
2019-11-20 00:36:03,860 - sequence_tagger_bert - INFO - Validation loss: 0.5431420835098479
2019-11-20 00:36:03,991 - sequence_tagger_bert - INFO - Validation metrics: (0.47466007416563666,)
2019-11-20 00:36:03,992 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch:  95%|█████████▌| 19/20 [17:20<00:54, 54.21s/it]

2019-11-20 00:36:54,500 - sequence_tagger_bert - INFO - Train loss: 0.018760964634991554
2019-11-20 00:36:57,514 - sequence_tagger_bert - INFO - Validation loss: 0.5384372797817355
2019-11-20 00:36:57,515 - sequence_tagger_bert - INFO - Validation metrics: (0.4773006134969324,)
2019-11-20 00:36:57,534 - sequence_tagger_bert - INFO - Current learning rate: 0.0


Epoch: 100%|██████████| 20/20 [18:13<00:00, 54.01s/it]


2019-11-20 00:37:00,991 - sequence_tagger_bert - INFO - Entity-level f1: 0.34105653382761814
2019-11-20 00:37:00,992 - sequence_tagger_bert - INFO - Token-level f1: 0.3320134023758757


In [57]:
with open('labels_v3_predful_manual.txt', 'w') as f:
    for _string in _:
        #f.seek(0)
        f.write(', '.join(_string) + '\n')

In [None]:
def prepare_flair_corpus(corpus, name='tag', filter_tokens={'-DOCSTART-'}):
    result = []
    for sent in corpus[:10]:
        print ("sent", sent)
        print ("sent[0].text", sent[0].text)
        if sent[0].text in filter_tokens:
            continue
        else:
            result.append(([token.text for token in sent.tokens],
                           [token.tags[name].value for token in sent.tokens]))
    
    return result


In [None]:
test_dataloader = create_loader_from_flair_corpus(corpus.test,
                                                  SequentialSampler,
                                                  batch_size=PRED_BATCH_SIZE)

_, __, test_metrics = seq_tagger.predict(test_dataloader, evaluate=True, 
                                         metrics=[f1_entity_level, f1_token_level])
logger.info(f'Entity-level f1: {test_metrics[1]}')
logger.info(f'Token-level f1: {test_metrics[2]}')

In [None]:
(0.9143007822800387, 0.9306361914074436)