In [1]:
import json
import sys,os
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

sys.path.extend(['/root/xiaoda/query_topic/'])

In [3]:
import torch
from torch.nn import functional as F
import numpy as np
import random
import torch.nn as nn
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np

"""
https://github.com/ondrejbohdal/meta-calibration/blob/main/Metrics/metrics.py
"""

class ECE(nn.Module):
    
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECE, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels, mode='logits'):
        if mode == 'logits':
            softmaxes = F.softmax(logits, dim=1)
        else:
            softmaxes = logits
        # softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

In [10]:
import torch
import json
import sys
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast
import transformers
from datetime import timedelta

import os, sys
cur_dir_path = '/root/xiaoda/query_topic/'

sys.path.extend([cur_dir_path])

from nets.them_classifier import MyBaseModel, RobertaClassifier

import configparser
from tqdm import tqdm

class TopicInfer(object):
    def __init__(self, config_path):

        import torch, os, sys

        con = configparser.ConfigParser()
        con_path = os.path.join(cur_dir_path, config_path)
        con.read(con_path, encoding='utf8')

        args_path = dict(dict(con.items('paths')), **dict(con.items("para")))
        self.tokenizer = BertTokenizerFast.from_pretrained(args_path["model_path"], do_lower_case=True)

        label_list = []
        label_path = os.path.join(cur_dir_path, args_path['label_path'])
        with open(label_path, 'r') as frobj:
            for line in frobj:
                label_list.append(line.strip())
        n_classes = len(label_list)

        self.label2id, self.id2label = {}, {}
        for idx, label in enumerate(label_list):
            self.label2id[label] = idx
            self.id2label[idx] = label
            
        print(self.label2id, '===', self.id2label)
        
        output_path = os.path.join(cur_dir_path, args_path['output_path'])

        from roformer import RoFormerModel, RoFormerConfig

        config = RoFormerConfig.from_pretrained(args_path["model_path"])
        encoder = RoFormerModel(config=config)
        
        encoder_net = MyBaseModel(encoder, config)

        classify_net = RobertaClassifier(
            hidden_size=config.hidden_size, 
            dropout_prob=con.getfloat('para', 'out_dropout_rate'),
            num_labels=n_classes, 
            dropout_type=con.get('para', 'dropout_type'))

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        class TopicClassifier(nn.Module):
            def __init__(self, transformer, classifier):
                super().__init__()

                self.transformer = transformer
                self.classifier = classifier

            def forward(self, input_ids, input_mask, 
                        segment_ids=None, transformer_mode='mean_pooling'):
                hidden_states = self.transformer(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              return_mode=transformer_mode)
                ce_logits = self.classifier(hidden_states)
                return ce_logits, hidden_states

        import os
        self.net = TopicClassifier(encoder_net, classify_net).to(self.device)
        # eo = 9
        # ckpt = torch.load(os.path.join(output_path, 'cls.pth.{}'.format(eo)), map_location=self.device)
        # self.topic_net.load_state_dict(ckpt)
        # self.topic_net.eval()
        
    def reload(self, model_path):
        ckpt = torch.load(model_path, map_location=self.device)
        self.net.load_state_dict(ckpt)
        self.net.eval() 

    def predict(self, text, top_n=5):

        """抽取输入text所包含的类型
        """
        token2char_span_mapping = self.tokenizer(text, return_offsets_mapping=True, max_length=256)["offset_mapping"]
        encoder_txt = self.tokenizer.encode_plus(text, max_length=256)
        input_ids = torch.tensor(encoder_txt["input_ids"]).long().unsqueeze(0).to(self.device)
        token_type_ids = torch.tensor(encoder_txt["token_type_ids"]).unsqueeze(0).to(self.device)
        attention_mask = torch.tensor(encoder_txt["attention_mask"]).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            scores, hidden_states = self.net(input_ids, attention_mask, token_type_ids, 
                         transformer_mode='cls'
                         )
            scores = torch.nn.Softmax(dim=1)(scores)[0].data.cpu().numpy()
        
        schema_types = []
        for index, score in enumerate(scores):
             schema_types.append([self.id2label[index], float(score)])
        return schema_types

topic_api = TopicInfer('./risk_data/config_senti.ini')

# text = '王二今天打车去了哪里，从哪里出发，到哪里了'
# print(topic_api.predict(text), text)


{'负向': 0, '正向': 1} === {0: '负向', 1: '正向'}


12/18/2022 13:15:44 - INFO - nets.them_classifier - ++RobertaClassifier++ apply stable dropout++


In [6]:
senti_copr = []
with open('/data/albert.xht/sentiment/dev/senti_copr.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        senti_copr.append(content)
        
senti_smp = []
with open('/data/albert.xht/sentiment/dev/senti_smp_usual.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        senti_smp.append(content)
        
senti_smpecisa = []
with open('/data/albert.xht/sentiment/dev/senti_smpecisa.json') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        senti_smpecisa.append(content)

In [11]:
from sklearn.metrics import classification_report
from tqdm import tqdm

def eval_all(data, model):
    pred = []
    gold = []
    pred_score = []
    for item in tqdm(data):
        gold.append(item['label'][0])
        if isinstance(item['text'], list):
            text = "\n".join(item['text'])
        else:
            text = item['text']
        result = model.predict(text)
        score = sorted(result, key=lambda u:u[1], reverse=True)
        pred.append(score[0][0])
        pred_score.append(result)
    print(classification_report(gold, pred, digits=4))
    return pred, gold, pred_score
    


In [12]:
model_path = '/data/albert.xht/xiaodao/risk_classification/them_senti/cls.pth.9'
topic_api.reload(model_path)
print('===senti_copr===')
pred, gold, pred_score = eval_all(senti_copr, topic_api)

===senti_copr===


  0%|          | 0/1200 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
  0%|          | 0/1200 [00:00<?, ?it/s]


AttributeError: 'TopicInfer' object has no attribute 'topic_net'