In [1]:
import numpy as np
import librosa
import os
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers, regularizers
from sklearn.metrics import precision_score, recall_score, confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns

# 初始化 FastAPI
app = FastAPI()

# 模型路径
model_path = "/Users/desperate/Desktop/FingerprintAndVoiceRecognition/src/dataset/updated_model.h5"


# 创建模型
def create_model(input_shape=(40, 500, 1)):
    model = Sequential([
        layers.InputLayer(input_shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),

        layers.Reshape((-1, 128)),
        layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.3, recurrent_dropout=0.2)),
        layers.BatchNormalization(),
        layers.Bidirectional(layers.LSTM(128, dropout=0.3, recurrent_dropout=0.2)),
        layers.BatchNormalization(),
        layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(0.001)),
        layers.Dropout(0.5),
        layers.Dense(1, activation='sigmoid')
    ])

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model


# 加载模型
model = create_model()
model.load_weights(model_path)


# 音频特征提取函数
def extract_features(audio_path, max_length=500):
    try:
        # 加载音频文件
        audio, _ = librosa.load(audio_path, sr=16000)
        # 提取MFCC特征
        mfccs = librosa.feature.mfcc(y=audio, sr=16000, n_mfcc=40)
        # 对特征进行填充或裁剪
        if mfccs.shape[1] < max_length:
            mfccs = np.pad(mfccs, ((0, 0), (0, max_length - mfccs.shape[1])), mode='constant')
        else:
            mfccs = mfccs[:, :max_length]
        return mfccs
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Error processing audio: {str(e)}")


# 定义输入数据结构
class PredictionInput(BaseModel):
    file_path: str


# API 路由
@app.get("/")
def root():
    return {"message": "Audio Classification API is running"}


@app.post("/predict/")
def predict(data: PredictionInput):
    try:
        # 提取音频特征
        features = extract_features(data.file_path)
        # 调整输入形状以适应模型
        features = features.reshape(1, 40, 500, 1)
        # 进行预测
        prediction = model.predict(features)
        label = "fake" if prediction[0][0] > 0.5 else "real"
        return {"prediction": label, "probability": float(prediction[0][0])}
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}")


@app.post("/evaluate/")
def evaluate(files: list[str], labels: list[int]):
    try:
        if len(files) != len(labels):
            raise HTTPException(status_code=400, detail="Files and labels must have the same length.")

        predictions = []
        true_labels = []

        for file, label in zip(files, labels):
            features = extract_features(file)
            features = features.reshape(1, 40, 500, 1)
            pred = model.predict(features)
            predictions.append(1 if pred[0][0] > 0.5 else 0)
            true_labels.append(label)

        # 计算评估指标
        precision = precision_score(true_labels, predictions)
        recall = recall_score(true_labels, predictions)
        accuracy = accuracy_score(true_labels, predictions)
        cm = confusion_matrix(true_labels, predictions)

        # 可视化混淆矩阵
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=["real", "fake"], yticklabels=["real", "fake"])
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.title("Confusion Matrix")
        plt.savefig("confusion_matrix.png")  # 保存混淆矩阵图像

        return {
            "precision": precision,
            "recall": recall,
            "accuracy": accuracy,
            "confusion_matrix": cm.tolist(),
            "confusion_matrix_image": "confusion_matrix.png"
        }
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Evaluation error: {str(e)}")

