In [1]:
# import sys
# sys.path.append('../../..')

import os
import jieba
import torch
import pickle
import torch.nn as nn
import torch.optim as optim

from ark_nlp.dataset import TMDataset
from ark_nlp.processor.vocab import CharVocab
from ark_nlp.processor.tokenizer.tm import TransfomerTokenizer
from ark_nlp.neural_network import VanillaBert
from ark_nlp.factory.task import TMTask

In [23]:
# !pip install -i http://39.99.190.185:10080 ark-nlp --trusted-host 39.99.190.185

### 一、数据读入与处理

#### 1. 数据读入

In [2]:
tm_train_dataset = TMDataset(
    os.path.abspath(os.path.join(os.getcwd(), "../.."))
    + '/data/task_datasets/ccks2018_text_sim/train_data.tsv', is_retain_dataset=True)

tm_dev_dataset = TMDataset(
    os.path.abspath(os.path.join(os.getcwd(), "../.."))
    + '/data/task_datasets/ccks2018_text_sim/dev_data.tsv', is_retain_dataset=True)

#### 2. 词典创建

In [3]:
import transformers 
from transformers import AutoTokenizer

bert_vocab = transformers.AutoTokenizer.from_pretrained('bert-base-chinese')

#### 3. 生成分词器

In [4]:
max_seq_length=50

tokenizer = TransfomerTokenizer(max_seq_length, bert_vocab)

#### 4. ID化

In [5]:
tm_train_dataset.convert_to_ids(tokenizer)
tm_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [6]:
from transformers import BertConfig

bert_config = BertConfig.from_pretrained('bert-base-chinese', 
                                         num_labels=len(tm_train_dataset.cat2id))

#### 2. 模型创建

In [7]:
torch.cuda.empty_cache()

In [8]:
dl_module = VanillaBert.from_pretrained('bert-base-chinese', 
                                        config=bert_config)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing VanillaBert: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing VanillaBert from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VanillaBert from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VanillaBert were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.weight', 'cla

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [9]:
# 设置运行次数
num_epoches = 10
batch_size = 32

In [10]:
param_optimizer = list(dl_module.named_parameters())
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]      

#### 2. 任务创建

In [11]:
model = TMTask(dl_module, 'adamw', 'ce')

In [12]:
tm_dev_dataset.categories = ['0', '1']

#### 3. 训练

In [None]:
model.fit(tm_train_dataset, 
          tm_dev_dataset,
          lr=5e-5,
          epochs=3, 
          batch_size=batch_size,
          params=optimizer_grouped_parameters
         )

<br>

### 四、模型验证与保存

#### 1. 模型验证