# train_bilstm_crf

## 代码解释

### 加载库

In [1]:
import torch
import warnings
from torch import optim
from pyner.train.metrics import F1_score
from pyner.train.trainer import Trainer
from pyner.io.data_loader import DataLoader
from pyner.io.data_transformer import DataTransformer
from pyner.utils.logginger import init_logger
from pyner.utils.utils import seed_everything
from pyner.config.basic_config import configs as config
from pyner.callback.lrscheduler import ReduceLROnPlateau
from pyner.model.nn.bilstm_crf import Model
from pyner.callback.modelcheckpoint import ModelCheckpoint
from pyner.callback.trainingmonitor import TrainingMonitor
import sys
warnings.filterwarnings("ignore")

### 变量设定

In [2]:
arch = 'bilstm_crf'
logger = init_logger(log_name=arch, log_dir=config['log_dir'])
logger.info("seed is %d"%config['seed'])
seed_everything(seed = config['seed'])
device = 'cuda:%d' % config['n_gpus'][0] if len(config['n_gpus']) else 'cpu'

[2020-12-23 11:00:01]: bilstm_crf <ipython-input-2-09a7f4c9d8bc>[line:3] INFO  seed is 2018


### 数据预处理模块

#### 数据预处理类定义

In [3]:
data_transformer = DataTransformer(logger = logger,
                                       is_train_mode = True,
                                       all_data_path = config['all_data_path'],
                                       vocab_path    = config['vocab_path'],
                                       max_features  = config['max_features'],
                                       label_to_id   = config['label_to_id'],
                                       train_file    = config['train_file_path'],
                                       valid_file      = config['valid_file_path'],
                                       valid_size      = config['valid_size'],
                                       min_freq      = config['min_freq'],
                                       seed          = config['seed'])

#### 生成词典

In [4]:
data_transformer.build_vocab()

#### 将句子转化为id形式

In [5]:
data_transformer.sentence2id(raw_data_path   = config['raw_train_path'],
                                 raw_target_path = config['raw_target_path'],
                                 x_var           = config['x_var'],
                                 y_var           = config['y_var'])

[2020-12-23 11:10:55]: bilstm_crf data_transformer.py[line:123] INFO  sentence to id


True

#### 建立词向量矩阵

In [11]:
word2vec_embedding_weight = data_transformer.build_embedding_matrix(embedding_path = config['embedding_weight_path'])
embedding_weight = word2vec_embedding_weight

[2020-12-23 11:48:39]: bilstm_crf data_transformer.py[line:173] INFO  initializer embedding matrix
[2020-12-23 11:48:39]: bilstm_crf data_transformer.py[line:196] INFO   load emebedding weights
[2020-12-23 11:48:39]: bilstm_crf data_transformer.py[line:212] INFO  Total 16115 word vectors.


### 构建数据迭代器

#### 加载训练数据集

In [7]:
train_loader = DataLoader(logger         = logger,
                              is_train_mode  = True,
                              x_var          = config['x_var'],
                              y_var          = config['y_var'],
                              skip_header    = False,
                              data_path      = config['train_file_path'],
                              batch_size     = config['batch_size'],
                              max_sentence_length = config['max_length'],
                              device = device)

#### 验证数据集

In [8]:
val_loader = DataLoader(logger        = logger,
                            is_train_mode = True,
                            x_var         = config['x_var'],
                            y_var         =config['y_var'],
                            skip_header   = False,
                            data_path     = config['valid_file_path'],
                            batch_size    = config['batch_size'],
                            max_sentence_length = config['max_length'],
                            device = device)

#### 产生数据迭代器

In [9]:
train_iter = train_loader.make_iter()
val_iter = val_loader.make_iter()

###  模型和优化器

In [12]:
model = Model(num_classes      = config['num_classes'],
              embedding_dim    = config['embedding_dim'],
              model_config     = config['models'][arch],
              embedding_weight = embedding_weight,
              vocab_size       = len(data_transformer.vocab),
              device           = device)
optimizer = optim.Adam(params = model.parameters(),lr = config['learning_rate'],
                       weight_decay = config['weight_decay'])

### callbacks

In [13]:
logger.info("initializing callbacks")

[2020-12-23 11:49:57]: bilstm_crf <ipython-input-13-964d892987af>[line:1] INFO  initializing callbacks


#### 模型保存

In [14]:
model_checkpoint = ModelCheckpoint(checkpoint_dir   = config['checkpoint_dir'],
                                   mode             = config['mode'],
                                   monitor          = config['monitor'],
                                   save_best_only   = config['save_best_only'],
                                   best_model_name  = config['best_model_name'],
                                   epoch_model_name = config['epoch_model_name'],
                                   arch             = arch,
                                   logger           = logger)

#### 监控训练过程

In [15]:
train_monitor = TrainingMonitor(fig_dir  = config['figure_dir'],
                                json_dir = config['log_dir'],
                                arch     = arch)

#### 学习率机制

In [17]:
lr_scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                 factor   = 0.5,
                                 patience = config['lr_patience'],
                                 min_lr   = 1e-9,
                                 epsilon  = 1e-5,
                                 verbose  = 1,
                                 mode     = config['mode'])

### 模型训练

In [18]:
logger.info('training model....')

[2020-12-23 14:34:51]: bilstm_crf <ipython-input-18-f1bb7e4c4970>[line:1] INFO  training model....


#### 模型 Trainer 定义

In [20]:
trainer = Trainer(model            = model,
                      train_data       = train_iter,
                      val_data         = val_iter,
                      optimizer        = optimizer,
                      epochs           = config['epochs'],
                      label_to_id      = config['label_to_id'],
                      evaluate         = F1_score(num_classes=config['num_classes']),
                      logger           = logger,
                      model_checkpoint = model_checkpoint,
                      training_monitor = train_monitor,
                      resume           = config['resume'],
                      lr_scheduler     = lr_scheduler,
                      n_gpu            = config['n_gpus'],
                      avg_batch_loss   = True)

#### 查看模型结构

In [21]:
trainer.summary()

[2020-12-23 14:37:30]: bilstm_crf trainer.py[line:82] INFO  trainable parameters: 0.9659099999999999M
[2020-12-23 14:37:30]: bilstm_crf trainer.py[line:84] INFO  Model(
  (embedding): Embed_Layer(
    (encoder): Embedding(4769, 100)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (lstm): BILSTM(
    (lstm): LSTM(100, 200, batch_first=True, dropout=0.5, bidirectional=True)
    (linear): Linear(in_features=400, out_features=14, bias=True)
  )
  (crf): CRF()
)


#### 拟合模型   

In [None]:
trainer.train()

----------------- training start -----------------------
Epoch 1/100......
[training] 179/179 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] -4.0s/step- loss: 5.5657- acc: 0.9312 - f1: 0.4581
training result:


[2020-12-23 14:46:25]: bilstm_crf trainer.py[line:194] INFO  
Epoch: 1 - loss: 6.1080 acc: 0.9339 - f1: 0.4025 val_loss: 3.1100 - val_acc: 0.9375 - val_f1: 0.5131
[2020-12-23 14:46:25]: bilstm_crf modelcheckpoint.py[line:47] INFO  
Epoch 1: val_loss improved from inf to 3.11004


----------- Train entity score:
----------- valid entity score:
----------------- training start -----------------------
Epoch 2/100......
[training] 179/179 [>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] -3.8s/step- loss: 2.7980- acc: 0.9544 - f1: 0.4923
training result:


[2020-12-23 14:54:52]: bilstm_crf trainer.py[line:194] INFO  
Epoch: 2 - loss: 2.2388 acc: 0.9593 - f1: 0.4895 val_loss: 1.9590 - val_acc: 0.9499 - val_f1: 0.5418


----------- Train entity score:
----------- valid entity score:


[2020-12-23 14:54:52]: bilstm_crf modelcheckpoint.py[line:47] INFO  
Epoch 2: val_loss improved from 3.11004 to 1.95901


----------------- training start -----------------------
Epoch 3/100......
[training] 14/179 [>>                            ] -1.3s/step- loss: 1.0951- acc: 0.9736 - f1: 0.5117

#### 释放显存

In [None]:
 if len(config['n_gpus']) > 0:
        torch.cuda.empty_cache()