In [None]:
import json
import pandas as pd

In [None]:
# 目录地址
train_data_path = '../data/source_datasets/tnews/train_few_all.json'
dev_data_path = '../data/source_datasets/tnews/test_public.json'

### 一、数据读入与处理

#### 1. 数据读入

In [None]:
label_en2zh = {'news_tech': '科技', 'news_entertainment': '娱乐', 'news_car': '汽车', 'news_travel': '旅游',
               'news_finance': '财经',
               'news_edu': '教育', 'news_world': '国际', 'news_house': '房产', 'news_game': '电竞', 'news_military': '军事',
               'news_story': '故事', 'news_culture': '文化', 'news_sports': '体育', 'news_agriculture': '农业',
               'news_stock': '股票'}

In [None]:
def load_data(filename):  # 加载数据
    D = []
    with open(filename, encoding='utf-8') as f:
        for i, l in enumerate(f):
            l = json.loads(l)
            l['label_zh'] = label_en2zh[l['label_desc']]
            D.append(l)
    df = pd.DataFrame(D)
    return df


train_data_df = load_data(train_data_path)
train_data_df = (train_data_df
                     .loc[:, ['sentence', 'label_zh']]).rename(columns={'sentence': 'text', 'label_zh': 'label'})

In [None]:
dev_data_df = load_data(dev_data_path)
dev_data_df = (dev_data_df
                     .loc[:, ['sentence', 'label_zh']]).rename(columns={'sentence': 'text', 'label_zh': 'label'})

In [None]:
from ark_nlp.dataset import SentenceClassificationDataset


class PTurningDataset(SentenceClassificationDataset):
    """
    将单句分类任务转换为MLM任务
    """
    
    def __init__(
        self, 
        *args, 
        prompt=None,
        **kwargs
    ):
        super(PTurningDataset, self).__init__(*args, **kwargs)
        
        if self.is_test and mask_lm_label_size is None:
            raise ValueError("The mask_lm_label_size must not be None")
                    
        if prompt is not None:
            self.prompt = prompt
            self.mask_lm_label_size = self.prompt.count("[MASK]") 


    def _convert_to_transfomer_ids(self, bert_tokenizer):

        features = []
        
        start_mask_position = self.prompt.index("[MASK]")
        
        mask_position = [
            start_mask_position + index
            for index in range(self.mask_lm_label_size)
        ]
        
        for (index_, row_) in enumerate(self.dataset):
            
            input_ids = bert_tokenizer.sequence_to_ids(row_['text'], self.prompt)
            
            input_ids, input_mask, segment_ids = input_ids

            feature = {
                'input_ids': input_ids,
                'attention_mask': input_mask,
                'token_type_ids': segment_ids,
                'mask_position': np.array(mask_position)
            }

            if not self.is_test:
                label_ids = self.cat2id[row_['label']]
                mask_lm_label = bert_tokenizer.vocab.convert_tokens_to_ids(bert_tokenizer.tokenize(row_['label']))
                
                feature['label_ids'] = np.array(mask_lm_label)

            features.append(feature)

        return features

In [None]:
p_tokens = ["[unused{}]".format(i) for i in range(5)]
mask_tokens = ["[MASK]"] * 2
prompt = p_tokens + ['[CLS]'] + mask_tokens

In [None]:
p_turning_train_dataset = PTurningDataset(train_data_df, prompt=prompt)
p_turning_dev_dataset = PTurningDataset(dev_data_df, prompt=prompt)

In [None]:
import abc
import torch
import random
import numpy as np

from torch.utils.data import Dataset
from ark_nlp.processor.tokenizer.transfomer import TransfomerTokenizer
from torchvision import datasets, transforms as T

class PromptMLMTransformerTokenizer(TransfomerTokenizer):
    """
    模板学习Transfomer文本编码器，用于对文本进行分词、ID化、填充等操作

    Args:
        vocab: transformers词典类对象、词典地址或词典名，用于实现文本分词和ID化
        max_seq_len (:obj:`int`): 预设的文本最大长度
    """
    def sequence_to_ids(
        self,
        sequence,
        prompt,
        return_sequence_length=False,
        **kwargs
    ):
        """
        将序列ID化

        Args:
            sequence (:obj:`str` or :obj:`list`): 输入序列
            prompt (:obj:`list`): 模板
            return_sequence_length (:obj:`bool`, optional, defaults to False): 返回是否包含序列长度
        """
        if type(sequence) == str:
            sequence = self.tokenize(sequence)

        if return_sequence_length:
            sequence_length = len(sequence)
            
        # 对超长序列进行截断
        if len(sequence) > self.max_seq_len - 1 - len(prompt):
            sequence = sequence[0:(self.max_seq_len - 1 - len(prompt))]
            
        # 分别在首尾拼接特殊符号
        sequence = prompt + sequence + ['[SEP]']
                
        # ID化
        sequence = self.vocab.convert_tokens_to_ids(sequence)
        
        segment_ids = [0] * len(sequence)
        
        # 根据max_seq_len与seq的长度产生填充序列
        padding = [0] * (self.max_seq_len - len(sequence))
        # 创建seq_mask
        sequence_mask = [1] * len(sequence) + padding
        # 创建seq_segment
        segment_ids = segment_ids + padding
        # 对seq拼接填充序列
        sequence += padding
        
        sequence = np.asarray(sequence, dtype='int64')
        sequence_mask = np.asarray(sequence_mask, dtype='int64')
        segment_ids = np.asarray(segment_ids, dtype='int64')

        if return_sequence_length:
            return (sequence, sequence_mask, segment_ids, sequence_length)

        return (sequence, sequence_mask, segment_ids)

In [None]:
import transformers
vocab = transformers.AutoTokenizer.from_pretrained('nghuyong/ernie-1.0')

In [None]:
vocab.add_special_tokens({'additional_special_tokens': ["[unused{}]".format(i) for i in range(5)]})

In [None]:
tokenizer = PromptMLMTransformerTokenizer(vocab, 100)

In [None]:
p_turning_train_dataset.convert_to_ids(tokenizer)
p_turning_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
from transformers import BertConfig

bert_config = BertConfig.from_pretrained(
    'nghuyong/ernie-1.0',
    num_labels=vocab.vocab_size
)

#### 2. 模型创建

In [None]:
torch.cuda.empty_cache()

In [None]:
import torch

from torch import nn
from torch import Tensor
from transformers import BertModel
from transformers import BertPreTrainedModel
from transformers.models.bert.modeling_bert import BertPredictionHeadTransform

from ark_nlp.nn.base.bert import Bert


class BertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size, bert_config.num_labels, bias=False)

        self.bias = nn.Parameter(torch.zeros(bert_config.num_labels))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states


class BertForMaskedLM(Bert):

    """
    基于BERT的mlm任务

    :param config: (obejct) 模型的配置对象
    :param bert_trained: (bool) bert参数是否可训练，默认可训练

    :returns:

    Reference:
        [1] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

    """
    def __init__(
            self,
            config,
            encoder_trained=True
    ):
        super(BertForMaskedLM, self).__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)

        self.classifier = BertLMPredictionHead(config)

        self.init_weights()

    @staticmethod
    def _batch_gather(data: torch.Tensor, index: torch.Tensor):
        """
        实现类似 tf.batch_gather 的效果
        :param data: (bs, max_seq_len, hidden)
        :param index: (bs, n)
        :return: a tensor which shape is (bs, n, hidden)
        """
        index = index.unsqueeze(-1).repeat_interleave(data.size()[-1], dim=-1)  # (bs, n, hidden)
        return torch.gather(data, 1, index)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        mask_position=None,
        **kwargs
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        sequence_output = outputs[0]
        
        sequence_output = BertForMaskedLM._batch_gather(sequence_output, mask_position)

        batch_size, _, hidden_size = sequence_output.shape
        
        sequence_output = sequence_output.reshape(-1, hidden_size)
                
        out = self.classifier(sequence_output)
        
        return out

In [None]:
dl_module = BertForMaskedLM.from_pretrained('nghuyong/ernie-1.0', 
                                        config=bert_config)

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
# 设置运行次数
num_epoches = 10
batch_size = 32

In [None]:
from transformers import AdamW


def get_ptuning_bert_optimizer(
    module,
    lr: float = 3e-5,
    eps: float = 1e-6,
    correct_bias: bool = True,
    weight_decay: float = 1e-3,
):
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in module.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay},
        {"params": [p for n, p in module.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0},
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=lr,
                      eps=eps,
                      correct_bias=correct_bias,
                      weight_decay=weight_decay)
    return optimizer

In [None]:
optimizer = get_ptuning_bert_optimizer(dl_module)

#### 2. 任务创建

In [None]:
import torch
import numpy as np
import sklearn.metrics as sklearn_metrics

from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask


class PromptMLMTask(SequenceClassificationTask):

    def __init__(self, *args, tokenizer=None, **kwargs):
        super(PromptMLMTask, self).__init__(*args, **kwargs)
        self.tokenizer = tokenizer

    def _compute_loss(
        self,
        inputs,
        logits,
        verbose=True,
        **kwargs
    ):
        labels = torch.squeeze(inputs['label_ids'].reshape(-1, 1))

        loss = self.loss_function(logits, labels)

        return loss

    def _on_evaluate_begin_record(self, **kwargs):
        self.evaluate_logs['eval_loss'] = 0
        self.evaluate_logs['eval_acc'] = 0
        self.evaluate_logs['eval_step'] = 0
        self.evaluate_logs['eval_example'] = 0

        self.evaluate_logs['labels'] = []
        self.evaluate_logs['logits'] = []
        
    def _on_evaluate_step_end(self, inputs, outputs, **kwargs):

        with torch.no_grad():
            # compute loss
            logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)
            self.evaluate_logs['eval_loss'] += loss.item()

            labels = inputs['label_ids'].cpu()
            logits = logits.cpu()
            
            batch_size = len(labels)
            vocab_size = logits.shape[1]
            label_length = labels.shape[1]
            
            # logits: [batch_size, label_lenght, vocab_size]
            logits = logits.reshape([batch_size, -1, vocab_size]).numpy()
            
            # [label_num, label_length]
            labels_ids = np.array(
                [self.tokenizer.vocab.convert_tokens_to_ids(
                    self.tokenizer.tokenize(_cat)) for _cat in self.cat2id])
                                    
            preds = np.ones(shape=[batch_size, len(labels_ids)])
            
            for index in range(label_length):
                preds *= logits[:, index, labels_ids[:, index]]
                
            preds = np.argmax(preds, axis=-1)
            
            label_indexs = []
            for _label in labels.numpy():
                _label = "".join(
                    tokenizer.vocab.convert_ids_to_tokens(list(_label)))
                
                label_indexs.append(self.cat2id[_label])

            label_indexs = np.array(label_indexs)
            
        self.evaluate_logs['labels'].append(label_indexs)
        self.evaluate_logs['logits'].append(preds)

        self.evaluate_logs['eval_example'] += len(label_indexs)
        self.evaluate_logs['eval_step'] += 1
        self.evaluate_logs['eval_acc'] += (label_indexs == preds).sum()

    def _on_evaluate_epoch_end(
        self,
        validation_data,
        epoch=1,
        is_evaluate_print=True,
        **kwargs
    ):

        _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)
        _preds = np.concatenate(self.evaluate_logs['logits'], axis=0)

        f1_score = sklearn_metrics.f1_score(_labels, _preds, average='macro')

        report_ = sklearn_metrics.classification_report(
            _labels,
            _preds,
            target_names=[str(_category) for _category in validation_data.categories]
        )

        confusion_matrix_ = sklearn_metrics.confusion_matrix(_labels, _preds)

        if is_evaluate_print:
            print('classification_report: \n', report_)
            print('confusion_matrix_: \n', confusion_matrix_)
            print('test loss is:{:.6f},test acc is:{:.6f},f1_score is:{:.6f}'.format(
                self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step'],
                self.evaluate_logs['eval_acc'] / self.evaluate_logs['eval_example'],
                f1_score
                )
            )

In [None]:
model = PromptMLMTask(dl_module, optimizer, 'ce', cuda_device=0, tokenizer=tokenizer)

#### 3. 训练

In [None]:
model.fit(
    p_turning_train_dataset,
    p_turning_dev_dataset,
    lr=2e-5,
    epochs=20,
    batch_size=batch_size
)


<br>

### 四、模型验证与保存

#### 1. 模型验证

In [None]:
import torch

from torch.utils.data import DataLoader

from ark_nlp.factory.predictor.base._predictor import Predictor


class PromptMLMPredictor(Predictor):
    
    def __init__(self, *args, prompt, **kwargs):
        super(PromptMLMPredictor, self).__init__(*args, **kwargs)
        self.prompt = prompt
    
    def _convert_to_transfomer_ids(
        self,
        text
    ):
        start_mask_position = self.prompt.index("[MASK]")
        
        mask_position = [
            start_mask_position + index
            for index in range(len(list(self.cat2id.keys())[0]))
        ]
        
        input_ids = self.tokenizer.sequence_to_ids(text, self.prompt)
        input_ids, input_mask, segment_ids = input_ids

        features = {
            'input_ids': input_ids,
            'attention_mask': input_mask,
            'token_type_ids': segment_ids,
            'mask_position': np.array(mask_position)
        }
        
        return features
    
    def predict_one_sample(
        self,
        text='',
        topk=1,
        return_label_name=True,
        return_proba=False
    ):
        if topk is None:
            topk = len(self.cat2id) if len(self.cat2id) > 2 else 1

        features = self._get_input_ids(text)
        self.module.eval()
        
        with torch.no_grad():
            inputs = self._get_module_one_sample_inputs(features)
            logit = self.module(**inputs).cpu().numpy()
                        
        # [label_num, label_length]
        labels_ids = np.array(
            [self.tokenizer.vocab.convert_tokens_to_ids(
                self.tokenizer.tokenize(_cat)) for _cat in self.cat2id])

        preds = np.ones(shape=[len(labels_ids)])
        
        label_length = len(list(self.cat2id.keys())[0])

        for index in range(label_length):
            preds *= logit[index, labels_ids[:, index]]
            
        preds = torch.Tensor(preds)
        preds = preds.reshape(1, -1)

        probs, indices = preds.topk(topk, dim=1, sorted=True)

        preds = []
        probas = []
        for pred_, proba_ in zip(indices.cpu().numpy()[0], probs.cpu().numpy()[0].tolist()):

            if return_label_name:
                pred_ = self.id2cat[pred_]

            preds.append(pred_)

            if return_proba:
                probas.append(proba_)

        if return_proba:
            return list(zip(preds, probas))

        return preds
    

In [None]:
p_turning_instance = PromptMLMPredictor(model.module, tokenizer, p_turning_train_dataset.cat2id, prompt=prompt)

In [None]:
p_turning_instance.predict_one_sample('小米就要港股上市了，那么为什么选择香港而没有选择上海？', topk=15, return_proba=True)