In [None]:
import os
import sys

sys.path.append(os.path.join(os.path.dirname(os.path.abspath('')), 'utils'))
from utils.vector_db import BookkeepingVectorDB

vector_db = BookkeepingVectorDB()

### 根据数据集 DB 训练 SVM 模型，并进行评估

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import classification_report, accuracy_score

# 获取向量数据库中的所有数据
collection = vector_db.collection
results = collection.get(include=["metadatas", "documents", "embeddings"])
documents = results["documents"]
categories = [metadata["type"] for metadata in results["metadatas"]]
embeddings = results["embeddings"]

print(f"从向量数据库中获取了 {len(documents)} 条记录")

# 将数据分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    embeddings, categories, test_size=0.2, random_state=42, stratify=categories if len(set(categories)) > 1 else None
)

print(f"训练集大小: {len(X_train)}")
print(f"测试集大小: {len(X_test)}")

# 训练SVM模型
print("开始训练SVM模型...")
svm_model = SVC(kernel='linear', probability=True)
svm_model.fit(X_train, y_train)

# 评估模型
y_pred = svm_model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.4f}")

# 打印分类报告
print("\n分类报告:")
print(classification_report(y_test, y_pred))

### 结合向量数据库和 SVM 模型进行混合预测

In [None]:
def hybrid_predict(merchant, product, threshold):
    """
    结合向量数据库和SVM模型进行混合预测

    Args:
        merchant: 交易对方
        product: 商品名称
        threshold: 置信度阈值

    Returns:
        预测结果字典
    """
    # 向量数据库预测
    vector_prediction = vector_db.predict_category(merchant, product, threshold)

    # 如果向量数据库预测置信度高，直接返回
    if vector_prediction['confidence'] >= threshold:
        return vector_prediction

    # 否则使用SVM模型预测
    # query = f"{merchant}:{product}"
    # query_embedding = vector_db.embed_fn([query])[0]

    # svm_prediction = svm_model.predict([query_embedding])[0]
    # svm_proba = svm_model.predict_proba([query_embedding])
    # svm_confidence = max(svm_proba[0])

    # if svm_confidence >= threshold:
    #     return {
    #         "category": svm_prediction,
    #         "confidence": svm_confidence,
    #         "source": "svm_model"
    #     }

    # 如果置信度低于阈值，返回None作为类别
    return {
        "category": None,
        "confidence": vector_prediction['confidence'],
        "source": "none"
    }

In [None]:
import os
import json

repo_path = os.path.join(os.path.dirname(os.path.abspath('')), '..')
config_path = os.path.join(repo_path, 'config', 'settings.json')
with open(config_path, 'r', encoding='utf-8') as f:
    config = json.load(f)

output_path = config.get('output', {}).get('path')
similarity_threshold = config.get('model', {}).get('similarity_threshold')

merged_bill_path = os.path.join(repo_path, output_path, config.get('output', {}).get('merged_filename'))
predict_file_path = os.path.join(repo_path, output_path, config.get('output', {}).get('processed_filename'))

if os.path.exists(merged_bill_path):
    print(f"读取合并账单: {merged_bill_path}")
    merged_df = pd.read_csv(merged_bill_path)
    print(f"共读取 {len(merged_df)} 条记录")

    # 显示类型列的缺失情况
    missing_type_count = merged_df['类型'].isna().sum()
    print(f"缺少分类信息的记录数: {missing_type_count}")

    # 使用混合预测填充缺失的类型
    updated_count = 0
    for idx, row in merged_df.iterrows():
        if pd.isna(row['类型']):
            merchant = str(row['交易对方']) if not pd.isna(row['交易对方']) else ""
            product = str(row['商品名称']) if not pd.isna(row['商品名称']) else ""

            # 使用混合预测
            prediction = hybrid_predict(merchant, product, similarity_threshold)

            # 只在预测结果不为None时更新类型
            if prediction['category'] is not None:
                merged_df.at[idx, '类型'] = prediction['category']
                updated_count += 1

            # 每处理100条记录打印一次进度
            if updated_count % 100 == 0:
                print(f"已处理 {updated_count}/{missing_type_count} 条记录")

    print(f"共更新了 {updated_count} 条记录的分类信息")

    # 保存更新后的数据
    merged_df.to_csv(predict_file_path, index=False, encoding='utf-8')
    print(f"已将分类结果保存至: {predict_file_path}")