In [1]:
import json
import sys,os
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

sys.path.extend(['/root/xiaoda/query_topic/'])

In [3]:
import torch
from torch.nn import functional as F
import numpy as np
import random
import torch.nn as nn
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np

"""
https://github.com/ondrejbohdal/meta-calibration/blob/main/Metrics/metrics.py
"""

class ECE(nn.Module):
    
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECE, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels, mode='logits'):
        if mode == 'logits':
            softmaxes = F.softmax(logits, dim=1)
        else:
            softmaxes = logits
        # softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

In [4]:
import torch
import json
import sys
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast
import transformers
from datetime import timedelta

import os, sys
cur_dir_path = '/root/xiaoda/query_topic/'

sys.path.extend([cur_dir_path])

from nets.them_classifier import MyBaseModel, RobertaClassifier

import configparser
from tqdm import tqdm

class TopicInfer(object):
    def __init__(self, config_path):

        import torch, os, sys

        con = configparser.ConfigParser()
        con_path = os.path.join(cur_dir_path, config_path)
        con.read(con_path, encoding='utf8')

        args_path = dict(dict(con.items('paths')), **dict(con.items("para")))
        self.tokenizer = BertTokenizerFast.from_pretrained(args_path["model_path"], do_lower_case=True)

        label_list = []
        label_path = os.path.join(cur_dir_path, args_path['label_path'])
        with open(label_path, 'r') as frobj:
            for line in frobj:
                label_list.append(line.strip())
        n_classes = len(label_list)

        self.label2id, self.id2label = {}, {}
        for idx, label in enumerate(label_list):
            self.label2id[label] = idx
            self.id2label[idx] = label
            
        print(self.label2id, '===', self.id2label)
        
        output_path = os.path.join(cur_dir_path, args_path['output_path'])

        from roformer import RoFormerModel, RoFormerConfig

        config = RoFormerConfig.from_pretrained(args_path["model_path"])
        encoder = RoFormerModel(config=config)
        
        encoder_net = MyBaseModel(encoder, config)

        classify_net = RobertaClassifier(
            hidden_size=config.hidden_size, 
            dropout_prob=con.getfloat('para', 'out_dropout_rate'),
            num_labels=n_classes, 
            dropout_type=con.get('para', 'dropout_type'))

        self.device = "cuda:1" if torch.cuda.is_available() else "cpu"

        class TopicClassifier(nn.Module):
            def __init__(self, transformer, classifier):
                super().__init__()

                self.transformer = transformer
                self.classifier = classifier

            def forward(self, input_ids, input_mask, 
                        segment_ids=None, transformer_mode='mean_pooling'):
                hidden_states = self.transformer(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              return_mode=transformer_mode)
                ce_logits = self.classifier(hidden_states)
                return ce_logits, hidden_states

        import os
        self.net = TopicClassifier(encoder_net, classify_net).to(self.device)
        # eo = 9
        # ckpt = torch.load(os.path.join(output_path, 'cls.pth.{}'.format(eo)), map_location=self.device)
        # self.topic_net.load_state_dict(ckpt)
        # self.topic_net.eval()
        
    def reload(self, model_path):
        ckpt = torch.load(model_path, map_location=self.device)
        self.net.load_state_dict(ckpt)
        self.net.eval() 

    def predict(self, text, top_n=5):

        """抽取输入text所包含的类型
        """
        token2char_span_mapping = self.tokenizer(text, return_offsets_mapping=True, max_length=256)["offset_mapping"]
        encoder_txt = self.tokenizer.encode_plus(text, max_length=256)
        input_ids = torch.tensor(encoder_txt["input_ids"]).long().unsqueeze(0).to(self.device)
        token_type_ids = torch.tensor(encoder_txt["token_type_ids"]).unsqueeze(0).to(self.device)
        attention_mask = torch.tensor(encoder_txt["attention_mask"]).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            scores, hidden_states = self.net(input_ids, attention_mask, token_type_ids, 
                         transformer_mode='cls'
                         )
            scores = torch.nn.Softmax(dim=1)(scores)[0].data.cpu().numpy()
        
        schema_types = []
        for index, score in enumerate(scores):
             schema_types.append([self.id2label[index], float(score)])
        schema_types = sorted(schema_types, key=lambda item:item[1], reverse=True)
        return schema_types[0:5]
    
    def predict_batch(self, text):
        if isinstance(text, list):
            text_list = text
        else:
            text_list = [text]
        model_input = self.tokenizer(text_list, return_tensors="pt",padding=True)
        for key in model_input:
            model_input[key] = model_input[key].to(self.device)
        with torch.no_grad():
            scores, hidden_states = self.net(model_input['input_ids'], 
                        model_input['attention_mask'], 
                        model_input['token_type_ids'], 
                         transformer_mode='cls'
                         )
            scores = torch.nn.Softmax(dim=1)(scores).data.cpu().numpy()
        schema_types_list = []
        for score_, text in zip(scores, text_list):
            schema_types = []
            for index, score in enumerate(score_):
                 schema_types.append([self.id2label[index], float(score)])
            schema_types = sorted(schema_types, key=lambda item:item[1], reverse=True)
            schema_types_list.append(schema_types[0:5])
        return schema_types_list
    
    def infer_batch(self, text):
        if isinstance(text, list):
            text_list = text
        else:
            text_list = [text]
        model_input = self.tokenizer(text_list, return_tensors="pt",padding=True)
        for key in model_input:
            model_input[key] = model_input[key].to(self.device)
        with torch.no_grad():
            scores, hidden_states = self.net(model_input['input_ids'], 
                        model_input['attention_mask'], 
                        model_input['token_type_ids'], 
                         transformer_mode='cls'
                         )
            scores = torch.nn.Softmax(dim=1)(scores).data.cpu().numpy()

topic_api = TopicInfer('./topic_data_v5/config.ini')

# text = '王二今天打车去了哪里，从哪里出发，到哪里了'
# print(topic_api.predict(text), text)


{'夜店': 0, '睡眠': 1, '编程': 2, '金融': 3, '抄袭': 4, '高铁': 5, '城市': 6, '性侵犯': 7, '网络安全': 8, '人物': 9, '移民': 10, '成长': 11, '经济': 12, '购房置业': 13, '音乐': 14, '养生': 15, '婚姻': 16, '时事政治': 17, '星座': 18, '常识': 19, '法律': 20, '死亡': 21, '性生活': 22, '公司': 23, '人际交往': 24, '游戏': 25, '国家': 26, '交通出行': 27, '中医': 28, '心理健康': 29, '女性': 30, '家居装修': 31, '美容/塑身': 32, 'LGBT': 33, '爱情': 34, '建筑': 35, '价值观': 36, '赚钱': 37, '资源共享': 38, '家庭关系': 39, '歧视': 40, '灾害意外': 41, '股票': 42, '风水': 43, '算命': 44, '思维': 45, '体验': 46, '交友': 47, '恐怖主义': 48, '冷知识': 49, '时尚': 50, '购物': 51, '育儿': 52, '男性': 53, '生活': 54, '色情': 55, '恋爱': 56, '汽车': 57, '创业投资': 58, '二手': 59, '电子商务': 60, '惊悚': 61, '军事': 62, '语言': 63, '娱乐': 64, '女权': 65, '道德伦理': 66, '市场营销': 67, '商业/理财': 68, '资本主义': 69, '相貌': 70, '房地产': 71, '教育/科学': 72, '暴力': 73, '美食/烹饪': 74, '动植物': 75, '爱国': 76, '神话': 77, '情商': 78, '小说': 79, '学习': 80, '电脑/网络': 81, '家电': 82, '旅游': 83, '明星': 84, '写作': 85, '宠物': 86, '期货': 87, '故事': 88, '审美': 89, '民族': 90, '战争': 91, '毒品': 92, '性骚扰': 93, '人类': 94, '广告

01/18/2023 16:20:54 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++


In [5]:
# topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v4_label_smoothing/them/cls.pth.9')
# topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v4/them/cls.pth.9')
# topic_api.reload('//data/albert.xht/xiaodao/topic_classification_v4/them_20epoch/cls.pth.19')


# topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v4/them_knn_20epoch/cls.pth.9')
# topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v5/them_knn_10epoch//cls.pth.9')

# topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v5/normal_knn_20epoch/cls.pth.19')
topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v5/normal_knn_20epoch_logitclip_1.2/cls.pth.19')



# topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v4/them_3epoch/cls.pth.2')

In [14]:
with open('/data/albert.xht/xiaodao/query_risk_v10/biake_qa_web_text_zh_train.json.offensive.all.negative', 'w') as fwobj:
    with open('/data/albert.xht/xiaodao/query_risk_v10/biake_qa_web_text_zh_train.json.offensive.all') as frobj:
        for line in frobj:
            content = json.loads(line.strip())
            topic = topic_api.predict_batch(content['text'])[0][0]
            if content['label'][0] in ['风险']:
                content['topic'] = [topic]
                fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
    


In [7]:
text_list = ['你能不能准时一点',
'你太严肃了我不喜欢你',
'学神是什么',
'是不是网络断开了',
'世界上有没有鬼呀',
'我老婆不开心了怎么办',
'你是个傻瓜']
topic_api.predict_batch('为 什 么 大 家 在  上 这 么 喜 欢 刷')

[[['游戏', 0.9999998807907104],
  ['电脑/网络', 2.62903965264627e-09],
  ['公司', 1.719765108099125e-09],
  ['音乐', 1.701626950456614e-09],
  ['购物', 1.6975975070110394e-09]]]

In [147]:
# Problem setup
n = 1000 # number of calibration points
alpha = 0.2 # 1-alpha is the desired coverage

1 2


In [73]:
v4_valid = []
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_valid.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        v4_valid.append(content)
        
v4_knn_valid = []
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_valid.json.topic.knn.final') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        v4_knn_valid.append(content)


In [8]:
v5_valid = []
v5_text = []
v5_label = []
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v5/biake_qa_web_text_zh_valid.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        v5_valid.append(content)
        v5_text.append(content['text'])
        v5_label.append(topic_api.label2id[content['label'][0]])



In [17]:
from sklearn.metrics import classification_report
import numpy as np

def eval_all(data, model, top_n=5):
    pred = []
    gold = []
    pred_score = []
    pred = 0
    queue = []
    tt = []
    total_pred = []
    for item in tqdm(data):
        gold.append(item['label'][0])
        if isinstance(item['text'], list):
            text = "\n".join(item['text'])
        else:
            text = item['text']
        queue.append(text)
        tt.append(item)
        if np.mod(len(queue), 128) == 0:
            result_list = model.predict_batch(queue)
            for result, text, t in zip(result_list, queue, tt):
                score = sorted(result, key=lambda u:u[1], reverse=True)
                pred_set = set([p[0] for p in score[:top_n]])
                total_pred.append(score[0][0])
                if set(t['label']) & pred_set:
                    pred += 1
                pred_score.append(result)
            queue = []
            tt = []
    if queue:
        result_list = model.predict_batch(queue)
        for result, text, t in zip(result_list, queue, tt):
            score = sorted(result, key=lambda u:u[1], reverse=True)
            pred_set = set([p[0] for p in score[:top_n]])
            total_pred.append(score[0][0])
            if set(t['label']) & pred_set:
                pred += 1
            pred_score.append(result)
        # break
    print(classification_report(gold, total_pred, digits=4), '===', top_n)
    print(pred/len(pred_score))
    return pred_score, gold

In [140]:
target = []
for item in v4_valid:
    if item['label'][0] in ['道德伦理']:
        target.append(item)

In [157]:
target_result = []
for item in target:
    result = topic_api.predict(item['text'])
    if result[0][0] not in item['label']:
        target_result.append((result, item))

In [108]:
pred_score, gold = eval_all(v5_valid, topic_api, top_n=2)

100%|██████████| 83069/83069 [00:28<00:00, 2959.74it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

        BDSM     0.0000    0.0000    0.0000        15
        LGBT     0.6034    0.6625    0.6316       317
         VPN     0.0000    0.0000    0.0000         5
          两性     0.1923    0.0079    0.0151       636
          中医     0.7321    0.5616    0.6357       146
          二手     0.0000    0.0000    0.0000        19
          交友     0.0000    0.0000    0.0000        25
        交通出行     0.6793    0.7231    0.7005       372
          人物     0.0000    0.0000    0.0000       218
          人类     0.0000    0.0000    0.0000        29
        人际交往     0.4295    0.3049    0.3566       669
         价值观     0.0000    0.0000    0.0000       303
       体育/运动     0.7583    0.8480    0.8006      2039
          体验     0.0000    0.0000    0.0000       122
          保险     0.7373    0.7500    0.7436       116
          健康     0.7543    0.8040    0.7783      5754
          公司     0.5736    0.5986    0.5859       872
          养生     0.0000    

  _warn_prf(average, modifier, msg_start, len(result))


In [10]:
# pred_score, gold = eval_all(v4_knn_valid, topic_api, top_n=5)

In [11]:
# pred_score, gold = eval_all(v4_valid, topic_api, top_n=2)

In [14]:
with open('/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive.topic', 'w') as fwobj:
    with open('/data/albert.xht/xiaodao/topic_classification_v7/biake_qa_web_text_zh_train.json.positive', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            # content['text'] = re.sub('请问', '', content['text'])
            queue.append(content['text'])
            t.append(content)
            if np.mod(len(queue), 32) == 0:
                probs = topic_api.predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    tt['label'] = prob_dict
                    fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = topic_api.predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                tt['label'] = prob_dict
                fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')

25507it [00:11, 2218.33it/s]


In [15]:
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/embed_linear_small_white.json.topic', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/embed_linear_small_white.json', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            # content['text'] = re.sub('请问', '', content['text'])
            queue.append(content['text'])
            t.append(content)
            if np.mod(len(queue), 32) == 0:
                probs = topic_api.predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    tt['label'] = prob_dict
                    fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = topic_api.predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                tt['label'] = prob_dict
                fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')

23116it [00:10, 2115.69it/s]


In [69]:
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.risk_topic', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.risk', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            # content['text'] = re.sub('请问', '', content['text'])
            queue.append(content['text'])
            t.append(content)
            if np.mod(len(queue), 32) == 0:
                probs = topic_api.predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    tt['topic'] = prob_dict
                    fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = topic_api.predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                tt['topic'] = prob_dict
                fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')

24721it [00:13, 1896.10it/s]


In [19]:
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json,new_topic', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_train.json', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            # content['text'] = re.sub('请问', '', content['text'])
            queue.append(content['text'])
            t.append(content)
            if np.mod(len(queue), 32) == 0:
                probs = topic_api.predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    tt['new_topic'] = prob_dict
                    fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = topic_api.predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                tt['new_topic'] = prob_dict
                fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')

2071401it [16:00, 2157.32it/s]


In [8]:
with open('/data/albert.xht/sentiment/green_teenager.json.all.detail.topic', 'w') as fwobj:
    with open('/data/albert.xht/sentiment/green_teenager.json.all.detail', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            # content['text'] = re.sub('请问', '', content['text'])
            if len(content['text']) > 192:
                continue
            queue.append(content['text'])
            t.append(content)
            if np.mod(len(queue), 32) == 0:
                probs = topic_api.predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    tt['topic'] = prob_dict
                    fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = topic_api.predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                tt['topic'] = prob_dict
                fwobj.write(json.dumps(tt, ensure_ascii=False)+'\n')

241035it [02:58, 1348.68it/s]


In [12]:

import re
with open('/data/albert.xht/pretrained_model_risk/corpus/efaqa-corpus-zh/efaqa-corpus-zh.utf8.topic', 'w') as fwobj:
    with open('/data/albert.xht/pretrained_model_risk/corpus/efaqa-corpus-zh/efaqa-corpus-zh.utf8', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            # content['text'] = re.sub('请问', '', content['text'])
            title = content['title'] #''.join(re.split('[\s,]', content['title'])[1:])
            if len(title) > 192:
                continue
            queue.append(title)
            t.append(content)
            if np.mod(len(queue), 32) == 0:
                probs = topic_api.predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    tmp = {
                        'text':text,
                        'topic':prob_dict,
                        'source':'efaqa'
                    }
                    fwobj.write(json.dumps(tmp, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = topic_api.predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                tmp = {
                        'text':text,
                        'topic':prob_dict,
                        'source':'efaqa'
                    }
                fwobj.write(json.dumps(tmp, ensure_ascii=False)+'\n')

20000it [00:11, 1762.52it/s]


In [14]:

import re
with open('/data/albert.xht/xiaodao/query_risk_v11/offensive_select_labeled.txt.topic', 'w') as fwobj:
    with open('/data/albert.xht/xiaodao/query_risk_v11/offensive_select_labeled.txt', 'r') as frobj:
        queue = []
        t = []
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            # content['text'] = re.sub('请问', '', content['text'])
            title = content['text'] #''.join(re.split('[\s,]', content['title'])[1:])
            if len(title) > 192:
                continue
            queue.append(title)
            t.append(content)
            if np.mod(len(queue), 32) == 0:
                probs = topic_api.predict_batch(queue)
                for prob_dict, text, tt in zip(probs, queue, t):
                    tmp = {
                        'text':text,
                        'topic':prob_dict,
                        'label':tt['label'],
                        'source':'efaqa'
                    }
                    fwobj.write(json.dumps(tmp, ensure_ascii=False)+'\n')
                queue = []
                t = []
        if queue:
            probs = topic_api.predict_batch(queue)
            for prob_dict, text, tt in zip(probs, queue, t):
                tmp = {
                        'text':text,
                        'topic':prob_dict,
                        'label':tt['label'],
                        'source':'efaqa'
                    }
                fwobj.write(json.dumps(tmp, ensure_ascii=False)+'\n')

20641it [00:10, 1970.45it/s]


In [114]:

with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/疑似有风险query_from对话预训练数据-20221130.txt') as frobj:
    for line in tqdm(frobj):
        text = line.strip()
        result_list.append((text, result))

topic_filter = []
topic_white = []
left = []
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.risk_topic') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        if content['score_list']['senti'][1][1] > 0.7:
            if content['topic'][0][0] in ['经济', '历史', 'LGBT', 'BDSM', '法律', '灵异灵修', '国家', '社会', '军事', '心理健康', '时事政治', '死亡', '毒品', '恐怖主义', '战争', '灵异事件']:
                topic_filter.append(content)
            else:
                if '艾滋病' in content['text'] or '抑郁' in content['text']:
                    topic_filter.append(content)
                else:
                    topic_white.append(content)
        else:
            topic_filter.append(content)
        # if content['topic'][0][0] in ['文化/艺术', '动漫', '音乐', '娱乐', '体育/运动', '小说', '星座', '时尚', '健康', '明星', '游戏'] and content['topic'][0][1] > 0.2:
        #     if content['score_list']['senti'][0][1] < 0.3:
        #         topic_white.append(content)
        #     elif content['score_list']['senti'][0][1] > 0.9:
        #         topic_filter.append(content)
        # elif content['topic'][0][0] in ['LGBT', 'BDSM', '法律', '灵异灵修', '国家', '社会', '军事', '心理健康', '时事政治', '死亡', '毒品', '恐怖主义', '战争', '灵异事件']:
        #     topic_filter.append(content)
        # elif content['topic'][1][0] in ['LGBT', 'BDSM', '法律', '灵异灵修', '国家', '社会', '军事', '心理健康', '时事政治', '死亡', '毒品', '恐怖主义', '战争', '灵异事件']:
        #     topic_filter.append(content)
        # else:
        #     left.append(content)
                

In [118]:
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.filter', 'w') as fwobj:
    for d in topic_filter:
        d['label'] = ['风险']
        fwobj.write(json.dumps(d, ensure_ascii=False)+'\n')
with open('/data/albert.xht/raw_chat_corpus/model_risk_xiaoda/query_risk_corpus.json.white', 'w') as fwobj:
    for d in topic_white:
        d['label'] = ['正常']
        fwobj.write(json.dumps(d, ensure_ascii=False)+'\n')

In [1]:
import os, sys

sys.path.extend(['/root/xiaoda/query_topic/'])

In [21]:
from keyword_processor import KeywordProcesser
with open('/data/albert.xht/xiaoda/sentiment/query_risk_v14//risk_event.txt') as frobj:
    for line in frobj:
        content = line.strip()
    keyword_list = content.split('||')

keyword_api_v1 = KeywordProcesser(keywords=keyword_list)

In [22]:
import jieba_fast as jieba
for word in keyword_list:
    jieba.add_word(word)
    
from tqdm import tqdm
import json, re

with open('/data/albert.xht/raw_chat_corpus/topic_classification_v5/biake_qa_web_text_zh_train.json.knn.final.keyword', 'w') as fwobj:
    with open('/data/albert.xht/raw_chat_corpus/topic_classification_v5/biake_qa_web_text_zh_train.json.knn.final') as frobj:
        for line in tqdm(frobj):
            content = json.loads(line.strip())
            text = re.sub(u'[^\u4e00-\u9fa50-9a-zA-Z ]+', '\n', content['text'].lower())
            content['keywords'] = keyword_api_v1.extract_keywords(text)
            words_list = list(jieba.cut(text))
            words_set = {}
            for word in words_list:
                words_set[word] = ''
            if content['keywords']:
                keyword_list = []
                for keyword_ in content['keywords']:
                    if keyword_[-1] in words_set:
                        keyword_list.append(keyword_)
                if keyword_list:
                    if content['label'][0] not in ['游戏', '小说', '网络安全', '动漫', '幽默滑稽', '电子数码']:
                        content['topic'] = content['label']
                        content['label'] = ['风险']
                        fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')
                    elif content['label'][0] in ['游戏', '小说', '网络安全', '动漫', '电子数码']:
                        content['topic'] = content['label']
                        content['label'] = ['正常']
                        fwobj.write(json.dumps(content, ensure_ascii=False)+'\n')

1973027it [03:03, 10726.40it/s]


In [7]:
list(jieba.cut('javaweb程序员是否适合用MacBook'))

['j', 'av', 'aweb', '程序员', '是否', '适合', '用', 'MacBook']