In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import jieba
import torch
import pickle
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.nn import Ernie
from ark_nlp.dataset import TMDataset
from ark_nlp.factory.task import TMTask
from ark_nlp.factory.optimizer import get_default_bert_optimizer
from ark_nlp.processor.tokenizer.transfomer import PairTokenizer

In [None]:
# 目录地址

train_data_path = '../data/source_datasets/KUAKE-QQR/KUAKE-QQR_train.json'
dev_data_path = '../data/source_datasets/KUAKE-QQR/KUAKE-QQR_dev.json'

### 一、数据读入与处理

#### 1. 数据读入

In [None]:
train_data_df = pd.read_json(train_data_path)
train_data_df = (train_data_df
                 .rename(columns={'query1': 'text_a', 'query2': 'text_b'})
                 .loc[:,['text_a', 'text_b', 'label']])

dev_data_df = pd.read_json(dev_data_path)
dev_data_df = dev_data_df[dev_data_df['label'] != "NA"]
dev_data_df = (dev_data_df
                 .rename(columns={'query1': 'text_a', 'query2': 'text_b'})
                 .loc[:,['text_a', 'text_b', 'label']])

In [None]:
tm_train_dataset = TMDataset(train_data_df)
tm_dev_dataset = TMDataset(dev_data_df)

#### 2. 词典创建和生成分词器

In [None]:
# 可以先创建词典，再加载入分词器
# 也可以使用分词器自动加载
# bert_vocab = transformers.AutoTokenizer.from_pretrained('nghuyong/ernie-1.0')
# tokenizer = TransfomerTokenizer(bert_vocab, max_seq_len=30)

In [None]:
tokenizer = PairTokenizer(vocab='nghuyong/ernie-1.0', max_seq_len=50)

#### 4. ID化

In [None]:
tm_train_dataset.convert_to_ids(tokenizer)
tm_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
from transformers import BertConfig

bert_config = BertConfig.from_pretrained('nghuyong/ernie-1.0', 
                                         num_labels=len(tm_train_dataset.cat2id))

#### 2. 模型创建

In [None]:
torch.cuda.empty_cache()

In [None]:
dl_module = Ernie.from_pretrained('nghuyong/ernie-1.0',
                                  config=bert_config)

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
# 设置运行次数
num_epoches = 10
batch_size = 32

In [None]:
optimizer = get_default_bert_optimizer(dl_module) 

#### 2. 任务创建

In [None]:
model = TMTask(dl_module, optimizer, 'ce', cuda_device=1)

#### 3. 训练

In [None]:
model.fit(tm_train_dataset, 
          tm_dev_dataset,
          lr=2e-5,
          epochs=5, 
          batch_size=batch_size
         )

<br>

### 四、模型验证与保存

#### 1. 模型验证

In [None]:
from ark_nlp.factory.predictor import TMPredictor

In [None]:
tm_predictor_instance = TMPredictor(model.module, tokenizer, tm_train_dataset.cat2id)

In [None]:
tm_predictor_instance.predict_one_sample(['女生第一次后大姨妈迟到', '第一次会不会影响大姨妈推迟'], 
                                         return_proba=True)

#### 2. Batch模型验证

In [None]:
test_data_df = pd.read_json('../data/source_datasets/KUAKE-QQR/KUAKE-QQR_train.json')
test_data_df = (test_data_df
                 .rename(columns={'query1': 'text_a', 'query2': 'text_b'})
                 .loc[:,['text_a', 'text_b', 'label']])


tm_test_dataset = TMDataset(test_data_df, categories=tm_train_dataset.categories, is_test=True)
tm_test_dataset.convert_to_ids(tokenizer)

In [None]:
predict_label = tm_predictor_instance.predict_batch(tm_test_dataset)

<br>

### 五、模型测试报告

In [None]:
1. 基本功能测试 通过
2. one sample predict 通过
3. batch predict  通过