In [None]:
import torch
import pickle
import pandas as pd

from ark_nlp.model.tm.bert import Bert
from ark_nlp.model.tm.bert import BertConfig
from ark_nlp.model.tm.bert import Dataset
from ark_nlp.model.tm.bert import Task
from ark_nlp.model.tm.bert import get_default_model_optimizer
from ark_nlp.model.tm.bert import Tokenizer

### 一、数据读入与处理

#### 1. 召回模型

In [None]:
import math
import copy
import logging
import numpy as np

from six import iteritems


logger = logging.getLogger(__name__)


class BM25(object):
    """
    BM25模型

    Args:
        corpus (:obj:`list`):
            检索的语料
        k1 (:obj:`float`, optional, defaults to 1.5):
            取正值的调优参数，用于文档中的词项频率进行缩放控制
        b (:obj:`float`, optional, defaults to 0.75):
            0到1之间的参数，决定文档长度的缩放程度，b=1表示基于文档长度对词项权重进行完全的缩放，b=0表示归一化时不考虑文档长度因素
        epsilon (:obj:`float`, optional, defaults to 0.25):
            idf的下限值
        tokenizer (:obj:`object`, optional, defaults to None):
            分词器，用于对文档进行分词操作，默认为None，按字颗粒对文档进行分词
        is_retain_docs (:obj:`bool`, optional, defaults to False):
            是否保持原始文档

    Reference:
        [1] https://github.com/RaRe-Technologies/gensim/blob/3.8.3/gensim/summarization/bm25.py
    """  # noqa: ignore flake8"

    def __init__(
        self,
        corpus,
        k1=1.5,
        b=0.75,
        epsilon=0.25,
        tokenizer=None,
        is_retain_docs=False
    ):
        self.k1 = k1
        self.b = b
        self.epsilon = epsilon

        self.docs = None
        self.corpus_size = 0
        self.avgdl = 0
        self.doc_freqs = []
        self.idf = {}
        self.doc_len = []

        if is_retain_docs:
            self.docs = copy.deepcopy(corpus)

        if tokenizer:
            corpus = [self.tokenizer.tokenize(document) for document in corpus]
        else:
            corpus = [list(document) for document in corpus]

        self._initialize(corpus)

    def _initialize(self, corpus):
        """Calculates frequencies of terms in documents and in corpus. Also computes inverse document frequencies."""
        nd = {}  # word -> number of documents with word
        num_doc = 0
        for document in corpus:                        
            self.corpus_size += 1
            self.doc_len.append(len(document))
            num_doc += len(document)

            frequencies = {}
            for word in document:
                if word not in frequencies:
                    frequencies[word] = 0
                frequencies[word] += 1
            self.doc_freqs.append(frequencies)

            for word, freq in iteritems(frequencies):
                if word not in nd:
                    nd[word] = 0
                nd[word] += 1

        self.avgdl = float(num_doc) / self.corpus_size

        idf_sum = 0
        negative_idfs = []
        for word, freq in iteritems(nd):
            idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5)
            self.idf[word] = idf
            idf_sum += idf
            if idf < 0:
                negative_idfs.append(word)
        self.average_idf = float(idf_sum) / len(self.idf)

        if self.average_idf < 0:
            logger.warning(
                'Average inverse document frequency is less than zero. Your corpus of {} documents'
                ' is either too small or it does not originate from natural text. BM25 may produce'
                ' unintuitive results.'.format(self.corpus_size)
            )

        eps = self.epsilon * self.average_idf
        for word in negative_idfs:
            self.idf[word] = eps

    def get_score(self, query, index):
        score = 0.0
        doc_freqs = self.doc_freqs[index]
        numerator_constant = self.k1 + 1
        denominator_constant = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl)
        for word in query:
            if word in doc_freqs:
                df = self.doc_freqs[index][word]
                idf = self.idf[word]
                score += (idf * df * numerator_constant) / (df + denominator_constant)
        return score

    def get_scores(self, query):
        scores = [self.get_score(query, index) for index in range(self.corpus_size)]
        return scores

    def recall(self, query, topk=5):
        scores = self.get_scores(query)
        indexs = np.argsort(scores)[::-1][:topk]

        if self.docs is None:
            return [[i, scores[i]] for i in indexs]
        else:
            return [[self.docs[i], scores[i]] for i in indexs]

bm25_model = pickle.load(open('../checkpoint/recall/bm25_model.pkl', 'rb'))
map_dict = pickle.load(open('../checkpoint/recall/map_dict.pkl', 'rb'))

#### 2. 数据生成

In [None]:
train_data_df = pd.read_json('../data/source_datasets/CHIP-CDN/CHIP-CDN_train.json')
dev_data_df = pd.read_json('../data/source_datasets/CHIP-CDN/CHIP-CDN_dev.json')

In [None]:
pair_dataset = []
for _raw_word, _normalized_result in zip(train_data_df['text'], train_data_df['normalized_result']):
    normalized_words = set(_normalized_result.split('##'))
    search_result_ = set()
    train_pair_dataset = []
    for _index, _search_word in enumerate(
        [_result for _results in bm25_model.recall(_raw_word, topk=1000) for _result in map_dict[_results[0]]]):

        if _search_word in normalized_words:
            continue
        elif _search_word in search_result_:
            continue
        else:
            train_pair_dataset.append([_raw_word, _search_word, '0'])
            
        search_result_.add(_search_word)
            
        if len(train_pair_dataset) == 20:
            pair_dataset.extend(train_pair_dataset)
            break
                    
    for _st_word in normalized_words:
        for _ in range(10):
            pair_dataset.append([_raw_word, _st_word, '1'])

In [None]:
pair_dev_dataset = []
for _raw_word, _normalized_result in zip(train_data_df['text'], train_data_df['normalized_result']):
    normalized_words = set(_normalized_result.split('##'))
    search_result_ = set()
    dev_pair_dataset = []
    for _index, _search_word in enumerate(
        [_result for _results in bm25_model.recall(_raw_word, topk=1000) for _result in map_dict[_results[0]]]):

        if _search_word in normalized_words:
            continue
        elif _search_word in search_result_:
            continue
        else:
            dev_pair_dataset.append([_raw_word, _search_word, '0'])
            
        search_result_.add(_search_word)
            
        if len(dev_pair_dataset) == 1:
            pair_dev_dataset.extend(dev_pair_dataset)
            break
                    
    for _st_word in normalized_words:
        pair_dev_dataset.append([_raw_word, _st_word, '1'])

In [None]:
train_data_df = pd.DataFrame(pair_dataset, columns=['text_a', 'text_b', 'label'])
dev_data_df = pd.DataFrame(pair_dev_dataset, columns=['text_a', 'text_b', 'label'])

In [None]:
tm_train_dataset = Dataset(train_data_df)
tm_dev_dataset = Dataset(dev_data_df)

#### 2. 词典创建

In [None]:
tokenizer = Tokenizer(vocab='nghuyong/ernie-1.0', max_seq_len=50)

#### 3. 生成分词器

#### 4. ID化

In [None]:
tm_train_dataset.convert_to_ids(tokenizer)
tm_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
config = BertConfig.from_pretrained('nghuyong/ernie-1.0',
                                    num_labels=len(tm_train_dataset.cat2id))

#### 2. 模型创建

In [None]:
torch.cuda.empty_cache()

In [None]:
dl_module = Bert.from_pretrained('nghuyong/ernie-1.0', 
                                 config=config)

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
# 设置运行次数
num_epoches = 2
batch_size = 32

In [None]:
param_optimizer = list(dl_module.named_parameters())
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]      

#### 2. 任务创建

In [None]:
model = Task(dl_module, 'adamw', 'lsce', cuda_device=0, ema_decay=0.995)

#### 3. 训练

In [None]:
model.fit(tm_train_dataset, 
          tm_dev_dataset,
          lr=3e-5,
          epochs=num_epoches, 
          batch_size=batch_size,
          params=optimizer_grouped_parameters
         )

In [None]:
model.ema.store(model.module.parameters())
model.ema.copy_to(model.module.parameters())  

<br>

### 四、模型验证与保存

#### 1. 模型验证

In [None]:
from ark_nlp.factory.predictor import TMPredictor

In [None]:
tm_predictor_instance = TMPredictor(model.module, tokenizer, tm_train_dataset.cat2id)

In [None]:
tm_predictor_instance.predict_one_sample(['胸部皮肤破裂伤', '胸部开放性损伤'], return_proba=True)

#### 2. 模型保存

In [None]:
import pickle

In [None]:
torch.save(model.module.state_dict(), '../checkpoint/textsim/module.pth')

In [None]:
with open('../checkpoint/textsim/cat2id.pkl', "wb") as f:
    pickle.dump(tm_train_dataset.cat2id, f)