In [1]:
import json
import sys,os
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

sys.path.extend(['/root/xiaoda/query_topic/'])

In [3]:
import torch
from torch.nn import functional as F
import numpy as np
import random
import torch.nn as nn
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np

"""
https://github.com/ondrejbohdal/meta-calibration/blob/main/Metrics/metrics.py
"""

class ECE(nn.Module):
    
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECE, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels, mode='logits'):
        if mode == 'logits':
            softmaxes = F.softmax(logits, dim=1)
        else:
            softmaxes = logits
        # softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

In [4]:
import torch
import json
import sys
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast
import transformers
from datetime import timedelta

import os, sys

from nets.them_classifier import MyBaseModel, RobertaClassifier

import configparser
from tqdm import tqdm

cur_dir_path = '/root/xiaoda/query_topic/'

def load_label(filepath):
    label_list = []
    with open(filepath, 'r') as frobj:
        for line in frobj:
            label_list.append(line.strip())
        n_classes = len(label_list)

        label2id = {}
        id2label = {}
        for idx, label in enumerate(label_list):
            label2id[label] = idx
            id2label[idx] = label
        return label2id, id2label

class RiskInfer(object):
    def __init__(self, config_path):

        import torch, os, sys

        con = configparser.ConfigParser()
        con_path = os.path.join(cur_dir_path, config_path)
        con.read(con_path, encoding='utf8')

        args_path = dict(dict(con.items('paths')), **dict(con.items("para")))
        self.tokenizer = BertTokenizerFast.from_pretrained(args_path["model_path"], do_lower_case=True)

        from collections import OrderedDict
        self.schema_dict = OrderedDict({})

        for label_index, schema_info in enumerate(args_path["label_path"].split(',')):
            schema_type, schema_path = schema_info.split(':')
            schema_path = os.path.join(cur_dir_path, schema_path)
            print(schema_type, schema_path, '===schema-path===')
            label2id, id2label = load_label(schema_path)
            self.schema_dict[schema_type] = {
                'label2id':label2id,
                'id2label':id2label,
                'label_index':label_index
            }
            print(self.schema_dict[schema_type], '==schema_type==', schema_type)
        
        output_path = os.path.join(cur_dir_path, args_path['output_path'])

        from roformer import RoFormerModel, RoFormerConfig

        config = RoFormerConfig.from_pretrained(args_path["model_path"])
        encoder = RoFormerModel(config=config)
        
        encoder_net = MyBaseModel(encoder, config)

        self.device = "cuda:3" if torch.cuda.is_available() else "cpu"

        classifier_list = []

        schema_list = list(self.schema_dict.keys())

        for schema_key in schema_list:
            classifier = RobertaClassifier(
                hidden_size=config.hidden_size, 
                dropout_prob=con.getfloat('para', 'out_dropout_rate'),
                num_labels=len(self.schema_dict[schema_key]['label2id']), 
                dropout_type=con.get('para', 'dropout_type'))
            classifier_list.append(classifier)

        classifier_list = nn.ModuleList(classifier_list)

        class MultitaskClassifier(nn.Module):
            def __init__(self, transformer, classifier_list):
                super().__init__()

                self.transformer = transformer
                self.classifier_list = classifier_list

            def forward(self, input_ids, input_mask, 
                        segment_ids=None, 
                        transformer_mode='mean_pooling', 
                        dt_idx=None):
                hidden_states = self.transformer(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              return_mode=transformer_mode)
                outputs_list = []
                
                for idx, classifier in enumerate(self.classifier_list):
                    
                    if dt_idx is not None and idx != dt_idx:
                        continue
                    
                    ce_logits = classifier(hidden_states)
                    outputs_list.append(ce_logits)
                return outputs_list, hidden_states

        self.net = MultitaskClassifier(encoder_net, classifier_list).to(self.device)

        # eo = 9
        # ckpt = torch.load(os.path.join(output_path, 'multitask_cls.pth.{}.raw'.format(eo)), map_location=self.device)
        # # ckpt = torch.load(os.path.join(output_path, 'multitask_cls.pth.{}.raw.focal'.format(eo)), map_location=self.device)
        # # ckpt = torch.load(os.path.join(output_path, 'multitask_contrast_cls.pth.{}'.format(eo)), map_location=self.device)
        # self.net.load_state_dict(ckpt)
        # self.net.eval()
        
    def reload(self, model_path):
        ckpt = torch.load(model_path, map_location=self.device)
        self.net.load_state_dict(ckpt)
        self.net.eval()

    def predict(self, text):

        """抽取输入text所包含的类型
        """
        encoder_txt = self.tokenizer.encode_plus(text, max_length=256)
        input_ids = torch.tensor(encoder_txt["input_ids"]).long().unsqueeze(0).to(self.device)
        token_type_ids = torch.tensor(encoder_txt["token_type_ids"]).unsqueeze(0).to(self.device)
        attention_mask = torch.tensor(encoder_txt["attention_mask"]).unsqueeze(0).to(self.device)
        
        scores_dict = {}
        with torch.no_grad():
            [logits_list, 
            hidden_states] = self.net(input_ids, 
                attention_mask, token_type_ids, transformer_mode='cls')
        for schema_type, logits in zip(list(self.schema_dict.keys()), logits_list):
            scores = torch.nn.Softmax(dim=1)(logits)[0].data.cpu().numpy()
            scores_dict[schema_type] = []
            for index, score in enumerate(scores):
                scores_dict[schema_type].append([self.schema_dict[schema_type]['id2label'][index], 
                                        float(score)])
        return scores_dict
    
    def get_logitnorm(self, text):
        """抽取输入text所包含的类型
        """
        encoder_txt = self.tokenizer.encode_plus(text, max_length=256)
        input_ids = torch.tensor(encoder_txt["input_ids"]).long().unsqueeze(0).to(self.device)
        token_type_ids = torch.tensor(encoder_txt["token_type_ids"]).unsqueeze(0).to(self.device)
        attention_mask = torch.tensor(encoder_txt["attention_mask"]).unsqueeze(0).to(self.device)
        
        scores_dict = {}
        logits_norm_list = []
        with torch.no_grad():
            [logits_list, 
            hidden_states] = self.net(input_ids, 
                attention_mask, token_type_ids, transformer_mode='cls')
            for logits in logits_list:
                logits_norm_list.append(logits/torch.norm(logits, p=2, dim=-1, keepdim=True) + 1e-7)
        for schema_type, logit_norm in zip(list(self.schema_dict.keys()), logits_norm_list):
            scores_dict[schema_type] = logit_norm[0].data.cpu().numpy()
        return scores_dict
            
    
    def predict_batch(self, text):
        if isinstance(text, list):
            text_list = text
        else:
            text_list = [text]
        model_input = self.tokenizer(text_list, return_tensors="pt",padding=True)
        for key in model_input:
            model_input[key] = model_input[key].to(self.device)
        with torch.no_grad():
            [logits_list, 
            hidden_states] = self.net(model_input['input_ids'], 
                model_input['attention_mask'], 
                model_input['token_type_ids'], transformer_mode='cls')
        score_dict_list = []
        for idx, text in enumerate(text_list):
            scores_dict = {}
            for schema_type, logits in zip(list(self.schema_dict.keys()), logits_list):
                scores = torch.nn.Softmax(dim=1)(logits)[idx].data.cpu().numpy()
                scores_dict[schema_type] = []
                for index, score in enumerate(scores):
                    scores_dict[schema_type].append([self.schema_dict[schema_type]['id2label'][index], 
                                            float(score)])
            score_dict_list.append(scores_dict)
        return score_dict_list

# risk_api = RiskInfer('./risk_data/config.ini')
risk_api = RiskInfer('./risk_data_v5/config.ini')




senti_query /root/xiaoda/query_topic/risk_data_v5/senti_query_label.txt ===schema-path===
{'label2id': {'负向': 0, '中性': 1, '正向': 2}, 'id2label': {0: '负向', 1: '中性', 2: '正向'}, 'label_index': 0} ==schema_type== senti_query
senti /root/xiaoda/query_topic/risk_data_v5/senti_label.txt ===schema-path===
{'label2id': {'负向': 0, '正向': 1}, 'id2label': {0: '负向', 1: '正向'}, 'label_index': 1} ==schema_type== senti
bias /root/xiaoda/query_topic/risk_data_v5/bias_label.txt ===schema-path===
{'label2id': {'偏见': 0, '正常': 1}, 'id2label': {0: '偏见', 1: '正常'}, 'label_index': 2} ==schema_type== bias
ciron /root/xiaoda/query_topic/risk_data_v5/ciron_label.txt ===schema-path===
{'label2id': {'讽刺': 0, '正常': 1}, 'id2label': {0: '讽刺', 1: '正常'}, 'label_index': 3} ==schema_type== ciron
intent /root/xiaoda/query_topic/risk_data_v5/intention_label_v0.txt ===schema-path===
{'label2id': {'主观评价/比较/判断': 0, '寻求建议/帮助': 1, '其它': 2}, 'id2label': {0: '主观评价/比较/判断', 1: '寻求建议/帮助', 2: '其它'}, 'label_index': 4} ==schema_type== intent

01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++


In [5]:
import csv
l = 0
reader = csv.reader(open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/Query风险分类_全部数据.csv'), delimiter="\t", quotechar=None)
for idx, item in enumerate(reader):
    # print(item)
    l += 1
l

9965

In [None]:
with open("/data/albert.xht/pretrained_model_risk/corpus/efaqa-corpus-zh/efaqa-corpus-zh.utf8", "r") as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        title = ''.join(re.split('[\s,]', content['title'])[1:])
        if len(title) >= 5:
            if  s3_mapping[content['label']['s3']] in ['正在进行的自杀行为', '策划进行的自杀行为', '自残']:
                tmp = {
                    'title':title,
                    'label':['风险']
                }
                print(tmp)

In [43]:
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v5_intent_v2_3/multitask_cls.pth.4')
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v5_intent_v2-1_3//multitask_cls.pth.4')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v6_intent_v2-1_10/multitask_cls.pth.6')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v6_intent_v2-1_10_no_symbol/multitask_cls.pth.5')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/risk_classification/multitask_raw_filter_senti_query_risk_v7_intent_v2-1_10_no_symbol/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v8_intent_v2-1_10_no_symbol/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v8_intent_v2-1_10_no_symbol_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v9_intent_v2-1_10_no_symbol_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v10_intent_v2-1_10_no_symbol_senti_query_senta_balanced_logitclip/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v10_intent_v2-1_10_no_symbol_senti_query_senta_balanced_logitclip_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v11_intent_v2-1_10_no_symbol_senti_query_senta_balanced_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v12_intent_v2-1_10_no_symbol_senti_query_senta_balanced_v1/multitask_cls.pth.8')

risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v12_intent_v2-1_10_no_symbol_senti_query_senta_balanced_v2/multitask_cls.pth.5')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v6_no_offensive_intent_v2-1_10//multitask_cls.pth.6')
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v5_no_offensive_intent_v2-1_3/multitask_cls.pth.4')


# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v4_intent_v2/multitask_cls.pth.9')


In [72]:
risk_api.predict('AAB的成语有哪些')

{'senti_query': [['负向', 0.00832600798457861],
  ['中性', 0.6789711713790894],
  ['正向', 0.31270286440849304]],
 'senti': [['负向', 0.7019690871238708], ['正向', 0.29803088307380676]],
 'bias': [['偏见', 0.045200686901807785], ['正常', 0.9547992944717407]],
 'ciron': [['讽刺', 0.48146119713783264], ['正常', 0.518538773059845]],
 'intent': [['主观评价/比较/判断', 2.1881326119910227e-06],
  ['寻求建议/帮助', 8.18282387626823e-06],
  ['其它', 0.9999896287918091]],
 'offensive': [['冒犯', 0.023488081991672516], ['正常', 0.9765118360519409]],
 'query_risk': [['风险', 0.010151089169085026],
  ['个人信息', 0.00015539505693595856],
  ['正常', 0.9896935820579529]],
 'teenager': [['不良', 0.02531786449253559], ['正常', 0.9746821522712708]]}

In [45]:
def batch_inference(input_path, output_path):
    from tqdm import tqdm
    import numpy as np
    import json, re

    def risk_predict_batch(text):
        if isinstance(text, list):
            text_list = text
        else:
            text_list = [text]
        result_list = risk_api.predict_batch(text_list)
        return result_list
    
    print(input_path, '===input-path===')
    print(output_path, '===output-path===')
    
    with open(output_path, 'w') as fwobj:
        with open(input_path, 'r') as frobj:
            queue = []
            t = []
            for line in tqdm(frobj):
                content = json.loads(line.strip())
                content['text'] = re.sub('请问', '', content['text'])
                text = re.sub(r"([，\_《。》、？；：‘’＂“”【「】」·！@￥…（）—\,\<\.\>\/\?\;\:\'\"\[\]\{\}\~\`\!\@\#\$\%\^\&\*\(\)\-\=\+])+", "", content['text'])   # 合并正文中过多的空格
                queue.append(text)
                t.append(content)
                if np.mod(len(queue), 128) == 0:
                    probs = risk_predict_batch(queue)
                    for prob_dict, text, tt in zip(probs, queue, t):
                        content = {
                            'text':tt['text'],
                            'topic':tt['label'],
                            'score_list':prob_dict,
                            # 'score_list': tt['score_list']
                        }
                        fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
                    queue = []
                    t = []
            if queue:
                probs = risk_predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    content = {
                        'text':tt['text'],
                        'topic':tt['label'],
                        'score_list':prob_dict,
                        # 'score_list': tt['score_list']
                    }
                    fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
                    



In [53]:
input_path = '/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive.topic'
output_path = '/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive.topic.v10'


batch_inference(input_path, output_path)

128it [00:00, 1057.32it/s]

/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive.topic ===input-path===
/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive.topic.v10 ===output-path===


25507it [00:23, 1074.57it/s]


In [54]:
input_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/embed_linear_small_white.json.topic'
output_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/embed_linear_small_white.json.topic.v10'


batch_inference(input_path, output_path)

128it [00:00, 1055.62it/s]

/data/albert.xht/raw_chat_corpus/topic_classification_v4/embed_linear_small_white.json.topic ===input-path===
/data/albert.xht/raw_chat_corpus/topic_classification_v4/embed_linear_small_white.json.topic.v10 ===output-path===


23116it [00:21, 1058.94it/s]


In [None]:
input_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.topic.knn.final'
output_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk.v9.1'


batch_inference(input_path, output_path)

In [43]:
input_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.topic.knn.final'
output_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk.v9.1'


batch_inference(input_path, output_path)


0it [00:00, ?it/s]

/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.topic.knn.final ===input-path===
/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk.v9.1 ===output-path===


2071401it [45:32, 758.09it/s]


In [65]:

input_path = '/data/albert.xht/xiaodao/query_risk_v10/biake_qa_web_text_zh_train.json.offensive.all'
output_path = '/data/albert.xht/xiaodao/query_risk_v10/biake_qa_web_text_zh_train.json.offensive.all.v10.1'

batch_inference(input_path, output_path)



128it [00:00, 959.94it/s]

/data/albert.xht/xiaodao/query_risk_v10/biake_qa_web_text_zh_train.json.offensive.all ===input-path===
/data/albert.xht/xiaodao/query_risk_v10/biake_qa_web_text_zh_train.json.offensive.all.v10.1 ===output-path===


18456it [00:17, 1072.16it/s]


In [45]:
black = []
white = []
import fast_json as json
from tqdm import tqdm

with open(output_path) as frobj:
    for line in tqdm(frobj):
        content = json.loads(line.strip())
        if content['score_list']['query_risk'][0][1] > 0.5:
            black.append(content)
        else:
            white.append(content)

2071401it [04:34, 7542.81it/s] 


177560

In [48]:
black_v1 = []
black_white = []
for d in black:
    if d['score_list']['query_risk'][0][1] > 0.9:
        black_v1.append(d)
    else:
        black_white.append(d)

In [77]:
sssss = []
for d in black_white:
    p = []
    for key in ['senti', 'offensive', 'teenager', 'senti_query', 'senti_query']:
        if d['score_list'][key][0][1] < 0.3:
            p.append(1)
        else:
            p.append(0)
    if d['score_list']['query_risk'][0][1] < 0.3:
        p.append(1)
    else:
        p.append(0)
    if d['score_list']['query_risk'][1][1] < 0.3:
        p.append(1)
    else:
        p.append(0)
    if sum(p) >= 5:
        sssss.append(d)

In [86]:
from copy import copy
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk.v9.1.black', 'w') as fwobj:
    for d in black_v1:
        content = copy(d)
        lssslsl = content.pop('score_list')
        content['label'] = ['风险']
        fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')

In [None]:
data_dict = {}
    for d in white:
        if d['topic'][0] not in data_dict:
            data_dict[d['topic'][0]] = []
        data_dict[d['topic'][0]].append(d)
    train_sample = []
    import random
    for key in data_dict:
        random.shuffle(data_dict[key])
        train_sample.extend(data_dict[key][:int(0.15*len(data_dict[key]))])
        
    from copy import copy
    for d in tqdm(train_sample):
        content = d.copy()
        if content['topic'][0] in ['死亡', '毒品', '恐怖主义', '战争', '灵异事件', 
                                   '色情', '灵异灵修', 'LGBT', 'BDSM', '性侵犯', '性骚扰']:
            continue
        if content['topic'][0] in ['法律']:
            if d['score_list']['senti'][0][1] < 0.5 and d['score_list']['offensive'][0][1] < 0.5:
                content['label'] = ['正常']
            else:
                continue
        else:
            content['label'] = ['正常']
        p = content.pop('score_list')
        content['source'] = 'query_white'
        fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')

In [50]:
from copy import copy
with open('/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive.v8', 'w') as fwobj:
    ppp = 0
    for d in tqdm(black_white):
        # if d['score_list']['senti_query'][-1][1] >= 0.8 and d['score_list']['senti'][-1][1] >= 0.8:
            # if d['topic'][0] in ['军事', '时事政治', '历史', '国家', '社会', '法律', '色情']:
            #     continue
            # if d['topic'][0] in [ '色情']:
            #     continue
        if d['topic'][0] in ['学习', '时尚', '宠物', '娱乐', '校园生活', '电子数码', '动漫']:
            content = d.copy()
            _del = content.pop('score_list')
            fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
#                 else:
#                     continue
#             fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
        
    
            

100%|██████████| 97158/97158 [00:00<00:00, 297035.32it/s]


In [70]:
white_dict = {}
from collections import Counter
t = Counter()
for d in white:
    t[d['topic'][0]] += 1
    if d['topic'][0] not in white_dict:
        white_dict[d['topic'][0]] = []
    white_dict[d['topic'][0]].append(d)

In [58]:
risk_api.predict('说的好。亚洲人在那里也属于,少数族裔，但是被压迫的不行了')




{'senti_query': [['负向', 0.007037341594696045],
  ['中性', 0.2305670976638794],
  ['正向', 0.7623955607414246]],
 'senti': [['负向', 0.05454050004482269], ['正向', 0.9454594850540161]],
 'bias': [['偏见', 0.026843402534723282], ['正常', 0.9731566309928894]],
 'ciron': [['讽刺', 0.032426558434963226], ['正常', 0.9675734043121338]],
 'intent': [['主观评价/比较/判断', 0.16399823129177094],
  ['寻求建议/帮助', 0.2048383355140686],
  ['其它', 0.631163477897644]],
 'offensive': [['冒犯', 0.06728868931531906], ['正常', 0.9327113628387451]],
 'query_risk': [['风险', 0.02225596271455288],
  ['个人信息', 0.004046468064188957],
  ['正常', 0.9736975431442261]],
 'teenager': [['不良', 0.009918730705976486], ['正常', 0.990081250667572]]}

In [275]:
import re
re.split('(\\[unused\d+\\])', '[unused1]后的数据规范环境[unused2]gsdhkjsgaf[unused2][SEP]')

['', '[unused1]', '后的数据规范环境', '[unused2]', 'gsdhkjsgaf', '[unused2]', '[SEP]']

In [331]:
risk_api.tokenizer('', add_special_tokens=False)

{'input_ids': [], 'token_type_ids': [], 'attention_mask': []}

In [61]:
from tqdm import tqdm
import numpy as np
import json, re

def risk_predict_batch(text):
    if isinstance(text, list):
        text_list = text
    else:
        text_list = [text]
    result_list = risk_api.predict_batch(text_list)
    return result_list

# with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk', 'w') as fwobj:
# with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk_v5', 'w') as fwobj:
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk_v10_logitclip', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            content['text'] = re.sub('请问', '', content['text'])
            queue.append(content['text'])
            t.append(content)
            if np.mod(len(queue), 512) == 0:
                probs = risk_predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    content = {
                        'text':text,
                        'topic':tt['label'],
                        'score_list':prob_dict
                    }
                    fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = risk_predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                content = {
                    'text':text,
                    'topic':tt['label'],
                    'score_list':prob_dict
                }
                fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
                    

50175it [00:49, 1023.64it/s]


KeyboardInterrupt: 

In [76]:
from tqdm import tqdm
import numpy as np
import json, re

def risk_predict_batch(text):
    if isinstance(text, list):
        text_list = text
    else:
        text_list = [text]
    result_list = risk_api.predict_batch(text_list)
    return result_list

offensive = []
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.offensive.all', 'r') as frobj:
    queue = []
    t = []
    for line in tqdm(frobj):
        content = json.loads(line.strip())
        content['text'] = re.sub('请问', '', content['text'])
        queue.append(content['text'])
        t.append(content)
        if np.mod(len(queue), 512) == 0:
            probs = risk_predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                content = {
                    'text':text,
                    'topic':tt['label'],
                    'score_list':prob_dict
                }
                offensive.append(content)
            queue = []
            t = []
    if queue:
        probs = risk_predict_batch(queue)
        for prob_dict, text, tt in zip(probs, queue, t):
            content = {
                'text':text,
                'topic':tt['label'],
                'score_list':prob_dict
            }
            offensive.append(content)


15414it [00:14, 1069.51it/s]


In [None]:
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.white', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json', 'r') as frobj:
        data_dict = {}
        for line in frobj:
            content = json.loads(line.strip())
            if content['text'] not in data_dict:
                data_dict[content['text']] = []
    
    

In [7]:
from tqdm import tqdm
import numpy as np
import json, re

def risk_predict_batch(text):
    if isinstance(text, list):
        text_list = text
    else:
        text_list = [text]
    result_list = risk_api.predict_batch(text_list)
    return result_list

In [10]:
risk_api.device

'cuda:3'

In [None]:
from tqdm import tqdm
import numpy as np
import json, re

def risk_predict_batch(text):
    if isinstance(text, list):
        text_list = text
    else:
        text_list = [text]
    result_list = risk_api.predict_batch(text_list)
    return result_list

with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.bias_ciron', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json', 'r') as frobj:
        for line in frobj:
            content = json.loads(line.strip())
            if content['label'][0] not in data_dict:
                data_dict[content['label'][0]] = []
            data_dict[content['label'][0]].append(content)

    train_sample = []
    import random
    for key in data_dict:
        random.shuffle(data_dict[key])
        train_sample.extend(data_dict[key][:int(0.2*len(data_dict[key]))])
    cnt = 1
    queue = []
    for content in tqdm(train_sample):
        content['text'] = re.sub('请问', '', content['text'])
        queue.append(content['text'])
        if np.mod(len(queue), 256) == 0:
            probs = risk_predict_batch(queue)
            for prob_dict, text in zip(probs, queue):
                score_list = []
                for key in ['bias', 'ciron', 'teenager']:
                    if prob_dict[key][0][1] > 0.9:
                        score_list.append([key, float(prob_dict[key][0][1])])
                key = 'query_risk'
                if prob_dict[key][0][1] > 0.9:
                    score_list.append([prob_dict[key][0][0], float(prob_dict[key][0][1])])
                if prob_dict[key][1][1] > 0.9:
                    score_list.append([prob_dict[key][1][0], float(prob_dict[key][1][1])])
                if score_list:
                    content = {
                        'text':text,
                        'label':['风险'],
                        'score_list':score_list
                    }
                    fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
            queue = []
    if queue:
        probs = risk_predict_batch(queue)
        for prob_dict, text in zip(probs, queue):
            score_list = []
            for key in ['bias', 'ciron', 'teenager']:
                if prob_dict[key][0][1] > 0.9:
                    score_list.append([key, float(prob_dict[key][0][1])])
            key = 'query_risk'
            if prob_dict[key][0][1] > 0.9:
                score_list.append([prob_dict[key][0][0], float(prob_dict[key][0][1])])
            if prob_dict[key][1][1] > 0.9:
                score_list.append([prob_dict[key][1][0], float(prob_dict[key][1][1])])

            if score_list:
                content = {
                    'text':text,
                    'label':['风险'],
                    'score_list':score_list
                }
                fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')

In [107]:
prob_dict[key][0]

['偏见', 0.4076650142669678]

In [41]:
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_intent_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_mtdnn_ce_256/multitask_cls.pth.9')
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_contrast_intent_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256_20/multitask_cls.pth.19')


# risk_api.reload('/data/albert.xht/xiaodao/query_risk_v3/multitask_raw_filter_senti_query_risk_v3/multitask_cls.pth.9')


In [367]:
model_path = [
'/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256/multitask_cls.pth.9',
'/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_intent_v1/multitask_cls.pth.9',
'/data/albert.xht/xiaodao/risk_classification/multitask_mtdnn_ce_256/multitask_cls.pth.9',
'/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256_20/multitask_cls.pth.19',
]

from tqdm import tqdm
total_result = []
for model in model_path:
    result_list = []
    risk_api.reload(model)
    with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/疑似有风险query_from对话预训练数据-20221130.txt') as frobj:
        for line in tqdm(frobj):
            text = line.strip()
            result = risk_api.predict(text)
            result_list.append((text, result))
    total_result.append(result_list)

result_matrix = []
for i in range(len(total_result[0])):
    p = []
    for tmp in total_result:
        item = tmp[i]
        for key in ['senti', 'bias', 'ciron', 'offensive', 'teenager', 'query_risk']:
            if item[1][key][0][1] > 0.6:
                p.append(1)
            else:
                p.append(0)
    result_matrix.append(p)



56443it [07:27, 126.13it/s]
56443it [07:29, 125.59it/s]
56443it [07:25, 126.73it/s]
56443it [07:25, 126.60it/s]


array([0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 1])

In [176]:
from tqdm import tqdm
result_list = []
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/疑似有风险query_from对话预训练数据-20221130.txt') as frobj:
    for line in tqdm(frobj):
        text = line.strip()
        result = risk_api.predict(text)
        result_list.append((text, result))
        

56443it [08:00, 117.48it/s]


In [362]:
result_matrix = []
for item in result_list:
    p = []
    for key in ['senti', 'bias', 'ciron', 'offensive', 'teenager']:
        if item[1][key][0][1] > 0.59:
            p.append(1)
        else:
            p.append(0)
    result_matrix.append(p)
        
        
# result_matrix = []
# for item in result_list:
#     p = []
#     for key in ['senti', 'bias', 'ciron', 'offensive', 'teenager', 'query_risk']:
#         if item[1][key][0][1] >= 0.6:
#             p.append(1)
#         else:
#             p.append(0)
#     result_matrix.append(p)
    

In [420]:
result_matrix = np.array(result_matrix)
votes = np.sum(result_matrix, axis=-1)
labels = np.array(votes >= 6).astype(np.int)
sum(labels)

36067

In [463]:

from sklearn.ensemble import RandomForestClassifier
result_matrix = np.array(result_matrix)

def SIMPLE(result_matrix):
    votes = np.sum(result_matrix, axis=-1)
    labels = np.array(votes >= 6).astype(np.int)
    clf = RandomForestClassifier(max_depth=2, random_state=0, ccp_alpha=0.1)
    
    for i in range(5):
        result = clf.fit(result_matrix, labels)
        probs = clf.predict_proba(result_matrix)
        print(probs.shape)
        labels = np.argmax(probs, axis=-1)
    return probs


probs = SIMPLE(result_matrix)
        

(56443, 2)
(56443, 2)
(56443, 2)
(56443, 2)
(56443, 2)


In [464]:
ff = []
clean = []
for idx in range(probs.shape[0]):
    if probs[idx,1] >= 0.8:
        ff.append((idx, result_list[idx], probs[idx]))
    else:
        clean.append((idx, result_list[idx], probs[idx]))
        


In [23]:
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.risk', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            # content['text'] = re.sub('请问', '', content['text'])
            queue.append(content['text'])
            t.append(content)
            if np.mod(len(queue), 512) == 0:
                probs = risk_predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    content = {
                        'text':text,
                        'topic':tt['label'],
                        'score_list':prob_dict
                    }
                    fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = risk_predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                content = {
                    'text':text,
                    'topic':tt['label'],
                    'score_list':prob_dict
                }
                fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')

24721it [00:22, 1076.00it/s]


In [480]:
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json', 'w') as fwobj:
    for item in ff:
        tmp = {
            'text':item[1][0],
            'label':['风险']
        }
        fwobj.write(json.dumps(tmp, ensure_ascii=False)+'\n')


In [221]:
candidate = [] 
positive = []
for item in result_list:
    if item[1]['senti'][0][1] >= 0.7 or item[1]['bias'][0][1] >= 0.8 or item[1]['ciron'][0][1] >= 0.8 or item[1]['offensive'][0][1] >= 0.8:
        candidate.append(item)
    else:
        positive.append(item)

In [73]:

offensive = []
with open('/data/albert.xht/sentiment/dev/offensive_cold.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        offensive.append(content)
        
offensive_test = []
with open('/data/albert.xht/sentiment/test/offensive_cold.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        offensive_test.append(content)

        
cdia_bias = []
with open('/data/albert.xht/sentiment/dev/cdial_bias.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        cdia_bias.append(content)
        
senti_copr = []
with open('/data/albert.xht/sentiment/dev/senti_copr.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        senti_copr.append(content)
        
ciron = []
with open('/data/albert.xht/sentiment/dev/chinese_ciron.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        ciron.append(content)

senti_smp = []
with open('/data/albert.xht/sentiment/dev/senti_smp_usual.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        senti_smp.append(content)
        
senti_smpecisa = []
with open('/data/albert.xht/sentiment/dev/senti_smpecisa.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        senti_smpecisa.append(content)

        
senti_query = []
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_valid.json.filter.0.7') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        senti_query.append(content)

In [74]:
from sklearn.metrics import classification_report
from tqdm import tqdm
import re

def eval_all(data, model, key):
    pred = []
    gold = []
    pred_score = []
    for item in tqdm(data):
        gold.append(item['label'][0])
        if isinstance(item['text'], list):
            text = "\n".join(item['text'])
        else:
            text = item['text']
        text = re.sub(r"([，\_《。》、？；：‘’＂“”【「】」·！@￥…（）—\,\<\.\>\/\?\;\:\'\"\[\]\{\}\~\`\!\@\#\$\%\^\&\*\(\)\-\=\+])+", "", text)   # 合并正文中过多的空格

        result = model.predict(text)
        score = sorted(result[key], key=lambda u:u[1], reverse=True)
        pred.append(score[0][0])
        pred_score.append(result[key])
    print(classification_report(gold, pred, digits=4))
    return pred, gold, pred_score
    


In [75]:

def evaluation_ece(pred_score, gold):
    pred_score_l = []
    mapping_dict = {}
    for item in pred_score:
        pred_score_l.append([])
        for idx, p in enumerate(item):
            if p[0] not in mapping_dict:
                mapping_dict[p[0]] = idx
            pred_score_l[-1].append(p[1])
    pred_score_l = torch.tensor(pred_score_l)
    gold_l = torch.tensor([mapping_dict[item] for item in gold])

    ece_fn = ECE(n_bins=15)
    print(ece_fn(pred_score_l, gold_l, mode='probs'), '==ece==')
# pred, gold, pred_score = eval_all(offensive_test, risk_api, 'offensive')
# evaluation_ece(pred_score, gold)


{'冒犯': 0, '正常': 1}
tensor([0.1119]) ==ece==


In [40]:
p

['正常', 0.9768216013908386]

In [37]:
mapping_dict

{}

In [28]:
item

'正常'

In [81]:
def evaluation(model_path):
    risk_api.reload(model_path)
    print('===offensive===')
    pred, gold, pred_score = eval_all(offensive_test, risk_api, 'offensive')
    evaluation_ece(pred_score, gold)
    print('===cdia-bias===')
    pred, gold, pred_score = eval_all(cdia_bias, risk_api, 'bias')
    evaluation_ece(pred_score, gold)
    print('===ciron===')
    pred, gold, pred_score = eval_all(ciron, risk_api, 'ciron')
    evaluation_ece(pred_score, gold)
    print('===chsenti===')
    pred, gold, pred_score = eval_all(senti_copr, risk_api, 'senti')
    evaluation_ece(pred_score, gold)
    print('===senti_smpecisa===')
    pred, gold, pred_score = eval_all(senti_smpecisa, risk_api, 'senti')
    evaluation_ece(pred_score, gold)
    print('===senti_smp===')
    pred, gold, pred_score = eval_all(senti_smp, risk_api, 'senti')
    evaluation_ece(pred_score, gold)
    print('===senti_query===')
    pred, gold, pred_score = eval_all(senti_query, risk_api, 'senti')
    evaluation_ece(pred_score, gold)
    

In [76]:
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v4_intent_v2/multitask_cls.pth.9')

# # risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_query_risk_20221222/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v4_intent_v2_5_aug/multitask_cls.pth.4')
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v10_intent_v2-1_10_no_symbol_senti_query_senta_balanced_logitclip_v1/multitask_cls.pth.9')
risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v12_intent_v2-1_10_no_symbol_senti_query_senta_balanced_v2/multitask_cls.pth.5')

intent_v2 = []
with open('/data/albert.xht/raw_chat_corpus/xiaoda/intention_data_v2/dev.txt') as frobj:
    for line in frobj:
        intent_v2.append(json.loads(line.strip()))
pred, gold, pred_score = eval_all(intent_v2, risk_api, 'intent')
evaluation_ece(pred_score, gold)


100%|██████████| 3000/3000 [00:25<00:00, 117.58it/s]

              precision    recall  f1-score   support

  主观评价/比较/判断     0.9233    0.9784    0.9501       603
          其它     0.9913    0.9725    0.9818      2219
     寻求建议/帮助     0.9022    0.9326    0.9171       178

    accuracy                         0.9713      3000
   macro avg     0.9389    0.9612    0.9497      3000
weighted avg     0.9723    0.9713    0.9716      3000

tensor([0.0181]) ==ece==





In [163]:
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v12_intent_v2-1_10_no_symbol_senti_query_senta_balanced_v2/multitask_cls.pth.7')

risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v12_intent_v2-1_10_no_symbol_senti_query_senta_mtdnn_v6/multitask_cls.pth.9')

risk_query = []
with open('/data/albert.xht/xiaodao/query_risk_v11/offensive_select_labeled.txt') as frobj:
    for line in frobj:
        risk_query.append(json.loads(line.strip()))
pred, gold, pred_score = eval_all(risk_query, risk_api, 'query_risk')
evaluation_ece(pred_score, gold)

100%|██████████| 20641/20641 [02:57<00:00, 115.98it/s]


              precision    recall  f1-score   support

          正常     0.8883    0.5134    0.6507      5514
          风险     0.8463    0.9765    0.9067     15127

    accuracy                         0.8528     20641
   macro avg     0.8673    0.7449    0.7787     20641
weighted avg     0.8575    0.8528    0.8383     20641

tensor([0.0591]) ==ece==


In [122]:
risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v12_intent_v2-1_10_no_symbol_senti_query_senta_balanced_v4/multitask_cls.pth.9')

risk_query = []
with open('/data/albert.xht/xiaodao/query_risk_v11/offensive_select_labeled.txt') as frobj:
    for line in frobj:
        risk_query.append(json.loads(line.strip()))
pred, gold, pred_score = eval_all(risk_query, risk_api, 'query_risk')
evaluation_ece(pred_score, gold)

100%|██████████| 20641/20641 [02:57<00:00, 116.23it/s]


              precision    recall  f1-score   support

          正常     0.7792    0.6578    0.7133      5514
          风险     0.8820    0.9320    0.9063     15127

    accuracy                         0.8588     20641
   macro avg     0.8306    0.7949    0.8098     20641
weighted avg     0.8545    0.8588    0.8548     20641

tensor([0.0275]) ==ece==


In [117]:
np.random.choice([1,2,3,4], 20, replace=True)

array([1, 2, 4, 3, 4, 2, 2, 1, 4, 2, 3, 3, 4, 4, 2, 4, 3, 1, 4, 2])

In [149]:
intent_v2 = []
with open('/data/albert.xht/raw_chat_corpus/xiaoda/intention_data_v2/dev.txt') as frobj:
    for line in frobj:
        intent_v2.append(json.loads(line.strip()))
pred, gold, pred_score = eval_all(intent_v2, risk_api, 'intent')
evaluation_ece(pred_score, gold)

100%|██████████| 3000/3000 [00:25<00:00, 117.97it/s]

              precision    recall  f1-score   support

  主观评价/比较/判断     0.9367    0.9818    0.9587       603
          其它     0.9931    0.9775    0.9852      2219
     寻求建议/帮助     0.9239    0.9551    0.9392       178

    accuracy                         0.9770      3000
   macro avg     0.9513    0.9714    0.9611      3000
weighted avg     0.9777    0.9770    0.9772      3000

tensor([0.0156]) ==ece==





In [172]:
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v12_intent_v2-1_10_no_symbol_senti_query_senta_balanced_v2/multitask_cls.pth.9')

risk_api.predict('头疼恶心')


{'senti_query': [['负向', 0.9976022839546204],
  ['中性', 0.0023331306874752045],
  ['正向', 6.459149881266057e-05]],
 'senti': [['负向', 0.9968634843826294], ['正向', 0.0031365305185317993]],
 'bias': [['偏见', 0.6690516471862793], ['正常', 0.3309483826160431]],
 'ciron': [['讽刺', 0.30152297019958496], ['正常', 0.698477029800415]],
 'intent': [['主观评价/比较/判断', 0.0004942239611409605],
  ['寻求建议/帮助', 0.9988555908203125],
  ['其它', 0.0006501346942968667]],
 'offensive': [['冒犯', 0.9131566882133484], ['正常', 0.0868433341383934]],
 'query_risk': [['风险', 0.10928872972726822],
  ['个人信息', 0.0005252966075204313],
  ['正常', 0.8901860117912292]],
 'teenager': [['不良', 0.55097496509552], ['正常', 0.44902506470680237]]}

In [164]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v12_intent_v2-1_10_no_symbol_senti_query_senta_mtdnn_v6/multitask_cls.pth.9')


  0%|          | 12/5304 [00:00<00:46, 114.35it/s]

===offensive===


100%|██████████| 5304/5304 [00:45<00:00, 115.80it/s]
  0%|          | 12/2829 [00:00<00:24, 116.27it/s]

              precision    recall  f1-score   support

          冒犯     0.7233    0.8651    0.7879      2106
          正常     0.8980    0.7821    0.8360      3198

    accuracy                         0.8150      5304
   macro avg     0.8107    0.8236    0.8120      5304
weighted avg     0.8287    0.8150    0.8169      5304

tensor([0.1024]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:24<00:00, 115.61it/s]
  1%|▏         | 12/875 [00:00<00:07, 116.87it/s]

              precision    recall  f1-score   support

          偏见     0.4979    0.6727    0.5723       718
          正常     0.8736    0.7693    0.8181      2111

    accuracy                         0.7448      2829
   macro avg     0.6858    0.7210    0.6952      2829
weighted avg     0.7782    0.7448    0.7557      2829

tensor([0.0150]) ==ece==
===ciron===


100%|██████████| 875/875 [00:07<00:00, 116.38it/s]
  1%|          | 12/1200 [00:00<00:10, 114.53it/s]

              precision    recall  f1-score   support

          正常     0.9301    0.9397    0.9349       779
          讽刺     0.4659    0.4271    0.4457        96

    accuracy                         0.8834       875
   macro avg     0.6980    0.6834    0.6903       875
weighted avg     0.8792    0.8834    0.8812       875

tensor([0.0316]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:10<00:00, 113.46it/s]
  0%|          | 12/2529 [00:00<00:21, 116.41it/s]

              precision    recall  f1-score   support

          正向     0.9121    0.9275    0.9197       593
          负向     0.9280    0.9127    0.9203       607

    accuracy                         0.9200      1200
   macro avg     0.9200    0.9201    0.9200      1200
weighted avg     0.9201    0.9200    0.9200      1200

tensor([0.0268]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:21<00:00, 116.69it/s]
  0%|          | 12/2844 [00:00<00:24, 117.74it/s]

              precision    recall  f1-score   support

          正向     0.8096    0.8177    0.8136      1201
          负向     0.8336    0.8261    0.8298      1328

    accuracy                         0.8221      2529
   macro avg     0.8216    0.8219    0.8217      2529
weighted avg     0.8222    0.8221    0.8221      2529

tensor([0.0986]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:24<00:00, 116.16it/s]
  0%|          | 12/50509 [00:00<07:12, 116.80it/s]

              precision    recall  f1-score   support

          正向     0.8117    0.9112    0.8586      1126
          负向     0.9367    0.8615    0.8975      1718

    accuracy                         0.8812      2844
   macro avg     0.8742    0.8863    0.8780      2844
weighted avg     0.8872    0.8812    0.8821      2844

tensor([0.0661]) ==ece==
===senti_query===


 18%|█▊        | 9193/50509 [01:18<05:52, 117.10it/s]


KeyboardInterrupt: 

In [158]:
192/6

32.0

In [None]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v5_intent_v2_3/multitask_cls.pth.4')


In [83]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v12_intent_v2-1_10_no_symbol_senti_query_senta_balanced_v2/multitask_cls.pth.5')

In [79]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v10_intent_v2-1_10_no_symbol_senti_query_senta_balanced_logitclip_v1/multitask_cls.pth.9')


In [20]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v10_intent_v2-1_10_no_symbol_senti_query_senta_balanced_logitclip/multitask_cls.pth.9')


In [12]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v9_intent_v2-1_10_no_symbol_v1/multitask_cls.pth.9')


In [13]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v6_intent_v2-1_10_no_symbol/multitask_cls.pth.9')


In [19]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v5_no_offensive_intent_v2-1_3/multitask_cls.pth.4')


In [20]:
# evaluation('/data/albert.xht/xiaodao/multitask_raw_filter_senti_query_risk_v5_intent_v2-1_3/multitask_cls.pth.4')


In [1]:
import random
random.randint(0, 1), random.sample(list(range(1, 2)), k=0), range(1, 0)

(0, [], range(1, 0))

In [97]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v4_intent_v2_5_aug/multitask_cls.pth.4')


In [98]:

# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v4_intent_v2/multitask_cls.pth.8')


In [99]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_query_risk_20221222/multitask_cls.pth.9')


In [100]:
# evaluation('/data/albert.xht/xiaodao/query_risk_v3/multitask_raw_filter_senti_query_risk_v3/multitask_cls.pth.9')

In [101]:

# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256_20/multitask_cls.pth.9')

In [102]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_intent_v1/multitask_cls.pth.9')

In [103]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_focal/multitask_cls.pth.9')

In [104]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_balanced_intent_v1//multitask_cls.pth.9')



In [105]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_focal_128/multitask_cls.pth.9')


In [106]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_contrast_balanced_intent_v1/multitask_cls.pth.9')


In [107]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_mtdnn_ce_256/multitask_cls.pth.9')


In [108]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_mtdnn_focal_256/multitask_cls.pth.9')


In [109]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256/multitask_cls.pth.9')




In [110]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_focal_256/multitask_cls.pth.9')


In [111]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_contrast_intent_v1/multitask_cls.pth.9')


In [24]:
ece_fn = ECE(n_bins=15)

In [45]:
mapping_dict = {
    '冒犯':0,
    '正常':1
}

gold_l = torch.tensor([mapping_dict[item] for item in gold])
pred_score_l = torch.tensor([[item[0][1], item[1][1]] for item in pred_score])

ece_fn(pred_score_l, gold_l, mode='probs')

tensor([0.1104])

In [43]:
mapping_dict = {
    '负向':0,
    '正向':1
}

gold_l = torch.tensor([mapping_dict[item] for item in gold])
pred_score_l = torch.tensor([[item[0][1], item[1][1]] for item in pred_score])

ece_fn(pred_score_l, gold_l, mode='probs')

tensor([0.0198])