In [1]:
import json
import os
from openai import OpenAI
from transformers.utils.versions import require_version
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

def predict(port, json_pth):
    # 端口
    client = OpenAI(
        api_key="0",
        base_url="http://localhost:{}/v1".format(port)
    )

    # 读取json文件
    with open(json_pth, 'r') as f:
        datas = json.load(f)

    # 初始化统计变量
    true_label_count = defaultdict(int)
    predict_label_count = defaultdict(int)
    confusion_matrix = defaultdict(lambda: defaultdict(int))  # 用于统计 TP, TN, FP, FN
    total_count = 0
    correct_predictions = 0

    for data in datas:
        input         = data['input']
        instruction   = data['instruction']
        label         = data['output']
        message = [{"role": "user",
                        "content": instruction + input}]

        result = client.chat.completions.create(messages=message, model="test",temperature=0)
        try:
            predict_label = result.choices[0].message.content.strip()
        except Exception as e:
            print(f'error occur {e}')
            print(result.choices[0].message.content)
            predict_label = "-1"

        # 更新统计
        true_label_count[label]               += 1
        predict_label_count[predict_label]    += 1
        total_count                           += 1

        if label == predict_label:
            correct_predictions += 1
            confusion_matrix[label]["TP"] += 1  # True Positive
        else:
            confusion_matrix[label]["FN"] += 1  # False Negative
            confusion_matrix[predict_label]["FP"] += 1  # False Positive

     # 计算总 Accuracy
    accuracy = correct_predictions / total_count

    # 打印结果
    print("数据总数:", total_count)
    print("真实标签分布:", dict(true_label_count))
    print("预测标签分布:", dict(predict_label_count))
    print("混淆矩阵:")
    for label, metrics in confusion_matrix.items():
        print(f"标签 {label}: {metrics}")
    print("Accuracy:", accuracy)

    return accuracy, confusion_matrix

accuracy, confusion_matrix = predict(port=8000, json_pth='./data/finance_sentiment/tfns_validation.json')

数据总数: 2388
真实标签分布: {'negative': 347, 'positive': 475, 'neutral': 1566}
预测标签分布: {'neutral': 1539, 'negative': 366, 'positive': 483}
混淆矩阵:
标签 negative: defaultdict(<class 'int'>, {'FN': 40, 'TP': 307, 'FP': 59})
标签 neutral: defaultdict(<class 'int'>, {'FP': 84, 'TP': 1455, 'FN': 111})
标签 positive: defaultdict(<class 'int'>, {'TP': 424, 'FN': 51, 'FP': 59})
Accuracy: 0.9154103852596315


In [2]:
def calculate_overall_f1(confusion_matrix):
    """
    根据混淆矩阵计算整体 Precision, Recall 和 F1 Score

    参数:
    confusion_matrix: dict
        混淆矩阵，例如：
        {
            "negative": {"TP": 1318, "FN": 124, "FP": 126},
            "neutral": {"TP": 5910, "FN": 268, "FP": 223},
            "positive": {"TP": 1807, "FN": 116, "FP": 159},
        }

    返回:
    dict:
        整体的 Precision, Recall 和 F1 Score
    """
    # 初始化总计数
    total_TP = 0
    total_FN = 0
    total_FP = 0

    # 遍历所有标签累加 TP, FN, FP
    for label, metrics in confusion_matrix.items():
        total_TP += metrics.get("TP", 0)
        total_FN += metrics.get("FN", 0)
        total_FP += metrics.get("FP", 0)

    # 计算整体 Precision 和 Recall
    print(f'total_TP: {total_TP}')
    print(f'total_FN: {total_FN}')
    print(f'total_FP: {total_FP}')
    precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0.0
    recall = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0.0

    # 计算整体 F1 Score
    f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

    return {
        "Precision": precision,
        "Recall": recall,
        "F1 Score": f1_score
    }


# 计算整体 F1-score
overall_scores = calculate_overall_f1(confusion_matrix)

# 打印结果
print("整体 Precision: {:.4f}".format(overall_scores["Precision"]))
print("整体 Recall: {:.4f}".format(overall_scores["Recall"]))
print("整体 F1 Score: {:.4f}".format(overall_scores["F1 Score"]))


total_TP: 2186
total_FN: 202
total_FP: 202
整体 Precision: 0.9154
整体 Recall: 0.9154
整体 F1 Score: 0.9154


In [3]:
json_path = r'data/finance_sentiment/tfns_validation.json'
with open(json_path, 'r') as f:
    data = json.load(f)

print(data[0])

{'instruction': 'What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.', 'input': '$ALLY - Ally Financial pulls outlook https://t.co/G9Zdi1boy5', 'output': 'negative'}
