In [1]:
import json
import sys,os
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

sys.path.extend(['/root/xiaoda/query_topic/'])

In [3]:
import torch
from torch.nn import functional as F
import numpy as np
import random
import torch.nn as nn
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np

"""
https://github.com/ondrejbohdal/meta-calibration/blob/main/Metrics/metrics.py
"""

class ECE(nn.Module):
    
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECE, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels, mode='logits'):
        if mode == 'logits':
            softmaxes = F.softmax(logits, dim=1)
        else:
            softmaxes = logits
        # softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

In [4]:
import torch
import json
import sys
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast
import transformers
from datetime import timedelta

import os, sys
cur_dir_path = '/root/xiaoda/query_topic/'

sys.path.extend([cur_dir_path])

from nets.them_classifier import MyBaseModel, RobertaClassifier

import configparser
from tqdm import tqdm

class TopicInfer(object):
    def __init__(self, config_path):

        import torch, os, sys

        con = configparser.ConfigParser()
        con_path = os.path.join(cur_dir_path, config_path)
        con.read(con_path, encoding='utf8')

        args_path = dict(dict(con.items('paths')), **dict(con.items("para")))
        self.tokenizer = BertTokenizerFast.from_pretrained(args_path["model_path"], do_lower_case=True)

        label_list = []
        label_path = os.path.join(cur_dir_path, args_path['label_path'])
        with open(label_path, 'r') as frobj:
            for line in frobj:
                label_list.append(line.strip())
        n_classes = len(label_list)

        self.label2id, self.id2label = {}, {}
        for idx, label in enumerate(label_list):
            self.label2id[label] = idx
            self.id2label[idx] = label
            
        print(self.label2id, '===', self.id2label)
        
        output_path = os.path.join(cur_dir_path, args_path['output_path'])

        from roformer import RoFormerModel, RoFormerConfig

        config = RoFormerConfig.from_pretrained(args_path["model_path"])
        encoder = RoFormerModel(config=config)
        
        encoder_net = MyBaseModel(encoder, config)

        classify_net = RobertaClassifier(
            hidden_size=config.hidden_size, 
            dropout_prob=con.getfloat('para', 'out_dropout_rate'),
            num_labels=n_classes, 
            dropout_type=con.get('para', 'dropout_type'))

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        class TopicClassifier(nn.Module):
            def __init__(self, transformer, classifier):
                super().__init__()

                self.transformer = transformer
                self.classifier = classifier

            def forward(self, input_ids, input_mask, 
                        segment_ids=None, transformer_mode='mean_pooling'):
                hidden_states = self.transformer(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              return_mode=transformer_mode)
                ce_logits = self.classifier(hidden_states)
                return ce_logits, hidden_states

        import os
        self.net = TopicClassifier(encoder_net, classify_net).to(self.device)
        # eo = 9
        # ckpt = torch.load(os.path.join(output_path, 'cls.pth.{}'.format(eo)), map_location=self.device)
        # self.topic_net.load_state_dict(ckpt)
        # self.topic_net.eval()
        
    def reload(self, model_path):
        ckpt = torch.load(model_path, map_location=self.device)
        self.net.load_state_dict(ckpt)
        self.net.eval() 

    def predict(self, text, top_n=5):

        """抽取输入text所包含的类型
        """
        token2char_span_mapping = self.tokenizer(text, return_offsets_mapping=True, max_length=256)["offset_mapping"]
        encoder_txt = self.tokenizer.encode_plus(text, max_length=256)
        input_ids = torch.tensor(encoder_txt["input_ids"]).long().unsqueeze(0).to(self.device)
        token_type_ids = torch.tensor(encoder_txt["token_type_ids"]).unsqueeze(0).to(self.device)
        attention_mask = torch.tensor(encoder_txt["attention_mask"]).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            scores, hidden_states = self.net(input_ids, attention_mask, token_type_ids, 
                         transformer_mode='cls'
                         )
            scores = torch.nn.Softmax(dim=1)(scores)[0].data.cpu().numpy()
        
        schema_types = []
        for index, score in enumerate(scores):
             schema_types.append([self.id2label[index], float(score)])
        schema_types = sorted(schema_types, key=lambda item:item[1], reverse=True)
        return schema_types[0:5]
    
    def predict_batch(self, text):
        if isinstance(text, list):
            text_list = text
        else:
            text_list = [text]
        model_input = self.tokenizer(text_list, return_tensors="pt",padding=True)
        for key in model_input:
            model_input[key] = model_input[key].to(self.device)
        with torch.no_grad():
            scores, hidden_states = self.net(model_input['input_ids'], 
                        model_input['attention_mask'], 
                        model_input['token_type_ids'], 
                         transformer_mode='cls'
                         )
            scores = torch.nn.Softmax(dim=1)(scores).data.cpu().numpy()
        schema_types_list = []
        for score_, text in zip(scores, text_list):
            schema_types = []
            for index, score in enumerate(score_):
                 schema_types.append([self.id2label[index], float(score)])
            schema_types = sorted(schema_types, key=lambda item:item[1], reverse=True)
            schema_types_list.append(schema_types[0:5])
        return schema_types_list

topic_api = TopicInfer('./topic_data_v4/config.ini')

# text = '王二今天打车去了哪里，从哪里出发，到哪里了'
# print(topic_api.predict(text), text)


{'基金': 0, '汽车': 1, '网络安全': 2, '情商': 3, '财务税务': 4, '旅游': 5, '建筑': 6, '炫富': 7, '家电': 8, '商业': 9, '性生活': 10, '风水': 11, '人物': 12, 'VPN': 13, '吸烟': 14, '市场营销': 15, '游戏': 16, '算命': 17, '编程': 18, '死亡': 19, '公司': 20, '国家': 21, '美食/烹饪': 22, '明星': 23, '城市': 24, '银行': 25, '期货': 26, '宗教': 27, '学习': 28, '电子数码': 29, '网络暴力': 30, 'LGBT': 31, '其他': 32, '故事': 33, '社会': 34, '二手': 35, '动漫': 36, '歧视': 37, '常识': 38, '星座': 39, '冷知识': 40, '职场职业': 41, '食品': 42, '心理健康': 43, '电子商务': 44, '道德伦理': 45, '商业/理财': 46, '赚钱': 47, '神话': 48, '校园生活': 49, '色情': 50, '婚姻': 51, '家居装修': 52, '生活': 53, '灵异灵修': 54, '股票': 55, '娱乐': 56, '女性': 57, '体验': 58, '广告': 59, '天气': 60, '女权': 61, '潜规则': 62, '人类': 63, '马克思主义': 64, '历史': 65, '音乐': 66, '毒品': 67, '摄影': 68, '金融': 69, '影视': 70, '语言': 71, '环境': 72, '高铁': 73, '人际交往': 74, '夜店': 75, '价值观': 76, '恋爱': 77, '相貌': 78, 'BDSM': 79, '恐怖主义': 80, '中医': 81, '性侵犯': 82, '阅读': 83, '时尚': 84, '体育/运动': 85, '资本主义': 86, '灾害意外': 87, '博彩': 88, '成长': 89, '校园暴力': 90, '移民': 91, '美容/塑身': 92, '经济': 93, '睡眠': 94, 

12/27/2022 08:17:43 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++


In [5]:
# topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v4_label_smoothing/them/cls.pth.9')
topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v4/them/cls.pth.9')

# topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v4/them_3epoch/cls.pth.2')

In [8]:
topic_api.predict_batch('路口确定没车，可以闯红灯吗，为什么')[0]

[['法律', 0.5339043736457825],
 ['交通出行', 0.2583194375038147],
 ['城市', 0.06136595457792282],
 ['社会', 0.04893098771572113],
 ['道德伦理', 0.04017011821269989]]

In [15]:
v4_valid = []
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_valid.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        v4_valid.append(content)
        

In [53]:
from sklearn.metrics import classification_report
import numpy as np

def eval_all(data, model, top_n=5):
    pred = []
    gold = []
    pred_score = []
    pred = 0
    queue = []
    tt = []
    total_pred = []
    for item in tqdm(data):
        gold.append(item['label'][0])
        if isinstance(item['text'], list):
            text = "\n".join(item['text'])
        else:
            text = item['text']
        queue.append(text)
        tt.append(item)
        if np.mod(len(queue), 128) == 0:
            result_list = model.predict_batch(queue)
            for result, text, t in zip(result_list, queue, tt):
                score = sorted(result, key=lambda u:u[1], reverse=True)
                pred_set = set([p[0] for p in score[:top_n]])
                total_pred.append(score[0][0])
                if set(t['label']) & pred_set:
                    pred += 1
                pred_score.append(result)
            queue = []
            tt = []
    if queue:
        result_list = model.predict_batch(queue)
        for result, text, t in zip(result_list, queue, tt):
            score = sorted(result, key=lambda u:u[1], reverse=True)
            pred_set = set([p[0] for p in score[:top_n]])
            total_pred.append(score[0][0])
            if set(t['label']) & pred_set:
                pred += 1
            pred_score.append(result)
        # break
    print(classification_report(gold, total_pred, digits=4), '===', top_n)
    print(pred/len(pred_score))
    return pred_score, gold

In [140]:
target = []
for item in v4_valid:
    if item['label'][0] in ['道德伦理']:
        target.append(item)

In [157]:
target_result = []
for item in target:
    result = topic_api.predict(item['text'])
    if result[0][0] not in item['label']:
        target_result.append((result, item))

In [97]:
pred_score, gold = eval_all(v4_valid, topic_api, top_n=2)

100%|██████████| 86668/86668 [00:30<00:00, 2864.74it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

        BDSM     0.3714    0.8667    0.5200        15
        LGBT     0.6643    0.8675    0.7524       317
         VPN     1.0000    0.2000    0.3333         5
          两性     0.2766    0.1919    0.2266      1084
          中医     0.3849    0.8014    0.5200       146
          二手     0.2800    0.7368    0.4058        19
          交友     0.0000    0.0000    0.0000        25
        交通出行     0.2712    0.0430    0.0742       372
          人物     0.4595    0.6267    0.5302       217
          人类     0.4146    0.5862    0.4857        29
        人际交往     0.4126    0.3812    0.3963       669
         价值观     0.3568    0.3838    0.3698       357
       体育/运动     0.5243    0.8467    0.6476      2100
          体验     0.4423    0.1885    0.2644       122
          保险     0.5401    0.8707    0.6667       116
          健康     0.7144    0.6894    0.7017      6459
          公司     0.3616    0.6514    0.4650       872
          其他     0.0000    

  _warn_prf(average, modifier, msg_start, len(result))


86668

In [69]:
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.risk_topic', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.risk', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            # content['text'] = re.sub('请问', '', content['text'])
            queue.append(content['text'])
            t.append(content)
            if np.mod(len(queue), 32) == 0:
                probs = topic_api.predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    tt['topic'] = prob_dict
                    fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = topic_api.predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                tt['topic'] = prob_dict
                fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')

24721it [00:13, 1896.10it/s]


In [114]:

with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/疑似有风险query_from对话预训练数据-20221130.txt') as frobj:
    for line in tqdm(frobj):
        text = line.strip()
        result_list.append((text, result))

topic_filter = []
topic_white = []
left = []
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.risk_topic') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        if content['score_list']['senti'][1][1] > 0.7:
            if content['topic'][0][0] in ['经济', '历史', 'LGBT', 'BDSM', '法律', '灵异灵修', '国家', '社会', '军事', '心理健康', '时事政治', '死亡', '毒品', '恐怖主义', '战争', '灵异事件']:
                topic_filter.append(content)
            else:
                if '艾滋病' in content['text'] or '抑郁' in content['text']:
                    topic_filter.append(content)
                else:
                    topic_white.append(content)
        else:
            topic_filter.append(content)
        # if content['topic'][0][0] in ['文化/艺术', '动漫', '音乐', '娱乐', '体育/运动', '小说', '星座', '时尚', '健康', '明星', '游戏'] and content['topic'][0][1] > 0.2:
        #     if content['score_list']['senti'][0][1] < 0.3:
        #         topic_white.append(content)
        #     elif content['score_list']['senti'][0][1] > 0.9:
        #         topic_filter.append(content)
        # elif content['topic'][0][0] in ['LGBT', 'BDSM', '法律', '灵异灵修', '国家', '社会', '军事', '心理健康', '时事政治', '死亡', '毒品', '恐怖主义', '战争', '灵异事件']:
        #     topic_filter.append(content)
        # elif content['topic'][1][0] in ['LGBT', 'BDSM', '法律', '灵异灵修', '国家', '社会', '军事', '心理健康', '时事政治', '死亡', '毒品', '恐怖主义', '战争', '灵异事件']:
        #     topic_filter.append(content)
        # else:
        #     left.append(content)
                

In [118]:
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.filter', 'w') as fwobj:
    for d in topic_filter:
        d['label'] = ['风险']
        fwobj.write(json.dumps(d, ensure_ascii=False)+'\n')
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.white', 'w') as fwobj:
    for d in topic_white:
        d['label'] = ['正常']
        fwobj.write(json.dumps(d, ensure_ascii=False)+'\n')

In [115]:
len(topic_white), len(topic_filter)

(2274, 22447)