In [None]:
cd /home/nested_ner

In [2]:
from model_scripts.data_utils.dataset_process import get_ind_sequence, dataset
from model_scripts.data_utils.label_helper import labels2ids, ids2labels
from transformers import get_linear_schedule_with_warmup
from model_scripts.utils.other_utils import set_seed
from model_scripts.task_trainers import GeneralTrainer
from model_scripts.span_classifier import NNERModel
from torch.utils.data import DataLoader
from torch import cuda, optim
import pandas as pd
import numpy as np
import random
import torch
import re

In [3]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [4]:
set_seed()

In [5]:
device = 'cuda' if cuda.is_available() else 'cpu'

##### Загрузка данных для обучения и валидации

train

In [6]:
# тренировочные тексты
ner_ds_path = 'data/all_data/train_texts.csv'

train_texts = pd.read_csv(ner_ds_path, sep=';')
to_del = [idx for idx, sent in enumerate(train_texts.Contents) if re.match(r'^\s+$', sent)]
train_texts = train_texts.drop(to_del).reset_index()
train_texts = train_texts.drop('index', axis=1)
# получение границ слов в последовательности
train_inds = get_ind_sequence(train_texts)

# тренировочные лейблы
ner_ds_path = 'data/all_data/train_spans.csv'
train_spans = pd.read_csv(ner_ds_path, sep=';')

In [7]:
train_texts_ds = dataset(train_texts, max_len=284)

val

In [8]:
ner_ds_path = 'data/all_data/dev_texts.csv'

val_texts = pd.read_csv(ner_ds_path, sep=';')
to_del = [idx for idx, sent in enumerate(val_texts.Contents) if re.match(r'^\s+$', sent)]
val_texts = val_texts.drop(to_del).reset_index()
val_texts = val_texts.drop('index', axis=1)
val_inds = get_ind_sequence(val_texts)

ner_ds_path = 'data/all_data/dev_spans.csv'
val_spans = pd.read_csv(ner_ds_path, sep=';')

In [9]:
val_texts_ds = dataset(val_texts, max_len=284)

test

In [6]:
ner_ds_path = 'data/all_data/test_texts.csv'

test_texts = pd.read_csv(ner_ds_path, sep=';')
to_del = [idx for idx, sent in enumerate(test_texts.Contents) if re.match(r'^\s+$', sent)]
test_texts = test_texts.drop(to_del).reset_index()
test_texts = test_texts.drop('index', axis=1)
test_inds = get_ind_sequence(test_texts)

ner_ds_path = 'data/all_data/test_spans.csv'
test_spans = pd.read_csv(ner_ds_path, sep=';')

In [7]:
test_texts_ds = dataset(test_texts, max_len=284)

##### Обучение

In [None]:
TRAIN_BATCH_SIZE = 1
VALID_BATCH_SIZE = 1

train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

val_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': False,
                'num_workers': 0
                }

test_params = {'batch_size': 1,
                'shuffle': False,
                'num_workers': 0
                }


training_loader = DataLoader(train_texts_ds, **train_params)
validation_loader = DataLoader(val_texts_ds, **val_params)
testing_loader = DataLoader(test_texts_ds, **test_params)

In [13]:
EPOCHS = 50
LEARNING_RATE = 1e-05

NNERModel - модель для основной задачи классификации вложенных сущностей


Важные параметры:

* extractor_type - тип используемого преобразования для получения представлений отрезков

    * 'weightedpooling' - простой взвещенный пулинг BERT-представлений слов
    * 'lstmattention' - модуль многоголового внимания со слоем LSTM для снижения размерности
    * 'linearattention' - модуль многоголового внимания со снижением размерности вектора с помощью линейного слоя (используется для всех экстракторов)
    * 'biaffine' - использование биаффинного метода кодирования слов в последовательности, кроме семантических представлений от BERT.
    * 'selfbiaffine' - использование биаффинного метода, а также механизма внимания. После конкатенации используется линейный слой для уменьшения размерности.

* num_heads - используется для экстракторов, включающих механизм внимания
* mode ('classification') - основная задача классификации
* extractor_use_gcn (Bool) - использование модуля для получения синтаксической информации с помощью парсера от natasha и графового модуля aggcn

In [14]:
model = NNERModel(num_labels=len(ids2labels),
                  device=device,
                  max_seq_len=284,
                  max_span_len=20,
                  num_heads=None,
                  extractor_type='biaffine',
                  mode='classification',
                  extractor_use_gcn=True)
_ = model.to(device)

Some weights of the model checkpoint at DeepPavlov/rubert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
# проигнорировать при валидации
optimizer = optim.Adam(model.parameters(),
                        lr=LEARNING_RATE)

total_steps = len(training_loader) * EPOCHS
warmup_steps = int(total_steps * 0.1)
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = warmup_steps,
                                            num_training_steps = total_steps)

In [16]:
# название для сохрания или загрузки модели
safe_prefix = 'biaffine_gcn'

In [17]:
# True при обучении модели с самого начала
# False - загрузка для дообучения
from_start = True

Требуется уточнить путь до места сохранения

Здесь загружается и сохраняется модель и оптимизатор с планировщиков на случай вылета обучения

In [18]:
if from_start == False:
    checkpoint = torch.load(f'data/all_data/general_checkpoints/{safe_prefix}_checkpoint.pth.tar')        
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    resume_epoch = checkpoint['epoch']
    last_best = checkpoint['f_score']
else:
    last_best = 0
    resume_epoch = 0

In [19]:
nner_trainer = GeneralTrainer(model=model,
                              device = device,
                              optimizer=optimizer,
                              scheduler=scheduler,
                              ids2labels=ids2labels,
                              labels2ids=labels2ids,
                              val_mode='comb',
                              logger_path=None # можно задать путь до файла, тогда логирование в файл
                              )

Весь процесс обучения

In [None]:
for epoch in range(EPOCHS):
    if resume_epoch+epoch+1 > EPOCHS:
        break
    nner_trainer.logger.info(f'EPOCH {resume_epoch+epoch+1}/{EPOCHS}')
    _ = nner_trainer(training_loader, train_inds, train_spans, mode='train')

    f_score = nner_trainer(validation_loader, val_inds, val_spans, mode='test')
    
    # на случай отслеживания оценок на валидации и тесте
    #_ = nner_trainer(testing_loader, test_spans, test_inds, mode='test')

    nner_trainer.logger.info('\n')
    
    if f_score > last_best:
        last_best = f_score
        check_path = f'data/all_data/general_checkpoints/{safe_prefix}_checkpoint.pth.tar'
        torch.save({'epoch': epoch+1+resume_epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'f_score' : last_best}, check_path)

##### Валидация на тесте с подгружаемой моделью

In [8]:
test_params = {'batch_size': 1,
                'shuffle': False,
                'num_workers': 0
                }
testing_loader = DataLoader(test_texts_ds, **test_params)

Про выбор гиперпараметров загрузки модели см. ранее

In [9]:
model = NNERModel(num_labels=len(ids2labels),
                  device=device,
                  max_seq_len=284,
                  max_span_len=20,
                  num_heads=None,
                  extractor_type='biaffine',
                  mode='classification',
                  extractor_use_gcn=True)
_ = model.to(device)

Some weights of the model checkpoint at DeepPavlov/rubert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
check_path = 'data/all_data/general_checkpoints/biaffine_gcn_checkpoint.pth.tar'
checkpoint = torch.load(check_path)

model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [11]:
nner_trainer = GeneralTrainer(model=model,
                              device=device,
                              ids2labels=ids2labels,
                              labels2ids=labels2ids)

In [12]:
f_score = nner_trainer(testing_loader, test_inds, test_spans, mode='val')

935it [00:59, 15.58it/s]
RESULTS FOR MODE VAL
NATIONALITY F1: 76.54321%
NATIONALITY Recall: 93.93939%
NATIONALITY Precision: 64.58333%
CITY F1: 89.91597%
CITY Recall: 89.16667%
CITY Precision: 90.67797%
TIME F1: 36.94268%
TIME Recall: 60.41667%
TIME Precision: 26.60550%
DATE F1: 91.42857%
DATE Recall: 91.60305%
DATE Precision: 91.25475%
ORGANIZATION F1: 82.94737%
ORGANIZATION Recall: 84.54936%
ORGANIZATION Precision: 81.40496%
COUNTRY F1: 95.07830%
COUNTRY Recall: 93.40659%
COUNTRY Precision: 96.81093%
EVENT F1: 65.78947%
EVENT Recall: 64.80218%
EVENT Precision: 66.80731%
AGE F1: 87.82288%
AGE Recall: 86.23188%
AGE Precision: 89.47368%
NUMBER F1: 85.05747%
NUMBER Recall: 83.33333%
NUMBER Precision: 86.85446%
PRODUCT F1: 63.86555%
PRODUCT Recall: 79.16667%
PRODUCT Precision: 53.52113%
PROFESSION F1: 83.00804%
PROFESSION Recall: 84.26573%
PROFESSION Precision: 81.78733%
FACILITY F1: 61.11111%
FACILITY Recall: 64.70588%
FACILITY Precision: 57.89474%
PERSON F1: 95.61177%
PERSON Recall: 96.