In [1]:
import json
import sys,os
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

sys.path.extend(['/root/xiaoda/query_topic/'])

In [3]:
import torch
from torch.nn import functional as F
import numpy as np
import random
import torch.nn as nn
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np

"""
https://github.com/ondrejbohdal/meta-calibration/blob/main/Metrics/metrics.py
"""

class ECE(nn.Module):
    
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECE, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels, mode='logits'):
        if mode == 'logits':
            softmaxes = F.softmax(logits, dim=1)
        else:
            softmaxes = logits
        # softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

In [4]:
import torch
import json
import sys
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast
import transformers
from datetime import timedelta

import os, sys
cur_dir_path = '/root/xiaoda/query_topic/'

sys.path.extend([cur_dir_path])

from nets.them_classifier import MyBaseModel, RobertaClassifier

import configparser
from tqdm import tqdm

class TopicInfer(object):
    def __init__(self, config_path):

        import torch, os, sys

        con = configparser.ConfigParser()
        con_path = os.path.join(cur_dir_path, config_path)
        con.read(con_path, encoding='utf8')

        args_path = dict(dict(con.items('paths')), **dict(con.items("para")))
        self.tokenizer = BertTokenizerFast.from_pretrained(args_path["model_path"], do_lower_case=True)

        label_list = []
        label_path = os.path.join(cur_dir_path, args_path['label_path'])
        with open(label_path, 'r') as frobj:
            for line in frobj:
                label_list.append(line.strip())
        n_classes = len(label_list)

        self.label2id, self.id2label = {}, {}
        for idx, label in enumerate(label_list):
            self.label2id[label] = idx
            self.id2label[idx] = label
            
        print(self.label2id, '===', self.id2label)
        
        output_path = os.path.join(cur_dir_path, args_path['output_path'])

        from roformer import RoFormerModel, RoFormerConfig

        config = RoFormerConfig.from_pretrained(args_path["model_path"])
        encoder = RoFormerModel(config=config)
        
        encoder_net = MyBaseModel(encoder, config)

        classify_net = RobertaClassifier(
            hidden_size=config.hidden_size, 
            dropout_prob=con.getfloat('para', 'out_dropout_rate'),
            num_labels=n_classes, 
            dropout_type=con.get('para', 'dropout_type'))

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        class TopicClassifier(nn.Module):
            def __init__(self, transformer, classifier):
                super().__init__()

                self.transformer = transformer
                self.classifier = classifier

            def forward(self, input_ids, input_mask, 
                        segment_ids=None, transformer_mode='mean_pooling'):
                hidden_states = self.transformer(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              return_mode=transformer_mode)
                ce_logits = self.classifier(hidden_states)
                return ce_logits, hidden_states

        import os
        self.net = TopicClassifier(encoder_net, classify_net).to(self.device)
        # eo = 9
        # ckpt = torch.load(os.path.join(output_path, 'cls.pth.{}'.format(eo)), map_location=self.device)
        # self.topic_net.load_state_dict(ckpt)
        # self.topic_net.eval()
        
    def reload(self, model_path):
        ckpt = torch.load(model_path, map_location=self.device)
        self.net.load_state_dict(ckpt)
        self.net.eval() 

    def predict(self, text, top_n=5):

        """抽取输入text所包含的类型
        """
        token2char_span_mapping = self.tokenizer(text, return_offsets_mapping=True, max_length=256)["offset_mapping"]
        encoder_txt = self.tokenizer.encode_plus(text, max_length=256)
        input_ids = torch.tensor(encoder_txt["input_ids"]).long().unsqueeze(0).to(self.device)
        token_type_ids = torch.tensor(encoder_txt["token_type_ids"]).unsqueeze(0).to(self.device)
        attention_mask = torch.tensor(encoder_txt["attention_mask"]).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            scores, hidden_states = self.net(input_ids, attention_mask, token_type_ids, 
                         transformer_mode='cls'
                         )
            scores = torch.nn.Softmax(dim=1)(scores)[0].data.cpu().numpy()
        
        schema_types = []
        for index, score in enumerate(scores):
             schema_types.append([self.id2label[index], float(score)])
        schema_types = sorted(schema_types, key=lambda item:item[1], reverse=True)
        return schema_types[0:5]
    
    def predict_batch(self, text):
        if isinstance(text, list):
            text_list = text
        else:
            text_list = [text]
        model_input = self.tokenizer(text_list, return_tensors="pt",padding=True)
        for key in model_input:
            model_input[key] = model_input[key].to(self.device)
        with torch.no_grad():
            scores, hidden_states = self.net(input_ids, attention_mask, token_type_ids, 
                         transformer_mode='cls'
                         )
            scores = torch.nn.Softmax(dim=1)(scores).data.cpu().numpy()
        schema_types_list = []
        for score, text in zip(scores, text_list):
            schema_types = []
            for index, score in enumerate(scores):
                 schema_types.append([self.id2label[index], float(score)])
            schema_types = sorted(schema_types, key=lambda item:item[1], reverse=True)

topic_api = TopicInfer('./topic_data_v4/config.ini')

# text = '王二今天打车去了哪里，从哪里出发，到哪里了'
# print(topic_api.predict(text), text)


{'基金': 0, '汽车': 1, '网络安全': 2, '情商': 3, '财务税务': 4, '旅游': 5, '建筑': 6, '炫富': 7, '家电': 8, '商业': 9, '性生活': 10, '风水': 11, '人物': 12, 'VPN': 13, '吸烟': 14, '市场营销': 15, '游戏': 16, '算命': 17, '编程': 18, '死亡': 19, '公司': 20, '国家': 21, '美食/烹饪': 22, '明星': 23, '城市': 24, '银行': 25, '期货': 26, '宗教': 27, '学习': 28, '电子数码': 29, '网络暴力': 30, 'LGBT': 31, '其他': 32, '故事': 33, '社会': 34, '二手': 35, '动漫': 36, '歧视': 37, '常识': 38, '星座': 39, '冷知识': 40, '职场职业': 41, '食品': 42, '心理健康': 43, '电子商务': 44, '道德伦理': 45, '商业/理财': 46, '赚钱': 47, '神话': 48, '校园生活': 49, '色情': 50, '婚姻': 51, '家居装修': 52, '生活': 53, '灵异灵修': 54, '股票': 55, '娱乐': 56, '女性': 57, '体验': 58, '广告': 59, '天气': 60, '女权': 61, '潜规则': 62, '人类': 63, '马克思主义': 64, '历史': 65, '音乐': 66, '毒品': 67, '摄影': 68, '金融': 69, '影视': 70, '语言': 71, '环境': 72, '高铁': 73, '人际交往': 74, '夜店': 75, '价值观': 76, '恋爱': 77, '相貌': 78, 'BDSM': 79, '恐怖主义': 80, '中医': 81, '性侵犯': 82, '阅读': 83, '时尚': 84, '体育/运动': 85, '资本主义': 86, '灾害意外': 87, '博彩': 88, '成长': 89, '校园暴力': 90, '移民': 91, '美容/塑身': 92, '经济': 93, '睡眠': 94, 

12/25/2022 11:14:29 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++


In [5]:
topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v4_label_smoothing/them/cls.pth.9')
# topic_api.reload('/data/albert.xht/xiaodao/topic_classification_v4/them/cls.pth.9')

In [None]:
topic_api.predict('察抓小偷')

In [19]:
v4_valid = []
with open('/data/albert.xht/raw_chat_corpus/topic_classification_v4/biake_qa_web_text_zh_valid.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        v4_valid.append(content)
        

In [165]:
from sklearn.metrics import classification_report

def eval_all(data, model, top_n=5):
    pred = []
    gold = []
    pred_score = []
    pred = 0
    for item in tqdm(data):
        gold.append(item['label'][0])
        if isinstance(item['text'], list):
            text = "\n".join(item['text'])
        else:
            text = item['text']
        result = model.predict(text)
        score = sorted(result, key=lambda u:u[1], reverse=True)
        pred_set = set([item[0] for item in score[:top_n]])
        if set(item['label']) & pred_set:
            pred += 1
        pred_score.append(result)
        # break
    # print(classification_report(gold, pred, digits=4), '===', top_n)
    print(pred/len(pred_score))
    return pred_score, gold

In [140]:
target = []
for item in v4_valid:
    if item['label'][0] in ['道德伦理']:
        target.append(item)

In [157]:
target_result = []
for item in target:
    result = topic_api.predict(item['text'])
    if result[0][0] not in item['label']:
        target_result.append((result, item))

In [167]:
pred_score, gold = eval_all(v4_valid, topic_api, top_n=3)

100%|██████████| 86668/86668 [10:50<00:00, 133.29it/s]

0.7505307610652144





In [149]:
pred_score[0:10]

[[['城市', 0.9565373659133911],
  ['文化/艺术', 0.03183119744062424],
  ['教育/科学', 0.0017084553837776184],
  ['历史', 0.0014964144211262465],
  ['电脑/网络', 0.0005834586918354034]],
 [['恋爱', 0.3455376625061035],
  ['两性', 0.29532861709594727],
  ['情感', 0.2196192741394043],
  ['女性', 0.04598280042409897],
  ['爱情', 0.024114560335874557]],
 [['情感', 0.40412089228630066],
  ['恋爱', 0.258561909198761],
  ['爱情', 0.22933520376682281],
  ['两性', 0.03647314012050629],
  ['心理健康', 0.013390579260885715]],
 [['汽车', 0.9744734764099121],
  ['市场营销', 0.004104863852262497],
  ['时事政治', 0.002023685025051236],
  ['财务税务', 0.0017918514786288142],
  ['二手', 0.001525147003121674]],
 [['学习', 0.4129531979560852],
  ['心理健康', 0.3989134430885315],
  ['校园生活', 0.12909871339797974],
  ['教育/科学', 0.02070748619735241],
  ['价值观', 0.007254814729094505]],
 [['国家', 0.5084620714187622],
  ['历史', 0.3961915969848633],
  ['时事政治', 0.0633591040968895],
  ['军事', 0.012993273325264454],
  ['战争', 0.008572555147111416]],
 [['影视', 0.9434626698493958],
  