In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import jieba
import torch
import pickle
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.model.tm.bert import Bert
from ark_nlp.model.tm.bert import BertConfig
from ark_nlp.dataset import TMDataset
from ark_nlp.model.tm.bert import Task
from ark_nlp.model.tm.bert import get_default_model_optimizer
from ark_nlp.model.tm.bert import Tokenizer

### 一、数据读入与处理

#### 1. 数据读入

In [None]:
class CondTMDataset(TMDataset):
    def __init__(
        self,
        data_path, 
        categories=None, 
        is_retain_dataset=False
    ):
        super(CondTMDataset, self).__init__(data_path, categories, is_retain_dataset)
        
        self.conditions = sorted(list(set([data['condition'] for data in self.dataset])))
        self.condition2id = dict(zip(self.conditions, range(len(self.conditions))))
        
    
    def _convert_to_transfomer_ids(self, bert_tokenizer):
        
        features = []
        for (index_, row_) in enumerate(self.dataset):
            input_ids = bert_tokenizer.sequence_to_ids(row_['text_a'], row_['text_b'])
            
            input_ids, input_mask, segment_ids = input_ids
            
            label_ids = self.cat2id[row_['label']]
            
            input_a_length = self._get_input_length(row_['text_a'], bert_tokenizer)
            input_b_length = self._get_input_length(row_['text_b'], bert_tokenizer)
            
            features.append({
                'input_ids': input_ids, 
                'attention_mask': input_mask, 
                'token_type_ids': segment_ids, 
                'condition_ids': self.condition2id[row_['condition']], 
                'label_ids': label_ids
            })
        
        return features        


In [None]:
train_data_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_train.json')
train_data_df = (train_data_df
                 .rename(columns={'text1': 'text_a', 'text2': 'text_b', 'category': 'condition'})
                 .loc[:,['text_a', 'text_b', 'condition', 'label']])

dev_data_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_dev.json')
dev_data_df = dev_data_df[dev_data_df['label'] != "NA"]
dev_data_df = (dev_data_df
                 .rename(columns={'text1': 'text_a', 'text2': 'text_b', 'category': 'condition'})
                 .loc[:,['text_a', 'text_b', 'condition', 'label']])

In [None]:
tm_train_dataset = CondTMDataset(train_data_df)
tm_dev_dataset = CondTMDataset(dev_data_df)

#### 2. 词典创建和生成分词器

In [None]:
tokenizer = Tokenizer(vocab='nghuyong/ernie-1.0', max_seq_len=50)

#### 3. ID化

In [None]:
tm_train_dataset.convert_to_ids(tokenizer)
tm_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
config = BertConfig.from_pretrained('nghuyong/ernie-1.0',
                                    num_labels=len(tm_train_dataset.cat2id))

In [None]:
config.num_conditions = len(tm_train_dataset.condition2id)

#### 2. 模型创建

In [None]:
from ark_nlp.nn.layer.layer_norm_block import CondLayerNormLayer

In [None]:
class CondBert(Bert):
    def __init__(
        self,
        config,
        encoder_trained=True,
        pooling='cls_with_pooler'
    ):
        super(CondBert, self).__init__(config, encoder_trained, pooling)
        
        self.condition_embed = nn.Embedding(config.num_conditions, config.hidden_size)
        nn.init.uniform_(self.condition_embed.weight.data)
        
        self.cond_layer_normal = CondLayerNormLayer(config.hidden_size)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        condition_ids=None,
        **kwargs
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            return_dict=True,
            output_hidden_states=True
        )
                
        condition_feature = self.condition_embed(condition_ids)
        
        encoder_feature = self.cond_layer_normal(outputs.hidden_states[-1], condition_feature)

        encoder_feature = self.mask_pooling(encoder_feature, attention_mask)

        encoder_feature = self.dropout(encoder_feature)
        out = self.classifier(encoder_feature)

        return out

In [None]:
dl_module = CondBert.from_pretrained('nghuyong/ernie-1.0',
                                     config=config)

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
# 设置运行次数
num_epoches = 5
batch_size = 32

In [None]:
optimizer = get_default_model_optimizer(dl_module)

#### 2. 任务创建

In [None]:
model = Task(dl_module, optimizer, 'ce', cuda_device=0)

#### 3. 训练

In [None]:
model.fit(tm_train_dataset, 
          tm_dev_dataset,
          lr=2e-5,
          epochs=num_epoches, 
          batch_size=batch_size
         )

<br>

### 四、生成提交数据

In [None]:
import torch

from torch.utils.data import DataLoader
from ark_nlp.factory.predictor.text_match import TMPredictor

class CondTMPredictor(TMPredictor):
    def __init__(
        self,
        module,
        tokernizer,
        cat2id,
        condition2id
    ):
        super(CondTMPredictor, self).__init__(module, tokernizer, cat2id)
        self.condition2id = condition2id

    def _convert_to_transfomer_ids(
        self,
        text_a,
        text_b,
        condition
    ):
        input_ids = self.tokenizer.sequence_to_ids(text_a, text_b)
        input_ids, input_mask, segment_ids = input_ids

        features = {
                'input_ids': input_ids,
                'attention_mask': input_mask,
                'token_type_ids': segment_ids,
                'condition_ids': np.array([self.condition2id[condition]])
            }
        return features

    def _get_input_ids(
        self,
        text_a,
        text_b,
        condition
    ):
        if self.tokenizer.tokenizer_type == 'transfomer':
            return self._convert_to_transfomer_ids(text_a, text_b, condition)
        else:
            raise ValueError("The tokenizer type does not exist")

    def predict_one_sample(
        self,
        text,
        condition,
        topk=None,
        return_label_name=True,
        return_proba=False
    ):
        if topk is None:
            topk = len(self.cat2id) if len(self.cat2id) > 2 else 1
        text_a, text_b = text
        features = self._get_input_ids(text_a, text_b, condition)
        self.module.eval()

        with torch.no_grad():
            inputs = self._get_module_one_sample_inputs(features)
            logit = self.module(**inputs)
            logit = torch.nn.functional.softmax(logit, dim=1)

        probs, indices = logit.topk(topk, dim=1, sorted=True)

        preds = []
        probas = []
        for pred_, proba_ in zip(indices.cpu().numpy()[0], probs.cpu().numpy()[0].tolist()):

            if return_label_name:
                pred_ = self.id2cat[pred_]

            preds.append(pred_)

            if return_proba:
                probas.append(proba_)

        if return_proba:
            return list(zip(preds, probas))

        return preds

    def predict_batch(
        self,
        test_data,
        batch_size=16,
        shuffle=False,
        return_label_name=True,
        return_proba=False
    ):
        self.inputs_cols = test_data.dataset_cols

        preds = []
        probas = []

        self.module.eval()
        generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)

        with torch.no_grad():
            for step, inputs in enumerate(generator):
                inputs = self._get_module_batch_inputs(inputs)

                logits = self.module(**inputs)

                preds.extend(torch.max(logits, 1)[1].cpu().numpy())
                if return_proba:
                    logits = torch.nn.functional.softmax(logits, dim=1)
                    probas.extend(logits.max(dim=1).values.cpu().detach().numpy())

        if return_label_name:
            preds = [self.id2cat[pred_] for pred_ in preds]

        if return_proba:
            return list(zip(preds, probas))

        return preds


In [None]:
tm_predictor_instance = CondTMPredictor(model.module, tokenizer, tm_train_dataset.cat2id, tm_train_dataset.condition2id)

In [None]:
import pandas as pd
test_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_test.json')

submit = []
for _id, _text_a, _text_b, _condition in zip(test_df['id'], test_df['text1'], test_df['text2'], test_df['category']):
    if _condition == 'daibetes':
        _condition = 'diabetes'

    predict_ = tm_predictor_instance.predict_one_sample([_text_a, _text_b], _condition)[0] 
    
    submit.append({
        'id': str(_id),
        'text1': _text_a,
        'text2': _text_b,
        'label': predict_,
        'category': _condition
    })

In [None]:
import json

output_path = '../data/output_datasets/CHIP-STS_test.json'

with open(output_path, 'w', encoding='utf-8') as f:
    f.write(json.dumps(submit, ensure_ascii=False))