In [None]:
import os
import jieba
import torch
import pickle
import pandas as pd
import torch.nn as nn

from ark_nlp.model.tc.bert import Bert
from ark_nlp.model.tc.bert import BertConfig
from ark_nlp.model.tc.bert import Dataset
from ark_nlp.model.tc.bert import Task
from ark_nlp.model.tc.bert import get_default_model_optimizer
from ark_nlp.model.tc.bert import Tokenizer

### 一、数据读入与处理

#### 1. 数据读入

In [None]:
train_data_df = pd.read_json('../data/source_datasets/CHIP-CDN/CHIP-CDN_train.json')
dev_data_df = pd.read_json('../data/source_datasets/CHIP-CDN/CHIP-CDN_dev.json')

In [None]:
train_data_df['normalized_result_num'] = train_data_df['normalized_result'].apply(lambda x: len(x.split('##')))
dev_data_df['normalized_result_num'] = dev_data_df['normalized_result'].apply(lambda x: len(x.split('##')))

train_data_df['normalized_result_num_label'] = train_data_df['normalized_result_num'].apply(lambda x: 0 if x > 2 else x)
dev_data_df['normalized_result_num_label'] = dev_data_df['normalized_result_num'].apply(lambda x: 0 if x > 2 else x)

In [None]:
train_data_df = (train_data_df
                 .loc[:,['text', 'normalized_result_num_label']]
                 .rename(columns={'normalized_result_num_label': 'label'}))

dev_data_df = (dev_data_df
               .loc[:,['text', 'normalized_result_num_label']]
               .rename(columns={'normalized_result_num_label': 'label'}))

In [None]:
tc_train_dataset = Dataset(train_data_df)
tc_dev_dataset = Dataset(dev_data_df)

#### 2. 词典创建和生成分词器

In [None]:
tokenizer = Tokenizer(vocab='nghuyong/ernie-1.0', max_seq_len=100)

#### 3. ID化

In [None]:
tc_train_dataset.convert_to_ids(tokenizer)
tc_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
config = BertConfig.from_pretrained('nghuyong/ernie-1.0',
                                    num_labels=len(tc_train_dataset.cat2id))

#### 2. 模型创建

In [None]:
torch.cuda.empty_cache()

In [None]:
dl_module = Bert.from_pretrained('nghuyong/ernie-1.0', 
                                 config=config)

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
# 设置运行次数
num_epoches = 5
batch_size = 32

In [None]:
param_optimizer = list(dl_module.named_parameters())
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]      

#### 2. 任务创建

In [None]:
model = Task(dl_module, 'adamw', 'lsce', cuda_device=0, ema_decay=0.995)

#### 3. 训练

In [None]:
model.fit(tc_train_dataset, 
          tc_dev_dataset,
          lr=3e-5,
          epochs=num_epoches, 
          batch_size=batch_size,
          params=optimizer_grouped_parameters
         )

In [None]:
model.ema.store(model.module.parameters())
model.ema.copy_to(model.module.parameters())  

<br>

### 四、模型验证与保存

#### 1. 模型验证

In [None]:
from ark_nlp.factory.predictor import TCPredictor

In [None]:
tc_predictor_instance = TCPredictor(model.module, tokenizer, tc_train_dataset.cat2id)

In [None]:
tc_predictor_instance.predict_one_sample('怀孕伴精神障碍',
                                         return_proba=True)

#### 2. 模型保存

In [None]:
import pickle

In [None]:
torch.save(model.module.state_dict(),
           '../checkpoint/predict_num/module.pth')

In [None]:
with open('../checkpoint/predict_num/cat2id.pkl', "wb") as f:
    pickle.dump(tc_train_dataset.cat2id, f)