In [1]:
import json
import sys,os
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

sys.path.extend(['/root/xiaoda/query_topic/'])

In [3]:
import torch
from torch.nn import functional as F
import numpy as np
import random
import torch.nn as nn
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np

"""
https://github.com/ondrejbohdal/meta-calibration/blob/main/Metrics/metrics.py
"""

class ECE(nn.Module):
    
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECE, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels, mode='logits'):
        if mode == 'logits':
            softmaxes = F.softmax(logits, dim=1)
        else:
            softmaxes = logits
        # softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

In [4]:
import torch
import json
import sys
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast
import transformers
from datetime import timedelta

import os, sys

from nets.them_classifier import MyBaseModel, RobertaClassifier

import configparser
from tqdm import tqdm

cur_dir_path = '/root/xiaoda/query_topic/'

def load_label(filepath):
    label_list = []
    with open(filepath, 'r') as frobj:
        for line in frobj:
            label_list.append(line.strip())
        n_classes = len(label_list)

        label2id = {}
        id2label = {}
        for idx, label in enumerate(label_list):
            label2id[label] = idx
            id2label[idx] = label
        return label2id, id2label

class RiskInfer(object):
    def __init__(self, config_path):

        import torch, os, sys

        con = configparser.ConfigParser()
        con_path = os.path.join(cur_dir_path, config_path)
        con.read(con_path, encoding='utf8')

        args_path = dict(dict(con.items('paths')), **dict(con.items("para")))
        self.tokenizer = BertTokenizerFast.from_pretrained(args_path["model_path"], do_lower_case=True)

        from collections import OrderedDict
        self.schema_dict = OrderedDict({})

        for label_index, schema_info in enumerate(args_path["label_path"].split(',')):
            schema_type, schema_path = schema_info.split(':')
            schema_path = os.path.join(cur_dir_path, schema_path)
            print(schema_type, schema_path, '===schema-path===')
            label2id, id2label = load_label(schema_path)
            self.schema_dict[schema_type] = {
                'label2id':label2id,
                'id2label':id2label,
                'label_index':label_index
            }
            print(self.schema_dict[schema_type], '==schema_type==', schema_type)
        
        output_path = os.path.join(cur_dir_path, args_path['output_path'])

        from roformer import RoFormerModel, RoFormerConfig

        config = RoFormerConfig.from_pretrained(args_path["model_path"])
        encoder = RoFormerModel(config=config)
        
        encoder_net = MyBaseModel(encoder, config)

        self.device = "cuda:3" if torch.cuda.is_available() else "cpu"

        classifier_list = []

        schema_list = list(self.schema_dict.keys())

        for schema_key in schema_list:
            classifier = RobertaClassifier(
                hidden_size=config.hidden_size, 
                dropout_prob=con.getfloat('para', 'out_dropout_rate'),
                num_labels=len(self.schema_dict[schema_key]['label2id']), 
                dropout_type=con.get('para', 'dropout_type'))
            classifier_list.append(classifier)

        classifier_list = nn.ModuleList(classifier_list)

        class MultitaskClassifier(nn.Module):
            def __init__(self, transformer, classifier_list):
                super().__init__()

                self.transformer = transformer
                self.classifier_list = classifier_list

            def forward(self, input_ids, input_mask, 
                        segment_ids=None, 
                        transformer_mode='mean_pooling', 
                        dt_idx=None):
                hidden_states = self.transformer(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              return_mode=transformer_mode)
                outputs_list = []
                
                for idx, classifier in enumerate(self.classifier_list):
                    
                    if dt_idx is not None and idx != dt_idx:
                        continue
                    
                    ce_logits = classifier(hidden_states)
                    outputs_list.append(ce_logits)
                return outputs_list, hidden_states

        self.net = MultitaskClassifier(encoder_net, classifier_list).to(self.device)

        # eo = 9
        # ckpt = torch.load(os.path.join(output_path, 'multitask_cls.pth.{}.raw'.format(eo)), map_location=self.device)
        # # ckpt = torch.load(os.path.join(output_path, 'multitask_cls.pth.{}.raw.focal'.format(eo)), map_location=self.device)
        # # ckpt = torch.load(os.path.join(output_path, 'multitask_contrast_cls.pth.{}'.format(eo)), map_location=self.device)
        # self.net.load_state_dict(ckpt)
        # self.net.eval()
        
    def reload(self, model_path):
        ckpt = torch.load(model_path, map_location=self.device)
        self.net.load_state_dict(ckpt)
        self.net.eval()

    def predict(self, text):

        """抽取输入text所包含的类型
        """
        encoder_txt = self.tokenizer.encode_plus(text, max_length=256)
        input_ids = torch.tensor(encoder_txt["input_ids"]).long().unsqueeze(0).to(self.device)
        token_type_ids = torch.tensor(encoder_txt["token_type_ids"]).unsqueeze(0).to(self.device)
        attention_mask = torch.tensor(encoder_txt["attention_mask"]).unsqueeze(0).to(self.device)
        
        scores_dict = {}
        with torch.no_grad():
            [logits_list, 
            hidden_states] = self.net(input_ids, 
                attention_mask, token_type_ids, transformer_mode='cls')
        for schema_type, logits in zip(list(self.schema_dict.keys()), logits_list):
            scores = torch.nn.Softmax(dim=1)(logits)[0].data.cpu().numpy()
            scores_dict[schema_type] = []
            for index, score in enumerate(scores):
                scores_dict[schema_type].append([self.schema_dict[schema_type]['id2label'][index], 
                                        float(score)])
        return scores_dict
    
    def get_logitnorm(self, text):
        """抽取输入text所包含的类型
        """
        encoder_txt = self.tokenizer.encode_plus(text, max_length=256)
        input_ids = torch.tensor(encoder_txt["input_ids"]).long().unsqueeze(0).to(self.device)
        token_type_ids = torch.tensor(encoder_txt["token_type_ids"]).unsqueeze(0).to(self.device)
        attention_mask = torch.tensor(encoder_txt["attention_mask"]).unsqueeze(0).to(self.device)
        
        scores_dict = {}
        logits_norm_list = []
        with torch.no_grad():
            [logits_list, 
            hidden_states] = self.net(input_ids, 
                attention_mask, token_type_ids, transformer_mode='cls')
            for logits in logits_list:
                logits_norm_list.append(logits/torch.norm(logits, p=2, dim=-1, keepdim=True) + 1e-7)
        for schema_type, logit_norm in zip(list(self.schema_dict.keys()), logits_norm_list):
            scores_dict[schema_type] = logit_norm[0].data.cpu().numpy()
        return scores_dict
            
    
    def predict_batch(self, text):
        if isinstance(text, list):
            text_list = text
        else:
            text_list = [text]
        model_input = self.tokenizer(text_list, return_tensors="pt",padding=True)
        for key in model_input:
            model_input[key] = model_input[key].to(self.device)
        with torch.no_grad():
            [logits_list, 
            hidden_states] = self.net(model_input['input_ids'], 
                model_input['attention_mask'], 
                model_input['token_type_ids'], transformer_mode='cls')
        score_dict_list = []
        for idx, text in enumerate(text_list):
            scores_dict = {}
            for schema_type, logits in zip(list(self.schema_dict.keys()), logits_list):
                scores = torch.nn.Softmax(dim=1)(logits)[idx].data.cpu().numpy()
                scores_dict[schema_type] = []
                for index, score in enumerate(scores):
                    scores_dict[schema_type].append([self.schema_dict[schema_type]['id2label'][index], 
                                            float(score)])
            score_dict_list.append(scores_dict)
        return score_dict_list

# risk_api = RiskInfer('./risk_data/config.ini')
risk_api = RiskInfer('./risk_data_v5/config.ini')




senti_query /root/xiaoda/query_topic/risk_data_v5/senti_query_label.txt ===schema-path===
{'label2id': {'负向': 0, '中性': 1, '正向': 2}, 'id2label': {0: '负向', 1: '中性', 2: '正向'}, 'label_index': 0} ==schema_type== senti_query
senti /root/xiaoda/query_topic/risk_data_v5/senti_label.txt ===schema-path===
{'label2id': {'负向': 0, '正向': 1}, 'id2label': {0: '负向', 1: '正向'}, 'label_index': 1} ==schema_type== senti
bias /root/xiaoda/query_topic/risk_data_v5/bias_label.txt ===schema-path===
{'label2id': {'偏见': 0, '正常': 1}, 'id2label': {0: '偏见', 1: '正常'}, 'label_index': 2} ==schema_type== bias
ciron /root/xiaoda/query_topic/risk_data_v5/ciron_label.txt ===schema-path===
{'label2id': {'讽刺': 0, '正常': 1}, 'id2label': {0: '讽刺', 1: '正常'}, 'label_index': 3} ==schema_type== ciron
intent /root/xiaoda/query_topic/risk_data_v5/intention_label_v0.txt ===schema-path===
{'label2id': {'主观评价/比较/判断': 0, '寻求建议/帮助': 1, '其它': 2}, 'id2label': {0: '主观评价/比较/判断', 1: '寻求建议/帮助', 2: '其它'}, 'label_index': 4} ==schema_type== intent

01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++
01/04/2023 09:17:22 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++


In [5]:
import csv
l = 0
reader = csv.reader(open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/Query风险分类_全部数据.csv'), delimiter="\t", quotechar=None)
for idx, item in enumerate(reader):
    # print(item)
    l += 1
l

9965

In [None]:
with open("/data/albert.xht/pretrained_model_risk/corpus/efaqa-corpus-zh/efaqa-corpus-zh.utf8", "r") as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        title = ''.join(re.split('[\s,]', content['title'])[1:])
        if len(title) >= 5:
            if  s3_mapping[content['label']['s3']] in ['正在进行的自杀行为', '策划进行的自杀行为', '自残']:
                tmp = {
                    'title':title,
                    'label':['风险']
                }
                print(tmp)

In [37]:
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v5_intent_v2_3/multitask_cls.pth.4')
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v5_intent_v2-1_3//multitask_cls.pth.4')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v6_intent_v2-1_10/multitask_cls.pth.6')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v6_intent_v2-1_10_no_symbol/multitask_cls.pth.5')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/risk_classification/multitask_raw_filter_senti_query_risk_v7_intent_v2-1_10_no_symbol/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v8_intent_v2-1_10_no_symbol/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v8_intent_v2-1_10_no_symbol_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v9_intent_v2-1_10_no_symbol_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v10_intent_v2-1_10_no_symbol_senti_query_senta_balanced_logitclip/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v10_intent_v2-1_10_no_symbol_senti_query_senta_balanced_logitclip_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v11_intent_v2-1_10_no_symbol_senti_query_senta_balanced_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v12_intent_v2-1_10_no_symbol_senti_query_senta_balanced_v1/multitask_cls.pth.8')

risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v12_intent_v2-1_10_no_symbol_senti_query_senta_balanced_v2/')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v6_no_offensive_intent_v2-1_10//multitask_cls.pth.6')
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v5_no_offensive_intent_v2-1_3/multitask_cls.pth.4')


# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v4_intent_v2/multitask_cls.pth.9')


In [41]:
risk_api.predict('唉成都又出猛女了成都一名26岁女司机无证驾驶连闯红灯还')

{'senti_query': [['负向', 0.756915271282196],
  ['中性', 0.20392245054244995],
  ['正向', 0.039162278175354004]],
 'senti': [['负向', 0.9995278120040894], ['正向', 0.00047222699504345655]],
 'bias': [['偏见', 0.9707301259040833], ['正常', 0.029269874095916748]],
 'ciron': [['讽刺', 0.8651314973831177], ['正常', 0.13486848771572113]],
 'intent': [['主观评价/比较/判断', 0.06287194788455963],
  ['寻求建议/帮助', 0.13903510570526123],
  ['其它', 0.7980930209159851]],
 'offensive': [['冒犯', 0.12341342121362686], ['正常', 0.8765866160392761]],
 'query_risk': [['风险', 0.5953889489173889],
  ['个人信息', 5.400699228630401e-05],
  ['正常', 0.4045570194721222]],
 'teenager': [['不良', 0.8138567805290222], ['正常', 0.18614326417446136]]}

In [45]:
def batch_inference(input_path, output_path):
    from tqdm import tqdm
    import numpy as np
    import json, re

    def risk_predict_batch(text):
        if isinstance(text, list):
            text_list = text
        else:
            text_list = [text]
        result_list = risk_api.predict_batch(text_list)
        return result_list
    
    print(input_path, '===input-path===')
    print(output_path, '===output-path===')
    
    with open(output_path, 'w') as fwobj:
        with open(input_path, 'r') as frobj:
            queue = []
            t = []
            for line in tqdm(frobj):
                content = json.loads(line.strip())
                content['text'] = re.sub('请问', '', content['text'])
                text = re.sub(r"([，\_《。》、？；：‘’＂“”【「】」·！@￥…（）—\,\<\.\>\/\?\;\:\'\"\[\]\{\}\~\`\!\@\#\$\%\^\&\*\(\)\-\=\+])+", "", content['text'])   # 合并正文中过多的空格
                queue.append(text)
                t.append(content)
                if np.mod(len(queue), 128) == 0:
                    probs = risk_predict_batch(queue)
                    for prob_dict, text, tt in zip(probs, queue, t):
                        content = {
                            'text':tt['text'],
                            'topic':tt['label'],
                            'score_list':prob_dict,
                            # 'score_list': tt['score_list']
                        }
                        fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
                    queue = []
                    t = []
            if queue:
                probs = risk_predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    content = {
                        'text':tt['text'],
                        'topic':tt['label'],
                        'score_list':prob_dict,
                        # 'score_list': tt['score_list']
                    }
                    fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
                    



In [53]:
input_path = '/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive.topic'
output_path = '/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive.topic.v10'


batch_inference(input_path, output_path)

128it [00:00, 1057.32it/s]

/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive.topic ===input-path===
/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive.topic.v10 ===output-path===


25507it [00:23, 1074.57it/s]


In [54]:
input_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/embed_linear_small_white.json.topic'
output_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/embed_linear_small_white.json.topic.v10'


batch_inference(input_path, output_path)

128it [00:00, 1055.62it/s]

/data/albert.xht/raw_chat_corpus/topic_classification_v4/embed_linear_small_white.json.topic ===input-path===
/data/albert.xht/raw_chat_corpus/topic_classification_v4/embed_linear_small_white.json.topic.v10 ===output-path===


23116it [00:21, 1058.94it/s]


In [None]:
input_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.topic.knn.final'
output_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk.v9.1'


batch_inference(input_path, output_path)

In [43]:
input_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.topic.knn.final'
output_path = '/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk.v9.1'


batch_inference(input_path, output_path)


0it [00:00, ?it/s]

/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.topic.knn.final ===input-path===
/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk.v9.1 ===output-path===


2071401it [45:32, 758.09it/s]


In [65]:

input_path = '/data/albert.xht/xiaodao/query_risk_v10/biake_qa_web_text_zh_train.json.offensive.all'
output_path = '/data/albert.xht/xiaodao/query_risk_v10/biake_qa_web_text_zh_train.json.offensive.all.v10.1'

batch_inference(input_path, output_path)



128it [00:00, 959.94it/s]

/data/albert.xht/xiaodao/query_risk_v10/biake_qa_web_text_zh_train.json.offensive.all ===input-path===
/data/albert.xht/xiaodao/query_risk_v10/biake_qa_web_text_zh_train.json.offensive.all.v10.1 ===output-path===


18456it [00:17, 1072.16it/s]


In [45]:
black = []
white = []
import fast_json as json
from tqdm import tqdm

with open(output_path) as frobj:
    for line in tqdm(frobj):
        content = json.loads(line.strip())
        if content['score_list']['query_risk'][0][1] > 0.5:
            black.append(content)
        else:
            white.append(content)

2071401it [04:34, 7542.81it/s] 


177560

In [48]:
black_v1 = []
black_white = []
for d in black:
    if d['score_list']['query_risk'][0][1] > 0.9:
        black_v1.append(d)
    else:
        black_white.append(d)

In [77]:
sssss = []
for d in black_white:
    p = []
    for key in ['senti', 'offensive', 'teenager', 'senti_query', 'senti_query']:
        if d['score_list'][key][0][1] < 0.3:
            p.append(1)
        else:
            p.append(0)
    if d['score_list']['query_risk'][0][1] < 0.3:
        p.append(1)
    else:
        p.append(0)
    if d['score_list']['query_risk'][1][1] < 0.3:
        p.append(1)
    else:
        p.append(0)
    if sum(p) >= 5:
        sssss.append(d)

In [86]:
from copy import copy
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk.v9.1.black', 'w') as fwobj:
    for d in black_v1:
        content = copy(d)
        lssslsl = content.pop('score_list')
        content['label'] = ['风险']
        fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')

In [None]:
data_dict = {}
    for d in white:
        if d['topic'][0] not in data_dict:
            data_dict[d['topic'][0]] = []
        data_dict[d['topic'][0]].append(d)
    train_sample = []
    import random
    for key in data_dict:
        random.shuffle(data_dict[key])
        train_sample.extend(data_dict[key][:int(0.15*len(data_dict[key]))])
        
    from copy import copy
    for d in tqdm(train_sample):
        content = d.copy()
        if content['topic'][0] in ['死亡', '毒品', '恐怖主义', '战争', '灵异事件', 
                                   '色情', '灵异灵修', 'LGBT', 'BDSM', '性侵犯', '性骚扰']:
            continue
        if content['topic'][0] in ['法律']:
            if d['score_list']['senti'][0][1] < 0.5 and d['score_list']['offensive'][0][1] < 0.5:
                content['label'] = ['正常']
            else:
                continue
        else:
            content['label'] = ['正常']
        p = content.pop('score_list')
        content['source'] = 'query_white'
        fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')

In [50]:
from copy import copy
with open('/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive.v8', 'w') as fwobj:
    ppp = 0
    for d in tqdm(black_white):
        # if d['score_list']['senti_query'][-1][1] >= 0.8 and d['score_list']['senti'][-1][1] >= 0.8:
            # if d['topic'][0] in ['军事', '时事政治', '历史', '国家', '社会', '法律', '色情']:
            #     continue
            # if d['topic'][0] in [ '色情']:
            #     continue
        if d['topic'][0] in ['学习', '时尚', '宠物', '娱乐', '校园生活', '电子数码', '动漫']:
            content = d.copy()
            _del = content.pop('score_list')
            fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
#                 else:
#                     continue
#             fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
        
    
            

100%|██████████| 97158/97158 [00:00<00:00, 297035.32it/s]


In [70]:
white_dict = {}
from collections import Counter
t = Counter()
for d in white:
    t[d['topic'][0]] += 1
    if d['topic'][0] not in white_dict:
        white_dict[d['topic'][0]] = []
    white_dict[d['topic'][0]].append(d)

In [58]:
risk_api.predict('说的好。亚洲人在那里也属于,少数族裔，但是被压迫的不行了')




{'senti_query': [['负向', 0.007037341594696045],
  ['中性', 0.2305670976638794],
  ['正向', 0.7623955607414246]],
 'senti': [['负向', 0.05454050004482269], ['正向', 0.9454594850540161]],
 'bias': [['偏见', 0.026843402534723282], ['正常', 0.9731566309928894]],
 'ciron': [['讽刺', 0.032426558434963226], ['正常', 0.9675734043121338]],
 'intent': [['主观评价/比较/判断', 0.16399823129177094],
  ['寻求建议/帮助', 0.2048383355140686],
  ['其它', 0.631163477897644]],
 'offensive': [['冒犯', 0.06728868931531906], ['正常', 0.9327113628387451]],
 'query_risk': [['风险', 0.02225596271455288],
  ['个人信息', 0.004046468064188957],
  ['正常', 0.9736975431442261]],
 'teenager': [['不良', 0.009918730705976486], ['正常', 0.990081250667572]]}

In [275]:
import re
re.split('(\\[unused\d+\\])', '[unused1]后的数据规范环境[unused2]gsdhkjsgaf[unused2][SEP]')

['', '[unused1]', '后的数据规范环境', '[unused2]', 'gsdhkjsgaf', '[unused2]', '[SEP]']

In [331]:
risk_api.tokenizer('', add_special_tokens=False)

{'input_ids': [], 'token_type_ids': [], 'attention_mask': []}

In [61]:
from tqdm import tqdm
import numpy as np
import json, re

def risk_predict_batch(text):
    if isinstance(text, list):
        text_list = text
    else:
        text_list = [text]
    result_list = risk_api.predict_batch(text_list)
    return result_list

# with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk', 'w') as fwobj:
# with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk_v5', 'w') as fwobj:
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.all_risk_v10_logitclip', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            content['text'] = re.sub('请问', '', content['text'])
            queue.append(content['text'])
            t.append(content)
            if np.mod(len(queue), 512) == 0:
                probs = risk_predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    content = {
                        'text':text,
                        'topic':tt['label'],
                        'score_list':prob_dict
                    }
                    fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = risk_predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                content = {
                    'text':text,
                    'topic':tt['label'],
                    'score_list':prob_dict
                }
                fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
                    

50175it [00:49, 1023.64it/s]


KeyboardInterrupt: 

In [76]:
from tqdm import tqdm
import numpy as np
import json, re

def risk_predict_batch(text):
    if isinstance(text, list):
        text_list = text
    else:
        text_list = [text]
    result_list = risk_api.predict_batch(text_list)
    return result_list

offensive = []
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.offensive.all', 'r') as frobj:
    queue = []
    t = []
    for line in tqdm(frobj):
        content = json.loads(line.strip())
        content['text'] = re.sub('请问', '', content['text'])
        queue.append(content['text'])
        t.append(content)
        if np.mod(len(queue), 512) == 0:
            probs = risk_predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                content = {
                    'text':text,
                    'topic':tt['label'],
                    'score_list':prob_dict
                }
                offensive.append(content)
            queue = []
            t = []
    if queue:
        probs = risk_predict_batch(queue)
        for prob_dict, text, tt in zip(probs, queue, t):
            content = {
                'text':text,
                'topic':tt['label'],
                'score_list':prob_dict
            }
            offensive.append(content)


15414it [00:14, 1069.51it/s]


In [None]:
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.white', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json', 'r') as frobj:
        data_dict = {}
        for line in frobj:
            content = json.loads(line.strip())
            if content['text'] not in data_dict:
                data_dict[content['text']] = []
    
    

In [7]:
from tqdm import tqdm
import numpy as np
import json, re

def risk_predict_batch(text):
    if isinstance(text, list):
        text_list = text
    else:
        text_list = [text]
    result_list = risk_api.predict_batch(text_list)
    return result_list

In [10]:
risk_api.device

'cuda:3'

In [None]:
from tqdm import tqdm
import numpy as np
import json, re

def risk_predict_batch(text):
    if isinstance(text, list):
        text_list = text
    else:
        text_list = [text]
    result_list = risk_api.predict_batch(text_list)
    return result_list

with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json.bias_ciron', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json', 'r') as frobj:
        for line in frobj:
            content = json.loads(line.strip())
            if content['label'][0] not in data_dict:
                data_dict[content['label'][0]] = []
            data_dict[content['label'][0]].append(content)

    train_sample = []
    import random
    for key in data_dict:
        random.shuffle(data_dict[key])
        train_sample.extend(data_dict[key][:int(0.2*len(data_dict[key]))])
    cnt = 1
    queue = []
    for content in tqdm(train_sample):
        content['text'] = re.sub('请问', '', content['text'])
        queue.append(content['text'])
        if np.mod(len(queue), 256) == 0:
            probs = risk_predict_batch(queue)
            for prob_dict, text in zip(probs, queue):
                score_list = []
                for key in ['bias', 'ciron', 'teenager']:
                    if prob_dict[key][0][1] > 0.9:
                        score_list.append([key, float(prob_dict[key][0][1])])
                key = 'query_risk'
                if prob_dict[key][0][1] > 0.9:
                    score_list.append([prob_dict[key][0][0], float(prob_dict[key][0][1])])
                if prob_dict[key][1][1] > 0.9:
                    score_list.append([prob_dict[key][1][0], float(prob_dict[key][1][1])])
                if score_list:
                    content = {
                        'text':text,
                        'label':['风险'],
                        'score_list':score_list
                    }
                    fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
            queue = []
    if queue:
        probs = risk_predict_batch(queue)
        for prob_dict, text in zip(probs, queue):
            score_list = []
            for key in ['bias', 'ciron', 'teenager']:
                if prob_dict[key][0][1] > 0.9:
                    score_list.append([key, float(prob_dict[key][0][1])])
            key = 'query_risk'
            if prob_dict[key][0][1] > 0.9:
                score_list.append([prob_dict[key][0][0], float(prob_dict[key][0][1])])
            if prob_dict[key][1][1] > 0.9:
                score_list.append([prob_dict[key][1][0], float(prob_dict[key][1][1])])

            if score_list:
                content = {
                    'text':text,
                    'label':['风险'],
                    'score_list':score_list
                }
                fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')

In [107]:
prob_dict[key][0]

['偏见', 0.4076650142669678]

In [41]:
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_intent_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_mtdnn_ce_256/multitask_cls.pth.9')
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_contrast_intent_v1/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256_20/multitask_cls.pth.19')


# risk_api.reload('/data/albert.xht/xiaodao/query_risk_v3/multitask_raw_filter_senti_query_risk_v3/multitask_cls.pth.9')


In [367]:
model_path = [
'/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256/multitask_cls.pth.9',
'/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_intent_v1/multitask_cls.pth.9',
'/data/albert.xht/xiaodao/risk_classification/multitask_mtdnn_ce_256/multitask_cls.pth.9',
'/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256_20/multitask_cls.pth.19',
]

from tqdm import tqdm
total_result = []
for model in model_path:
    result_list = []
    risk_api.reload(model)
    with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/疑似有风险query_from对话预训练数据-20221130.txt') as frobj:
        for line in tqdm(frobj):
            text = line.strip()
            result = risk_api.predict(text)
            result_list.append((text, result))
    total_result.append(result_list)

result_matrix = []
for i in range(len(total_result[0])):
    p = []
    for tmp in total_result:
        item = tmp[i]
        for key in ['senti', 'bias', 'ciron', 'offensive', 'teenager', 'query_risk']:
            if item[1][key][0][1] > 0.6:
                p.append(1)
            else:
                p.append(0)
    result_matrix.append(p)



56443it [07:27, 126.13it/s]
56443it [07:29, 125.59it/s]
56443it [07:25, 126.73it/s]
56443it [07:25, 126.60it/s]


array([0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 1])

In [176]:
from tqdm import tqdm
result_list = []
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/疑似有风险query_from对话预训练数据-20221130.txt') as frobj:
    for line in tqdm(frobj):
        text = line.strip()
        result = risk_api.predict(text)
        result_list.append((text, result))
        

56443it [08:00, 117.48it/s]


In [362]:
result_matrix = []
for item in result_list:
    p = []
    for key in ['senti', 'bias', 'ciron', 'offensive', 'teenager']:
        if item[1][key][0][1] > 0.59:
            p.append(1)
        else:
            p.append(0)
    result_matrix.append(p)
        
        
# result_matrix = []
# for item in result_list:
#     p = []
#     for key in ['senti', 'bias', 'ciron', 'offensive', 'teenager', 'query_risk']:
#         if item[1][key][0][1] >= 0.6:
#             p.append(1)
#         else:
#             p.append(0)
#     result_matrix.append(p)
    

In [420]:
result_matrix = np.array(result_matrix)
votes = np.sum(result_matrix, axis=-1)
labels = np.array(votes >= 6).astype(np.int)
sum(labels)

36067

In [463]:

from sklearn.ensemble import RandomForestClassifier
result_matrix = np.array(result_matrix)

def SIMPLE(result_matrix):
    votes = np.sum(result_matrix, axis=-1)
    labels = np.array(votes >= 6).astype(np.int)
    clf = RandomForestClassifier(max_depth=2, random_state=0, ccp_alpha=0.1)
    
    for i in range(5):
        result = clf.fit(result_matrix, labels)
        probs = clf.predict_proba(result_matrix)
        print(probs.shape)
        labels = np.argmax(probs, axis=-1)
    return probs


probs = SIMPLE(result_matrix)
        

(56443, 2)
(56443, 2)
(56443, 2)
(56443, 2)
(56443, 2)


In [464]:
ff = []
clean = []
for idx in range(probs.shape[0]):
    if probs[idx,1] >= 0.8:
        ff.append((idx, result_list[idx], probs[idx]))
    else:
        clean.append((idx, result_list[idx], probs[idx]))
        


In [23]:
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.risk', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            # content['text'] = re.sub('请问', '', content['text'])
            queue.append(content['text'])
            t.append(content)
            if np.mod(len(queue), 512) == 0:
                probs = risk_predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    content = {
                        'text':text,
                        'topic':tt['label'],
                        'score_list':prob_dict
                    }
                    fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = risk_predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                content = {
                    'text':text,
                    'topic':tt['label'],
                    'score_list':prob_dict
                }
                fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')

24721it [00:22, 1076.00it/s]


In [480]:
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json', 'w') as fwobj:
    for item in ff:
        tmp = {
            'text':item[1][0],
            'label':['风险']
        }
        fwobj.write(json.dumps(tmp, ensure_ascii=False)+'\n')


In [221]:
candidate = [] 
positive = []
for item in result_list:
    if item[1]['senti'][0][1] >= 0.7 or item[1]['bias'][0][1] >= 0.8 or item[1]['ciron'][0][1] >= 0.8 or item[1]['offensive'][0][1] >= 0.8:
        candidate.append(item)
    else:
        positive.append(item)

In [9]:

offensive = []
with open('/data/albert.xht/sentiment/dev/offensive_cold.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        offensive.append(content)
        
offensive_test = []
with open('/data/albert.xht/sentiment/test/offensive_cold.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        offensive_test.append(content)

        
cdia_bias = []
with open('/data/albert.xht/sentiment/dev/cdial_bias.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        cdia_bias.append(content)
        
senti_copr = []
with open('/data/albert.xht/sentiment/dev/senti_copr.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        senti_copr.append(content)
        
ciron = []
with open('/data/albert.xht/sentiment/dev/chinese_ciron.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        ciron.append(content)

senti_smp = []
with open('/data/albert.xht/sentiment/dev/senti_smp_usual.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        senti_smp.append(content)
        
senti_smpecisa = []
with open('/data/albert.xht/sentiment/dev/senti_smpecisa.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        senti_smpecisa.append(content)

        
senti_query = []
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_valid.json.filter.0.7') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        senti_query.append(content)

In [10]:
from sklearn.metrics import classification_report
from tqdm import tqdm
import re

def eval_all(data, model, key):
    pred = []
    gold = []
    pred_score = []
    for item in tqdm(data):
        gold.append(item['label'][0])
        if isinstance(item['text'], list):
            text = "\n".join(item['text'])
        else:
            text = item['text']
        text = re.sub(r"([，\_《。》、？；：‘’＂“”【「】」·！@￥…（）—\,\<\.\>\/\?\;\:\'\"\[\]\{\}\~\`\!\@\#\$\%\^\&\*\(\)\-\=\+])+", "", text)   # 合并正文中过多的空格

        result = model.predict(text)
        score = sorted(result[key], key=lambda u:u[1], reverse=True)
        pred.append(score[0][0])
        pred_score.append(result[key])
    print(classification_report(gold, pred, digits=4))
    return pred, gold, pred_score
    


In [11]:

def evaluation_ece(pred_score, gold):
    pred_score_l = []
    mapping_dict = {}
    for item in pred_score:
        pred_score_l.append([])
        for idx, p in enumerate(item):
            if p[0] not in mapping_dict:
                mapping_dict[p[0]] = idx
            pred_score_l[-1].append(p[1])
    pred_score_l = torch.tensor(pred_score_l)
    gold_l = torch.tensor([mapping_dict[item] for item in gold])

    ece_fn = ECE(n_bins=15)
    print(ece_fn(pred_score_l, gold_l, mode='probs'), '==ece==')
# pred, gold, pred_score = eval_all(offensive_test, risk_api, 'offensive')
# evaluation_ece(pred_score, gold)


{'冒犯': 0, '正常': 1}
tensor([0.1119]) ==ece==


In [40]:
p

['正常', 0.9768216013908386]

In [37]:
mapping_dict

{}

In [28]:
item

'正常'

In [18]:
def evaluation(model_path):
    risk_api.reload(model_path)
    print('===offensive===')
    pred, gold, pred_score = eval_all(offensive_test, risk_api, 'offensive')
    evaluation_ece(pred_score, gold)
    print('===cdia-bias===')
    pred, gold, pred_score = eval_all(cdia_bias, risk_api, 'bias')
    evaluation_ece(pred_score, gold)
    print('===ciron===')
    pred, gold, pred_score = eval_all(ciron, risk_api, 'ciron')
    evaluation_ece(pred_score, gold)
    print('===chsenti===')
    pred, gold, pred_score = eval_all(senti_copr, risk_api, 'senti')
    evaluation_ece(pred_score, gold)
    print('===senti_smpecisa===')
    pred, gold, pred_score = eval_all(senti_smpecisa, risk_api, 'senti')
    evaluation_ece(pred_score, gold)
    print('===senti_smp===')
    pred, gold, pred_score = eval_all(senti_smp, risk_api, 'senti')
    evaluation_ece(pred_score, gold)
    print('===senti_query===')
    pred, gold, pred_score = eval_all(senti_query, risk_api, 'senti')
    evaluation_ece(pred_score, gold)
    

In [14]:
# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v4_intent_v2/multitask_cls.pth.9')

# # risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_query_risk_20221222/multitask_cls.pth.9')

# risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v4_intent_v2_5_aug/multitask_cls.pth.4')
risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v10_intent_v2-1_10_no_symbol_senti_query_senta_balanced_logitclip_v1/multitask_cls.pth.9')

intent_v2 = []
with open('/data/albert.xht/raw_chat_corpus/xiaoda/intention_data_v2/dev.txt') as frobj:
    for line in frobj:
        intent_v2.append(json.loads(line.strip()))
pred, gold, pred_score = eval_all(intent_v2, risk_api, 'intent')
evaluation_ece(pred_score, gold)


100%|██████████| 3000/3000 [00:25<00:00, 119.55it/s]

              precision    recall  f1-score   support

  主观评价/比较/判断     0.9211    0.9867    0.9528       603
          其它     0.9908    0.9743    0.9825      2219
     寻求建议/帮助     0.9186    0.8876    0.9029       178

    accuracy                         0.9717      3000
   macro avg     0.9435    0.9496    0.9460      3000
weighted avg     0.9725    0.9717    0.9718      3000

tensor([0.0111]) ==ece==





In [16]:
risk_api.reload('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v10_intent_v2-1_10_no_symbol_senti_query_senta_balanced_logitclip_v1/multitask_cls.pth.9')

risk_query = []
with open('/data/albert.xht/xiaodao/query_risk_v11/offensive_select_labeled.txt') as frobj:
    for line in frobj:
        risk_query.append(json.loads(line.strip()))
pred, gold, pred_score = eval_all(risk_query, risk_api, 'query_risk')
evaluation_ece(pred_score, gold)

100%|██████████| 20641/20641 [02:54<00:00, 118.21it/s]


              precision    recall  f1-score   support

          正常     0.5534    0.2285    0.3235      5514
          风险     0.7684    0.9328    0.8426     15127

    accuracy                         0.7446     20641
   macro avg     0.6609    0.5806    0.5830     20641
weighted avg     0.7109    0.7446    0.7039     20641

tensor([0.1919]) ==ece==


In [10]:
intent_v2 = []
with open('/data/albert.xht/raw_chat_corpus/xiaoda/intention_data_v2/dev.txt') as frobj:
    for line in frobj:
        intent_v2.append(json.loads(line.strip()))
pred, gold, pred_score = eval_all(intent_v2, risk_api, 'intent')
evaluation_ece(pred_score, gold)

  0%|          | 0/3000 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|██████████| 3000/3000 [00:25<00:00, 118.61it/s]

              precision    recall  f1-score   support

  主观评价/比较/判断     0.9183    0.9884    0.9521       603
          其它     0.9899    0.9716    0.9807      2219
     寻求建议/帮助     0.8902    0.8652    0.8775       178

    accuracy                         0.9687      3000
   macro avg     0.9328    0.9417    0.9367      3000
weighted avg     0.9696    0.9687    0.9688      3000

tensor([0.0188]) ==ece==





In [11]:
risk_api.predict('笑死我是什么感觉')


{'senti_query': [['负向', 0.748834490776062],
  ['中性', 0.1530851274728775],
  ['正向', 0.09808038920164108]],
 'senti': [['负向', 0.9947615265846252], ['正向', 0.0052385409362614155]],
 'bias': [['偏见', 0.0014122501015663147], ['正常', 0.9985877275466919]],
 'ciron': [['讽刺', 1.7454854969400913e-05], ['正常', 0.9999825954437256]],
 'intent': [['主观评价/比较/判断', 0.01418764516711235],
  ['寻求建议/帮助', 0.0049115875735878944],
  ['其它', 0.9809007048606873]],
 'offensive': [['冒犯', 0.012532699853181839], ['正常', 0.987467348575592]],
 'query_risk': [['风险', 0.0017439085058867931],
  ['个人信息', 0.0008314871229231358],
  ['正常', 0.9974246025085449]],
 'teenager': [['不良', 0.07456959038972855], ['正常', 0.9254303574562073]]}

In [None]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v5_intent_v2_3/multitask_cls.pth.4')


In [19]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v10_intent_v2-1_10_no_symbol_senti_query_senta_balanced_logitclip_v1/multitask_cls.pth.9')


  0%|          | 12/5304 [00:00<00:45, 116.77it/s]

===offensive===


100%|██████████| 5304/5304 [00:45<00:00, 116.38it/s]
  0%|          | 12/2829 [00:00<00:24, 112.99it/s]

              precision    recall  f1-score   support

          冒犯     0.7303    0.8613    0.7904      2106
          正常     0.8965    0.7905    0.8401      3198

    accuracy                         0.8186      5304
   macro avg     0.8134    0.8259    0.8153      5304
weighted avg     0.8305    0.8186    0.8204      5304

tensor([0.0965]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:24<00:00, 116.02it/s]
  1%|▏         | 12/875 [00:00<00:07, 117.36it/s]

              precision    recall  f1-score   support

          偏见     0.5202    0.6448    0.5759       718
          正常     0.8685    0.7977    0.8316      2111

    accuracy                         0.7589      2829
   macro avg     0.6944    0.7213    0.7037      2829
weighted avg     0.7801    0.7589    0.7667      2829

tensor([0.1289]) ==ece==
===ciron===


100%|██████████| 875/875 [00:07<00:00, 117.49it/s]
  1%|          | 12/1200 [00:00<00:10, 115.68it/s]

              precision    recall  f1-score   support

          正常     0.9390    0.9089    0.9237       779
          讽刺     0.4132    0.5208    0.4608        96

    accuracy                         0.8663       875
   macro avg     0.6761    0.7148    0.6923       875
weighted avg     0.8813    0.8663    0.8729       875

tensor([0.0541]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:10<00:00, 114.52it/s]
  0%|          | 12/2529 [00:00<00:21, 117.82it/s]

              precision    recall  f1-score   support

          正向     0.9183    0.9292    0.9237       593
          负向     0.9300    0.9193    0.9246       607

    accuracy                         0.9242      1200
   macro avg     0.9242    0.9242    0.9242      1200
weighted avg     0.9242    0.9242    0.9242      1200

tensor([0.0321]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:21<00:00, 118.50it/s]
  0%|          | 12/2844 [00:00<00:23, 118.78it/s]

              precision    recall  f1-score   support

          正向     0.8015    0.8102    0.8058      1201
          负向     0.8266    0.8185    0.8226      1328

    accuracy                         0.8146      2529
   macro avg     0.8140    0.8143    0.8142      2529
weighted avg     0.8147    0.8146    0.8146      2529

tensor([0.1175]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:24<00:00, 117.36it/s]
  0%|          | 12/50509 [00:00<07:05, 118.65it/s]

              precision    recall  f1-score   support

          正向     0.8383    0.8934    0.8650      1126
          负向     0.9270    0.8871    0.9066      1718

    accuracy                         0.8896      2844
   macro avg     0.8827    0.8903    0.8858      2844
weighted avg     0.8919    0.8896    0.8901      2844

tensor([0.0691]) ==ece==
===senti_query===


 40%|███▉      | 20181/50509 [02:51<04:18, 117.51it/s]


KeyboardInterrupt: 

In [20]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v10_intent_v2-1_10_no_symbol_senti_query_senta_balanced_logitclip/multitask_cls.pth.9')


In [12]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v9_intent_v2-1_10_no_symbol_v1/multitask_cls.pth.9')


In [13]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v6_intent_v2-1_10_no_symbol/multitask_cls.pth.9')


In [19]:
# evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v5_no_offensive_intent_v2-1_3/multitask_cls.pth.4')


In [20]:
# evaluation('/data/albert.xht/xiaodao/multitask_raw_filter_senti_query_risk_v5_intent_v2-1_3/multitask_cls.pth.4')


In [1]:
import random
random.randint(0, 1), random.sample(list(range(1, 2)), k=0), range(1, 0)

(0, [], range(1, 0))

In [23]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v4_intent_v2_5_aug/multitask_cls.pth.4')


===offensive===


100%|██████████| 5304/5304 [00:43<00:00, 121.00it/s]


              precision    recall  f1-score   support

          冒犯     0.7306    0.8523    0.7868      2106
          正常     0.8908    0.7930    0.8390      3198

    accuracy                         0.8166      5304
   macro avg     0.8107    0.8227    0.8129      5304
weighted avg     0.8272    0.8166    0.8183      5304

tensor([0.1362]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:23<00:00, 120.11it/s]


              precision    recall  f1-score   support

          偏见     0.5761    0.5376    0.5562       718
          正常     0.8462    0.8655    0.8557      2111

    accuracy                         0.7823      2829
   macro avg     0.7112    0.7015    0.7060      2829
weighted avg     0.7777    0.7823    0.7797      2829

tensor([0.0758]) ==ece==
===ciron===


100%|██████████| 875/875 [00:07<00:00, 122.01it/s]


              precision    recall  f1-score   support

          正常     0.9246    0.9602    0.9421       779
          讽刺     0.5303    0.3646    0.4321        96

    accuracy                         0.8949       875
   macro avg     0.7275    0.6624    0.6871       875
weighted avg     0.8813    0.8949    0.8861       875

tensor([0.0579]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:10<00:00, 118.20it/s]


              precision    recall  f1-score   support

          正向     0.9480    0.9528    0.9504       593
          负向     0.9536    0.9489    0.9513       607

    accuracy                         0.9508      1200
   macro avg     0.9508    0.9509    0.9508      1200
weighted avg     0.9508    0.9508    0.9508      1200

tensor([0.0192]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:20<00:00, 122.04it/s]


              precision    recall  f1-score   support

          正向     0.7700    0.8501    0.8081      1201
          负向     0.8504    0.7703    0.8084      1328

    accuracy                         0.8082      2529
   macro avg     0.8102    0.8102    0.8082      2529
weighted avg     0.8122    0.8082    0.8082      2529

tensor([0.1563]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:23<00:00, 121.86it/s]


              precision    recall  f1-score   support

          正向     0.8057    0.8988    0.8497      1126
          负向     0.9282    0.8580    0.8917      1718

    accuracy                         0.8741      2844
   macro avg     0.8670    0.8784    0.8707      2844
weighted avg     0.8797    0.8741    0.8751      2844

tensor([0.0911]) ==ece==
===senti_query===


 60%|█████▉    | 30108/50509 [04:09<02:48, 120.89it/s]


KeyboardInterrupt: 

In [19]:

evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_filter_senti_query_risk_v4_intent_v2/multitask_cls.pth.8')


===offensive===


100%|██████████| 5304/5304 [00:47<00:00, 111.59it/s]


              precision    recall  f1-score   support

          冒犯     0.7309    0.8680    0.7936      2106
          正常     0.9008    0.7896    0.8415      3198

    accuracy                         0.8207      5304
   macro avg     0.8159    0.8288    0.8176      5304
weighted avg     0.8334    0.8207    0.8225      5304

tensor([0.1185]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:23<00:00, 122.14it/s]


              precision    recall  f1-score   support

          偏见     0.5839    0.5139    0.5467       718
          正常     0.8411    0.8754    0.8579      2111

    accuracy                         0.7837      2829
   macro avg     0.7125    0.6947    0.7023      2829
weighted avg     0.7758    0.7837    0.7789      2829

tensor([0.0372]) ==ece==
===ciron===


100%|██████████| 875/875 [00:07<00:00, 124.73it/s]


              precision    recall  f1-score   support

          正常     0.9191    0.9628    0.9404       779
          讽刺     0.5085    0.3125    0.3871        96

    accuracy                         0.8914       875
   macro avg     0.7138    0.6376    0.6638       875
weighted avg     0.8741    0.8914    0.8797       875

tensor([0.0317]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:09<00:00, 121.65it/s]


              precision    recall  f1-score   support

          正向     0.9305    0.9477    0.9390       593
          负向     0.9480    0.9308    0.9393       607

    accuracy                         0.9392      1200
   macro avg     0.9392    0.9393    0.9392      1200
weighted avg     0.9393    0.9392    0.9392      1200

tensor([0.0284]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:20<00:00, 124.54it/s]


              precision    recall  f1-score   support

          正向     0.7702    0.8510    0.8085      1201
          负向     0.8511    0.7703    0.8087      1328

    accuracy                         0.8086      2529
   macro avg     0.8106    0.8106    0.8086      2529
weighted avg     0.8127    0.8086    0.8086      2529

tensor([0.1475]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:23<00:00, 123.29it/s]


              precision    recall  f1-score   support

          正向     0.8225    0.9014    0.8602      1126
          负向     0.9311    0.8725    0.9008      1718

    accuracy                         0.8840      2844
   macro avg     0.8768    0.8870    0.8805      2844
weighted avg     0.8881    0.8840    0.8847      2844

tensor([0.0719]) ==ece==
===senti_query===


 78%|███████▊  | 39435/50509 [05:17<01:29, 124.29it/s]


KeyboardInterrupt: 

In [25]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_query_risk_20221222/multitask_cls.pth.9')


===offensive===


100%|██████████| 5304/5304 [00:43<00:00, 122.71it/s]


              precision    recall  f1-score   support

          冒犯     0.7344    0.8599    0.7922      2106
          正常     0.8961    0.7952    0.8426      3198

    accuracy                         0.8209      5304
   macro avg     0.8152    0.8276    0.8174      5304
weighted avg     0.8319    0.8209    0.8226      5304

tensor([0.1135]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:23<00:00, 122.61it/s]


              precision    recall  f1-score   support

          偏见     0.5700    0.4930    0.5288       718
          正常     0.8351    0.8735    0.8539      2111

    accuracy                         0.7770      2829
   macro avg     0.7026    0.6833    0.6913      2829
weighted avg     0.7679    0.7770    0.7714      2829

tensor([0.0479]) ==ece==
===ciron===


100%|██████████| 875/875 [00:07<00:00, 122.63it/s]


              precision    recall  f1-score   support

          正常     0.9168    0.9615    0.9386       779
          讽刺     0.4828    0.2917    0.3636        96

    accuracy                         0.8880       875
   macro avg     0.6998    0.6266    0.6511       875
weighted avg     0.8692    0.8880    0.8755       875

tensor([0.0333]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:10<00:00, 119.86it/s]


              precision    recall  f1-score   support

          正向     0.9316    0.9410    0.9362       593
          负向     0.9418    0.9325    0.9371       607

    accuracy                         0.9367      1200
   macro avg     0.9367    0.9367    0.9367      1200
weighted avg     0.9367    0.9367    0.9367      1200

tensor([0.0270]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:20<00:00, 122.65it/s]


              precision    recall  f1-score   support

          正向     0.7767    0.8543    0.8136      1201
          负向     0.8551    0.7779    0.8147      1328

    accuracy                         0.8142      2529
   macro avg     0.8159    0.8161    0.8142      2529
weighted avg     0.8179    0.8142    0.8142      2529

tensor([0.1410]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:23<00:00, 122.76it/s]


              precision    recall  f1-score   support

          正向     0.8195    0.9032    0.8593      1126
          负向     0.9320    0.8696    0.8997      1718

    accuracy                         0.8829      2844
   macro avg     0.8758    0.8864    0.8795      2844
weighted avg     0.8875    0.8829    0.8837      2844

tensor([0.0736]) ==ece==
===senti_query===


100%|██████████| 50509/50509 [06:45<00:00, 124.52it/s]


              precision    recall  f1-score   support

          正向     0.9686    0.9721    0.9704     28251
          负向     0.9645    0.9600    0.9622     22258

    accuracy                         0.9668     50509
   macro avg     0.9665    0.9661    0.9663     50509
weighted avg     0.9668    0.9668    0.9668     50509

tensor([0.0188]) ==ece==


In [11]:
evaluation('/data/albert.xht/xiaodao/query_risk_v3/multitask_raw_filter_senti_query_risk_v3/multitask_cls.pth.9')

===offensive===


100%|██████████| 5304/5304 [00:42<00:00, 123.49it/s]


              precision    recall  f1-score   support

          冒犯     0.7448    0.8538    0.7956      2106
          正常     0.8934    0.8074    0.8482      3198

    accuracy                         0.8258      5304
   macro avg     0.8191    0.8306    0.8219      5304
weighted avg     0.8344    0.8258    0.8273      5304

tensor([0.1187]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:22<00:00, 123.55it/s]


              precision    recall  f1-score   support

          偏见     0.5665    0.5752    0.5708       718
          正常     0.8548    0.8503    0.8525      2111

    accuracy                         0.7805      2829
   macro avg     0.7106    0.7128    0.7117      2829
weighted avg     0.7816    0.7805    0.7810      2829

tensor([0.0452]) ==ece==
===ciron===


100%|██████████| 875/875 [00:07<00:00, 124.43it/s]


              precision    recall  f1-score   support

          正常     0.9299    0.9371    0.9335       779
          讽刺     0.4556    0.4271    0.4409        96

    accuracy                         0.8811       875
   macro avg     0.6927    0.6821    0.6872       875
weighted avg     0.8779    0.8811    0.8795       875

tensor([0.0460]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:09<00:00, 120.96it/s]


              precision    recall  f1-score   support

          正向     0.9282    0.9595    0.9436       593
          负向     0.9591    0.9275    0.9430       607

    accuracy                         0.9433      1200
   macro avg     0.9437    0.9435    0.9433      1200
weighted avg     0.9438    0.9433    0.9433      1200

tensor([0.0386]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:20<00:00, 124.37it/s]


              precision    recall  f1-score   support

          正向     0.7543    0.8485    0.7986      1201
          负向     0.8455    0.7500    0.7949      1328

    accuracy                         0.7968      2529
   macro avg     0.7999    0.7992    0.7967      2529
weighted avg     0.8022    0.7968    0.7966      2529

tensor([0.1659]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:23<00:00, 123.31it/s]


              precision    recall  f1-score   support

          正向     0.8233    0.8899    0.8553      1126
          负向     0.9238    0.8749    0.8987      1718

    accuracy                         0.8808      2844
   macro avg     0.8736    0.8824    0.8770      2844
weighted avg     0.8840    0.8808    0.8815      2844

tensor([0.0784]) ==ece==


In [156]:

evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256_20/multitask_cls.pth.9')

===offensive===


100%|██████████| 5304/5304 [00:42<00:00, 124.52it/s]


              precision    recall  f1-score   support

          冒犯     0.7084    0.8837    0.7864      2106
          正常     0.9085    0.7605    0.8279      3198

    accuracy                         0.8094      5304
   macro avg     0.8084    0.8221    0.8072      5304
weighted avg     0.8290    0.8094    0.8114      5304

tensor([0.1319]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:22<00:00, 124.59it/s]


              precision    recall  f1-score   support

          偏见     0.5622    0.6045    0.5826       718
          正常     0.8619    0.8399    0.8508      2111

    accuracy                         0.7801      2829
   macro avg     0.7121    0.7222    0.7167      2829
weighted avg     0.7859    0.7801    0.7827      2829

tensor([0.0458]) ==ece==
===ciron===


100%|██████████| 875/875 [00:06<00:00, 125.92it/s]


              precision    recall  f1-score   support

          正常     0.9287    0.9538    0.9411       779
          讽刺     0.5200    0.4062    0.4561        96

    accuracy                         0.8937       875
   macro avg     0.7244    0.6800    0.6986       875
weighted avg     0.8839    0.8937    0.8879       875

tensor([0.0358]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:09<00:00, 122.62it/s]


              precision    recall  f1-score   support

          正向     0.9112    0.9174    0.9143       593
          负向     0.9187    0.9127    0.9157       607

    accuracy                         0.9150      1200
   macro avg     0.9150    0.9150    0.9150      1200
weighted avg     0.9150    0.9150    0.9150      1200

tensor([0.0297]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:20<00:00, 126.30it/s]


              precision    recall  f1-score   support

          正向     0.8326    0.7868    0.8091      1201
          负向     0.8164    0.8569    0.8361      1328

    accuracy                         0.8236      2529
   macro avg     0.8245    0.8219    0.8226      2529
weighted avg     0.8241    0.8236    0.8233      2529

tensor([0.0308]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:22<00:00, 126.33it/s]


              precision    recall  f1-score   support

          正向     0.8724    0.8321    0.8518      1126
          负向     0.8932    0.9203    0.9065      1718

    accuracy                         0.8854      2844
   macro avg     0.8828    0.8762    0.8792      2844
weighted avg     0.8850    0.8854    0.8849      2844

tensor([0.0111]) ==ece==


In [46]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_intent_v1/multitask_cls.pth.9')

===offensive===


100%|██████████| 5304/5304 [00:42<00:00, 125.32it/s]


              precision    recall  f1-score   support

          冒犯     0.7050    0.8794    0.7826      2106
          正常     0.9051    0.7577    0.8249      3198

    accuracy                         0.8060      5304
   macro avg     0.8051    0.8185    0.8037      5304
weighted avg     0.8257    0.8060    0.8081      5304

tensor([0.1119]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:25<00:00, 111.92it/s]


              precision    recall  f1-score   support

          偏见     0.6208    0.4903    0.5479       718
          正常     0.8382    0.8982    0.8671      2111

    accuracy                         0.7946      2829
   macro avg     0.7295    0.6942    0.7075      2829
weighted avg     0.7830    0.7946    0.7861      2829

tensor([0.0231]) ==ece==
===ciron===


100%|██████████| 875/875 [00:08<00:00, 98.27it/s]


              precision    recall  f1-score   support

          正常     0.9255    0.9730    0.9487       779
          讽刺     0.6250    0.3646    0.4605        96

    accuracy                         0.9063       875
   macro avg     0.7753    0.6688    0.7046       875
weighted avg     0.8925    0.9063    0.8951       875

tensor([0.0363]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:12<00:00, 96.01it/s]


              precision    recall  f1-score   support

          正向     0.8838    0.9106    0.8970       593
          负向     0.9100    0.8830    0.8963       607

    accuracy                         0.8967      1200
   macro avg     0.8969    0.8968    0.8967      1200
weighted avg     0.8971    0.8967    0.8967      1200

tensor([0.0226]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:22<00:00, 111.83it/s]


              precision    recall  f1-score   support

          正向     0.8053    0.8160    0.8106      1201
          负向     0.8316    0.8215    0.8265      1328

    accuracy                         0.8189      2529
   macro avg     0.8184    0.8188    0.8186      2529
weighted avg     0.8191    0.8189    0.8190      2529

tensor([0.0208]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:22<00:00, 125.04it/s]


              precision    recall  f1-score   support

          正向     0.8540    0.8366    0.8452      1126
          负向     0.8943    0.9063    0.9003      1718

    accuracy                         0.8787      2844
   macro avg     0.8742    0.8714    0.8727      2844
weighted avg     0.8784    0.8787    0.8785      2844

tensor([0.0247]) ==ece==


In [48]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_focal/multitask_cls.pth.9')

===offensive===


100%|██████████| 5304/5304 [00:42<00:00, 125.80it/s]


              precision    recall  f1-score   support

          冒犯     0.7068    0.8746    0.7818      2106
          正常     0.9021    0.7611    0.8256      3198

    accuracy                         0.8062      5304
   macro avg     0.8045    0.8179    0.8037      5304
weighted avg     0.8246    0.8062    0.8082      5304

tensor([0.0195]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:22<00:00, 125.39it/s]


              precision    recall  f1-score   support

          偏见     0.5951    0.4749    0.5283       718
          正常     0.8329    0.8901    0.8605      2111

    accuracy                         0.7847      2829
   macro avg     0.7140    0.6825    0.6944      2829
weighted avg     0.7725    0.7847    0.7762      2829

tensor([0.1170]) ==ece==
===ciron===


100%|██████████| 875/875 [00:06<00:00, 125.42it/s]


              precision    recall  f1-score   support

          正常     0.9230    0.9692    0.9455       779
          讽刺     0.5789    0.3438    0.4314        96

    accuracy                         0.9006       875
   macro avg     0.7510    0.6565    0.6884       875
weighted avg     0.8852    0.9006    0.8891       875

tensor([0.1106]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:09<00:00, 121.55it/s]


              precision    recall  f1-score   support

          正向     0.8807    0.9089    0.8946       593
          负向     0.9082    0.8797    0.8937       607

    accuracy                         0.8942      1200
   macro avg     0.8944    0.8943    0.8942      1200
weighted avg     0.8946    0.8942    0.8942      1200

tensor([0.1561]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:20<00:00, 126.43it/s]


              precision    recall  f1-score   support

          正向     0.8163    0.8251    0.8207      1201
          负向     0.8403    0.8321    0.8362      1328

    accuracy                         0.8288      2529
   macro avg     0.8283    0.8286    0.8284      2529
weighted avg     0.8289    0.8288    0.8288      2529

tensor([0.1556]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:22<00:00, 125.37it/s]


              precision    recall  f1-score   support

          正向     0.8627    0.8313    0.8467      1126
          负向     0.8920    0.9133    0.9025      1718

    accuracy                         0.8808      2844
   macro avg     0.8773    0.8723    0.8746      2844
weighted avg     0.8804    0.8808    0.8804      2844

tensor([0.1674]) ==ece==


In [50]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_balanced_intent_v1//multitask_cls.pth.9')



===offensive===


100%|██████████| 5304/5304 [00:42<00:00, 124.95it/s]


              precision    recall  f1-score   support

          冒犯     0.7034    0.8818    0.7826      2106
          正常     0.9065    0.7552    0.8240      3198

    accuracy                         0.8054      5304
   macro avg     0.8050    0.8185    0.8033      5304
weighted avg     0.8259    0.8054    0.8075      5304

tensor([0.1129]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:22<00:00, 125.58it/s]


              precision    recall  f1-score   support

          偏见     0.6035    0.4833    0.5367       718
          正常     0.8354    0.8920    0.8628      2111

    accuracy                         0.7883      2829
   macro avg     0.7194    0.6876    0.6998      2829
weighted avg     0.7765    0.7883    0.7800      2829

tensor([0.0261]) ==ece==
===ciron===


100%|██████████| 875/875 [00:06<00:00, 126.61it/s]


              precision    recall  f1-score   support

          正常     0.9126    0.9653    0.9382       779
          讽刺     0.4706    0.2500    0.3265        96

    accuracy                         0.8869       875
   macro avg     0.6916    0.6077    0.6324       875
weighted avg     0.8641    0.8869    0.8711       875

tensor([0.0374]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:09<00:00, 122.86it/s]


              precision    recall  f1-score   support

          正向     0.8831    0.9174    0.8999       593
          负向     0.9161    0.8814    0.8984       607

    accuracy                         0.8992      1200
   macro avg     0.8996    0.8994    0.8992      1200
weighted avg     0.8998    0.8992    0.8992      1200

tensor([0.0285]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:19<00:00, 127.17it/s]


              precision    recall  f1-score   support

          正向     0.8099    0.8301    0.8199      1201
          负向     0.8428    0.8238    0.8332      1328

    accuracy                         0.8268      2529
   macro avg     0.8264    0.8270    0.8266      2529
weighted avg     0.8272    0.8268    0.8269      2529

tensor([0.0214]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:22<00:00, 126.09it/s]


              precision    recall  f1-score   support

          正向     0.8552    0.8393    0.8472      1126
          负向     0.8959    0.9069    0.9014      1718

    accuracy                         0.8801      2844
   macro avg     0.8756    0.8731    0.8743      2844
weighted avg     0.8798    0.8801    0.8799      2844

tensor([0.0243]) ==ece==


In [51]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_focal_128/multitask_cls.pth.9')


===offensive===


100%|██████████| 5304/5304 [00:42<00:00, 125.00it/s]


              precision    recall  f1-score   support

          冒犯     0.7314    0.8661    0.7930      2106
          正常     0.8996    0.7905    0.8415      3198

    accuracy                         0.8205      5304
   macro avg     0.8155    0.8283    0.8173      5304
weighted avg     0.8328    0.8205    0.8223      5304

tensor([0.0103]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:22<00:00, 124.37it/s]


              precision    recall  f1-score   support

          偏见     0.5877    0.4944    0.5371       718
          正常     0.8369    0.8820    0.8589      2111

    accuracy                         0.7837      2829
   macro avg     0.7123    0.6882    0.6980      2829
weighted avg     0.7736    0.7837    0.7772      2829

tensor([0.1111]) ==ece==
===ciron===


100%|██████████| 875/875 [00:07<00:00, 114.74it/s]


              precision    recall  f1-score   support

          正常     0.9174    0.9692    0.9426       779
          讽刺     0.5385    0.2917    0.3784        96

    accuracy                         0.8949       875
   macro avg     0.7279    0.6304    0.6605       875
weighted avg     0.8758    0.8949    0.8807       875

tensor([0.1099]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:10<00:00, 116.51it/s]


              precision    recall  f1-score   support

          正向     0.8941    0.9258    0.9097       593
          负向     0.9249    0.8929    0.9086       607

    accuracy                         0.9092      1200
   macro avg     0.9095    0.9094    0.9092      1200
weighted avg     0.9097    0.9092    0.9092      1200

tensor([0.1246]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:20<00:00, 125.53it/s]


              precision    recall  f1-score   support

          正向     0.8103    0.8218    0.8160      1201
          负向     0.8368    0.8261    0.8314      1328

    accuracy                         0.8240      2529
   macro avg     0.8236    0.8239    0.8237      2529
weighted avg     0.8242    0.8240    0.8241      2529

tensor([0.1440]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:22<00:00, 124.13it/s]


              precision    recall  f1-score   support

          正向     0.8532    0.8464    0.8498      1126
          负向     0.8998    0.9045    0.9022      1718

    accuracy                         0.8815      2844
   macro avg     0.8765    0.8754    0.8760      2844
weighted avg     0.8814    0.8815    0.8814      2844

tensor([0.1548]) ==ece==


In [52]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_contrast_balanced_intent_v1/multitask_cls.pth.9')


===offensive===


100%|██████████| 5304/5304 [00:42<00:00, 125.22it/s]


              precision    recall  f1-score   support

          冒犯     0.7074    0.8770    0.7831      2106
          正常     0.9038    0.7611    0.8263      3198

    accuracy                         0.8071      5304
   macro avg     0.8056    0.8191    0.8047      5304
weighted avg     0.8258    0.8071    0.8092      5304

tensor([0.1762]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:22<00:00, 124.54it/s]


              precision    recall  f1-score   support

          偏见     0.5777    0.5125    0.5432       718
          正常     0.8403    0.8726    0.8561      2111

    accuracy                         0.7812      2829
   macro avg     0.7090    0.6926    0.6997      2829
weighted avg     0.7737    0.7812    0.7767      2829

tensor([0.1900]) ==ece==
===ciron===


100%|██████████| 875/875 [00:06<00:00, 125.17it/s]


              precision    recall  f1-score   support

          正常     0.9294    0.9461    0.9377       779
          讽刺     0.4878    0.4167    0.4494        96

    accuracy                         0.8880       875
   macro avg     0.7086    0.6814    0.6935       875
weighted avg     0.8809    0.8880    0.8841       875

tensor([0.1048]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:09<00:00, 121.88it/s]


              precision    recall  f1-score   support

          正向     0.9314    0.9157    0.9235       593
          负向     0.9190    0.9341    0.9265       607

    accuracy                         0.9250      1200
   macro avg     0.9252    0.9249    0.9250      1200
weighted avg     0.9251    0.9250    0.9250      1200

tensor([0.0596]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:20<00:00, 125.97it/s]


              precision    recall  f1-score   support

          正向     0.8183    0.8285    0.8233      1201
          负向     0.8431    0.8336    0.8383      1328

    accuracy                         0.8312      2529
   macro avg     0.8307    0.8310    0.8308      2529
weighted avg     0.8313    0.8312    0.8312      2529

tensor([0.1288]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:22<00:00, 126.45it/s]


              precision    recall  f1-score   support

          正向     0.8613    0.8437    0.8524      1126
          负向     0.8989    0.9109    0.9049      1718

    accuracy                         0.8843      2844
   macro avg     0.8801    0.8773    0.8786      2844
weighted avg     0.8840    0.8843    0.8841      2844

tensor([0.0493]) ==ece==


In [66]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_mtdnn_ce_256/multitask_cls.pth.9')


===offensive===


100%|██████████| 5304/5304 [00:42<00:00, 125.73it/s]


              precision    recall  f1-score   support

          冒犯     0.7229    0.8756    0.7919      2106
          正常     0.9048    0.7789    0.8372      3198

    accuracy                         0.8173      5304
   macro avg     0.8138    0.8273    0.8145      5304
weighted avg     0.8326    0.8173    0.8192      5304

tensor([0.1323]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:22<00:00, 126.27it/s]


              precision    recall  f1-score   support

          偏见     0.5892    0.5153    0.5498       718
          正常     0.8419    0.8778    0.8595      2111

    accuracy                         0.7858      2829
   macro avg     0.7155    0.6966    0.7046      2829
weighted avg     0.7778    0.7858    0.7809      2829

tensor([0.0623]) ==ece==
===ciron===


100%|██████████| 875/875 [00:06<00:00, 126.06it/s]


              precision    recall  f1-score   support

          正常     0.9188    0.9589    0.9384       779
          讽刺     0.4839    0.3125    0.3797        96

    accuracy                         0.8880       875
   macro avg     0.7013    0.6357    0.6591       875
weighted avg     0.8711    0.8880    0.8771       875

tensor([0.0520]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:09<00:00, 122.59it/s]


              precision    recall  f1-score   support

          正向     0.9129    0.9191    0.9160       593
          负向     0.9204    0.9143    0.9174       607

    accuracy                         0.9167      1200
   macro avg     0.9166    0.9167    0.9167      1200
weighted avg     0.9167    0.9167    0.9167      1200

tensor([0.0421]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:19<00:00, 126.73it/s]


              precision    recall  f1-score   support

          正向     0.8085    0.8085    0.8085      1201
          负向     0.8268    0.8268    0.8268      1328

    accuracy                         0.8181      2529
   macro avg     0.8177    0.8177    0.8177      2529
weighted avg     0.8181    0.8181    0.8181      2529

tensor([0.0537]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:22<00:00, 124.94it/s]


              precision    recall  f1-score   support

          正向     0.8543    0.8330    0.8435      1126
          负向     0.8923    0.9069    0.8995      1718

    accuracy                         0.8776      2844
   macro avg     0.8733    0.8700    0.8715      2844
weighted avg     0.8773    0.8776    0.8774      2844

tensor([0.0230]) ==ece==


In [67]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_mtdnn_focal_256/multitask_cls.pth.9')


===offensive===


100%|██████████| 5304/5304 [00:42<00:00, 125.26it/s]


              precision    recall  f1-score   support

          冒犯     0.7248    0.8590    0.7862      2106
          正常     0.8942    0.7852    0.8362      3198

    accuracy                         0.8145      5304
   macro avg     0.8095    0.8221    0.8112      5304
weighted avg     0.8269    0.8145    0.8163      5304

tensor([0.0393]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:22<00:00, 124.76it/s]


              precision    recall  f1-score   support

          偏见     0.5753    0.5056    0.5382       718
          正常     0.8385    0.8730    0.8554      2111

    accuracy                         0.7798      2829
   macro avg     0.7069    0.6893    0.6968      2829
weighted avg     0.7717    0.7798    0.7749      2829

tensor([0.0811]) ==ece==
===ciron===


100%|██████████| 875/875 [00:06<00:00, 125.40it/s]


              precision    recall  f1-score   support

          正常     0.9243    0.9564    0.9401       779
          讽刺     0.5072    0.3646    0.4242        96

    accuracy                         0.8914       875
   macro avg     0.7158    0.6605    0.6822       875
weighted avg     0.8786    0.8914    0.8835       875

tensor([0.0768]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:09<00:00, 121.72it/s]


              precision    recall  f1-score   support

          正向     0.9132    0.9224    0.9178       593
          负向     0.9235    0.9143    0.9189       607

    accuracy                         0.9183      1200
   macro avg     0.9183    0.9184    0.9183      1200
weighted avg     0.9184    0.9183    0.9183      1200

tensor([0.0882]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:20<00:00, 125.80it/s]


              precision    recall  f1-score   support

          正向     0.8040    0.8027    0.8033      1201
          负向     0.8218    0.8230    0.8224      1328

    accuracy                         0.8134      2529
   macro avg     0.8129    0.8129    0.8129      2529
weighted avg     0.8134    0.8134    0.8134      2529

tensor([0.0975]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:22<00:00, 124.79it/s]


              precision    recall  f1-score   support

          正向     0.8523    0.8250    0.8384      1126
          负向     0.8877    0.9063    0.8969      1718

    accuracy                         0.8741      2844
   macro avg     0.8700    0.8657    0.8677      2844
weighted avg     0.8737    0.8741    0.8738      2844

tensor([0.1215]) ==ece==


In [71]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_ce_256/multitask_cls.pth.9')




===offensive===


100%|██████████| 6417/6417 [00:51<00:00, 124.88it/s]


              precision    recall  f1-score   support

          冒犯     0.9025    0.9264    0.9143      3208
          正常     0.9245    0.9000    0.9120      3209

    accuracy                         0.9132      6417
   macro avg     0.9135    0.9132    0.9132      6417
weighted avg     0.9135    0.9132    0.9132      6417

tensor([0.0343]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:22<00:00, 124.98it/s]


              precision    recall  f1-score   support

          偏见     0.5808    0.5153    0.5461       718
          正常     0.8412    0.8735    0.8571      2111

    accuracy                         0.7826      2829
   macro avg     0.7110    0.6944    0.7016      2829
weighted avg     0.7752    0.7826    0.7782      2829

tensor([0.0418]) ==ece==
===ciron===


100%|██████████| 875/875 [00:06<00:00, 126.12it/s]


              precision    recall  f1-score   support

          正常     0.9220    0.9718    0.9462       779
          讽刺     0.5926    0.3333    0.4267        96

    accuracy                         0.9017       875
   macro avg     0.7573    0.6525    0.6865       875
weighted avg     0.8859    0.9017    0.8892       875

tensor([0.0266]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:09<00:00, 122.70it/s]


              precision    recall  f1-score   support

          正向     0.9028    0.9241    0.9133       593
          负向     0.9241    0.9028    0.9133       607

    accuracy                         0.9133      1200
   macro avg     0.9135    0.9135    0.9133      1200
weighted avg     0.9136    0.9133    0.9133      1200

tensor([0.0361]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:19<00:00, 126.84it/s]


              precision    recall  f1-score   support

          正向     0.8196    0.8285    0.8240      1201
          负向     0.8433    0.8351    0.8392      1328

    accuracy                         0.8319      2529
   macro avg     0.8315    0.8318    0.8316      2529
weighted avg     0.8321    0.8319    0.8320      2529

tensor([0.0217]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:22<00:00, 126.04it/s]


              precision    recall  f1-score   support

          正向     0.8630    0.8446    0.8537      1126
          负向     0.8995    0.9121    0.9058      1718

    accuracy                         0.8854      2844
   macro avg     0.8813    0.8783    0.8797      2844
weighted avg     0.8851    0.8854    0.8852      2844

tensor([0.0128]) ==ece==


In [72]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_raw_all_focal_256/multitask_cls.pth.9')


===offensive===


100%|██████████| 6417/6417 [00:51<00:00, 125.02it/s]


              precision    recall  f1-score   support

          冒犯     0.9075    0.9261    0.9167      3208
          正常     0.9246    0.9056    0.9150      3209

    accuracy                         0.9158      6417
   macro avg     0.9160    0.9159    0.9158      6417
weighted avg     0.9160    0.9158    0.9158      6417

tensor([0.0770]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:25<00:00, 109.98it/s]


              precision    recall  f1-score   support

          偏见     0.5963    0.4958    0.5414       718
          正常     0.8378    0.8858    0.8612      2111

    accuracy                         0.7869      2829
   macro avg     0.7171    0.6908    0.7013      2829
weighted avg     0.7765    0.7869    0.7800      2829

tensor([0.1125]) ==ece==
===ciron===


100%|██████████| 875/875 [00:06<00:00, 126.53it/s]


              precision    recall  f1-score   support

          正常     0.9221    0.9576    0.9395       779
          讽刺     0.5000    0.3438    0.4074        96

    accuracy                         0.8903       875
   macro avg     0.7111    0.6507    0.6735       875
weighted avg     0.8758    0.8903    0.8812       875

tensor([0.1056]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:09<00:00, 122.87it/s]


              precision    recall  f1-score   support

          正向     0.8992    0.9174    0.9082       593
          负向     0.9176    0.8995    0.9085       607

    accuracy                         0.9083      1200
   macro avg     0.9084    0.9084    0.9083      1200
weighted avg     0.9085    0.9083    0.9083      1200

tensor([0.1026]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:19<00:00, 127.02it/s]


              precision    recall  f1-score   support

          正向     0.8167    0.8127    0.8147      1201
          负向     0.8313    0.8351    0.8332      1328

    accuracy                         0.8244      2529
   macro avg     0.8240    0.8239    0.8239      2529
weighted avg     0.8244    0.8244    0.8244      2529

tensor([0.1432]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:22<00:00, 126.06it/s]


              precision    recall  f1-score   support

          正向     0.8662    0.8393    0.8525      1126
          负向     0.8967    0.9150    0.9058      1718

    accuracy                         0.8850      2844
   macro avg     0.8815    0.8771    0.8791      2844
weighted avg     0.8846    0.8850    0.8847      2844

tensor([0.1520]) ==ece==


In [75]:
evaluation('/data/albert.xht/xiaodao/risk_classification/multitask_contrast_intent_v1/multitask_cls.pth.9')


===offensive===


100%|██████████| 5304/5304 [00:42<00:00, 125.36it/s]


              precision    recall  f1-score   support

          冒犯     0.7202    0.8371    0.7743      2106
          正常     0.8799    0.7858    0.8302      3198

    accuracy                         0.8062      5304
   macro avg     0.8000    0.8115    0.8022      5304
weighted avg     0.8165    0.8062    0.8080      5304

tensor([0.0912]) ==ece==
===cdia-bias===


100%|██████████| 2829/2829 [00:22<00:00, 125.30it/s]


              precision    recall  f1-score   support

          偏见     0.6115    0.3245    0.4240       718
          正常     0.8019    0.9299    0.8612      2111

    accuracy                         0.7762      2829
   macro avg     0.7067    0.6272    0.6426      2829
weighted avg     0.7536    0.7762    0.7502      2829

tensor([0.0236]) ==ece==
===ciron===


100%|██████████| 875/875 [00:06<00:00, 126.19it/s]


              precision    recall  f1-score   support

          正常     0.8976    0.9795    0.9368       779
          讽刺     0.3600    0.0938    0.1488        96

    accuracy                         0.8823       875
   macro avg     0.6288    0.5366    0.5428       875
weighted avg     0.8387    0.8823    0.8503       875

tensor([0.0214]) ==ece==
===chsenti===


100%|██████████| 1200/1200 [00:09<00:00, 121.94it/s]


              precision    recall  f1-score   support

          正向     0.8849    0.8820    0.8834       593
          负向     0.8851    0.8880    0.8865       607

    accuracy                         0.8850      1200
   macro avg     0.8850    0.8850    0.8850      1200
weighted avg     0.8850    0.8850    0.8850      1200

tensor([0.0228]) ==ece==
===senti_smpecisa===


100%|██████████| 2529/2529 [00:19<00:00, 126.66it/s]


              precision    recall  f1-score   support

          正向     0.7964    0.8010    0.7987      1201
          负向     0.8191    0.8148    0.8169      1328

    accuracy                         0.8082      2529
   macro avg     0.8077    0.8079    0.8078      2529
weighted avg     0.8083    0.8082    0.8082      2529

tensor([0.0255]) ==ece==
===senti_smp===


100%|██████████| 2844/2844 [00:22<00:00, 126.16it/s]


              precision    recall  f1-score   support

          正向     0.8475    0.8144    0.8306      1126
          负向     0.8814    0.9040    0.8925      1718

    accuracy                         0.8685      2844
   macro avg     0.8644    0.8592    0.8616      2844
weighted avg     0.8680    0.8685    0.8680      2844

tensor([0.0368]) ==ece==


In [24]:
ece_fn = ECE(n_bins=15)

In [45]:
mapping_dict = {
    '冒犯':0,
    '正常':1
}

gold_l = torch.tensor([mapping_dict[item] for item in gold])
pred_score_l = torch.tensor([[item[0][1], item[1][1]] for item in pred_score])

ece_fn(pred_score_l, gold_l, mode='probs')

tensor([0.1104])

In [43]:
mapping_dict = {
    '负向':0,
    '正向':1
}

gold_l = torch.tensor([mapping_dict[item] for item in gold])
pred_score_l = torch.tensor([[item[0][1], item[1][1]] for item in pred_score])

ece_fn(pred_score_l, gold_l, mode='probs')

tensor([0.0198])