In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import json
import jieba
import torch
import pickle
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.model.prompt.prompt_ernie_ctm_nptag import Module
from ark_nlp.model.prompt.prompt_ernie_ctm_nptag import ModuleConfig
from ark_nlp.model.prompt.prompt_ernie_ctm_nptag import Dataset
from ark_nlp.model.prompt.prompt_ernie_ctm_nptag import Task
from ark_nlp.model.prompt.prompt_ernie_ctm_nptag import get_default_model_optimizer
from ark_nlp.model.prompt.prompt_ernie_ctm_nptag import Tokenizer
from ark_nlp.factory.utils.seed import set_seed

In [None]:
set_seed(42)

In [None]:
# 目录地址
# 数据集下载地址：https://bj.bcebos.com/paddlenlp/paddlenlp/datasets/nptag_dataset.tar.gz

train_data_path = '../data/source_datasets/nptag_dataset/train.txt'
dev_data_path = '../data/source_datasets/nptag_dataset/dev.txt'
name_category_map_path  = '../data/source_datasets/nptag_dataset/name_category_map.json'

In [None]:
# 预训练模型地址
module_path = 'freedomking/ernie-ctm-nptag'

### 一、数据读入与处理

#### 1. 数据读入

In [None]:
train_data_df = pd.read_csv(train_data_path, sep='\t', names=['text', 'label'])
dev_data_df = pd.read_csv(dev_data_path, sep='\t', names=['text', 'label'])

In [None]:
name_category_map = json.load(open(name_category_map_path, 'r', encoding='utf-8'))

In [None]:
# 设置prompt
mask_tokens = ["[MASK]"] * 5
prompt = ['是'] + mask_tokens

#### 2. 词典创建和生成分词器

In [None]:
tokenizer = Tokenizer(module_path, 100)

#### 3. 对齐label

In [None]:
# 由于prompt中的[MASK]数量一定，所以需要对齐到

label2newlabel = dict()

for _k, _ in name_category_map.items():
    _term = _k
    label2newlabel[_k] = ''.join(tokenizer.tokenize(_term) + ['[PAD]'] * (5 - len(tokenizer.tokenize(_term))))
    
label2newlabel['海绵蛋糕'] = '海绵蛋糕[PAD]'
    
train_data_df['label'] = train_data_df['label'].apply(lambda x: label2newlabel[x])
dev_data_df['label'] = dev_data_df['label'].apply(lambda x: label2newlabel[x])

categories = [_v for _, _v in label2newlabel.items()]

In [None]:
prompt_train_dataset = Dataset(train_data_df, prompt=prompt, categories=categories)
prompt_dev_dataset = Dataset(dev_data_df, prompt=prompt, categories=categories)

#### 4. ID化

In [None]:
prompt_train_dataset.convert_to_ids(tokenizer)
prompt_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
config = ModuleConfig.from_pretrained(
    module_path,
    num_labels=tokenizer.vocab.vocab_size
)

#### 2. 模型创建

In [None]:
torch.cuda.empty_cache()

In [None]:
dl_module = Module.from_pretrained(
    module_path,
    config=config
)

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
# 设置运行次数
num_epoches = 10
batch_size = 32

In [None]:
optimizer = get_default_model_optimizer(dl_module)

#### 2. 任务创建

In [None]:
model = Task(dl_module, optimizer, 'ce', cuda_device=0, tokenizer=tokenizer)

#### 3. 训练

In [None]:
model.fit(
    prompt_train_dataset,
    prompt_dev_dataset,
    lr=2e-5,
    epochs=10,
    batch_size=batch_size
)

<br>

### 四、模型验证与保存

#### 1. 模型验证

In [None]:
from ark_nlp.model.prompt.prompt_bert import Predictor

In [None]:
prompt_instance = Predictor(model.module, tokenizer, prompt_train_dataset.cat2id, prompt=prompt)

In [None]:
prompt_instance.predict_one_sample('美国队长3', topk=15, return_proba=True)