In [None]:
import os
import librosa
import numpy as np
import soundfile as sf
from tqdm import tqdm
import glob
import argparse
import joblib
from sklearn.preprocessing import StandardScaler

In [None]:
# 配置参数
TARGET_SR = 22050       # 目标采样率
DURATION = 0.5          # 提取音频时长(秒)
N_MELS = 128            # Mel频带数量
N_FFT = 2048            # FFT窗口大小
HOP_LENGTH = 512        # 帧移
MAX_LENGTH = 22         # 时间轴最大帧数 (基于0.5秒计算)

In [None]:
def compute_mel_spectrogram(y, sr):
    """
    计算Mel谱图 (训练优化版)
    :param y: 音频数据
    :param sr: 采样率
    :return: Mel谱图 (固定尺寸)
    """
    # 确保音频长度正确
    target_length = int(DURATION * sr)
    if len(y) < target_length:
        y = np.pad(y, (0, max(0, target_length - len(y)), mode='constant'))
    else:
        y = y[:target_length]
    
    # 计算Mel谱图
    mel_spec = librosa.feature.melspectrogram(
        y=y,
        sr=sr,
        n_mels=N_MELS,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH,
        fmax=sr//2
    )
    
    # 转换为对数刻度
    log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
    
    # 确保固定尺寸
    if log_mel_spec.shape[1] < MAX_LENGTH:
        log_mel_spec = np.pad(log_mel_spec, ((0, 0), (0, MAX_LENGTH - log_mel_spec.shape[1])), 
                              mode='constant')
    elif log_mel_spec.shape[1] > MAX_LENGTH:
        log_mel_spec = log_mel_spec[:, :MAX_LENGTH]
    
    return log_mel_spec

def process_audio_to_mel_features(input_dir, output_dir, train_scaler=True):
    """
    处理音频目录中的所有MP3文件，生成训练用Mel特征
    :param input_dir: 输入目录路径
    :param output_dir: 输出目录路径
    :param train_scaler: 是否训练归一化器
    """
    # 确保输出目录存在
    os.makedirs(output_dir, exist_ok=True)
    
    # 获取所有MP3文件
    audio_files = glob.glob(os.path.join(input_dir, "*.mp3"))
    
    if not audio_files:
        print(f"警告: 在目录 {input_dir} 中未找到MP3文件")
        return
    
    print(f"找到 {len(audio_files)} 个MP3文件，开始提取Mel特征...")
    
    # 存储所有Mel谱用于训练归一化器
    all_mel_specs = []
    file_paths = []
    
    # 第一遍：收集所有Mel谱
    for file_path in tqdm(audio_files, desc="收集特征"):
        try:
            # 加载音频
            y, sr = librosa.load(file_path, sr=TARGET_SR, mono=True)
            
            # 计算Mel谱
            mel_spec = compute_mel_spectrogram(y, sr)
            all_mel_specs.append(mel_spec)
            file_paths.append(file_path)
        except Exception as e:
            print(f"处理 {os.path.basename(file_path)} 失败: {str(e)}")
    
    if not all_mel_specs:
        print("错误: 未成功处理任何文件")
        return
    
    # 转换为3D数组 (样本数, 频率, 时间)
    all_mel_specs = np.array(all_mel_specs)
    
    # 训练归一化器或加载现有归一化器
    scaler_path = os.path.join(output_dir, "mel_scaler.pkl")
    if train_scaler or not os.path.exists(scaler_path):
        print("训练归一化器...")
        # 重塑为2D数组 (样本数 * 频率, 时间)
        n_samples, n_mels, n_time = all_mel_specs.shape
        reshaped_specs = all_mel_specs.reshape(n_samples * n_mels, n_time)
        
        # 训练归一化器
        scaler = StandardScaler()
        scaler.fit(reshaped_specs)
        
        # 保存归一化器
        joblib.dump(scaler, scaler_path)
        print(f"归一化器已保存到 {scaler_path}")
    else:
        print(f"加载现有归一化器: {scaler_path}")
        scaler = joblib.load(scaler_path)
    
    # 第二遍：应用归一化并保存特征
    print("应用归一化并保存特征...")
    for i, file_path in enumerate(tqdm(file_paths, desc="保存特征")):
        try:
            # 获取文件名（不含扩展名）
            filename = os.path.basename(file_path)
            basename = os.path.splitext(filename)[0]
            
            # 获取Mel谱
            mel_spec = all_mel_specs[i]
            
            # 重塑为2D以应用归一化
            n_mels, n_time = mel_spec.shape
            reshaped = mel_spec.reshape(n_mels * n_time, 1)
            normalized = scaler.transform(reshaped)
            normalized = normalized.reshape(n_mels, n_time)
            
            # 保存为numpy数组
            output_path = os.path.join(output_dir, f"{basename}.npy")
            np.save(output_path, normalized)
            
        except Exception as e:
            print(f"保存 {filename} 失败: {str(e)}")
    
    # 保存数据集元数据
    metadata = {
        "sample_rate": TARGET_SR,
        "duration": DURATION,
        "n_mels": N_MELS,
        "n_fft": N_FFT,
        "hop_length": HOP_LENGTH,
        "max_length": MAX_LENGTH,
        "scaler_params": {
            "mean": scaler.mean_,
            "scale": scaler.scale_
        },
        "file_count": len(file_paths)
    }
    np.save(os.path.join(output_dir, "dataset_metadata.npy"), metadata)
    print("数据集元数据已保存")
