In [18]:
import os
from pydub import AudioSegment
import librosa
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from dsp import *

我们希望使用 SVM 实现语音识别，步骤如下：

1. 特征提取

2. 模型训练

3. 模型识别

In [8]:
def readDat(filepath):
    # 使用16位有符号整数(int16)格式读取原始数据
    y = np.fromfile(filepath, dtype=np.int16)
    # 将整数值归一化到[-1, 1]范围
    y = y.astype(np.float32) / 32768.0
    return y

In [9]:
y = readDat('22307110206-00-01.dat')
print(y.shape)

(20310,)


In [10]:
# 特征提取
def extractFeature(audio_file, sr=8000, win=256):
    y = readDat(audio_file)
    
    segments = vad(y, sr, win, win//2)

    if len(segments) == 0:
        print("No segments found.")
        return None
    
    # 取所有音频段的起始和结束时间
    start = segments[0][0]
    end = segments[-1][1]
    y_speech = y[int(start):int(end)]

    # 确保语音段足够长
    min_length = win * 2  # 确保能生成至少一个完整的STFT帧
    if len(y_speech) < min_length:
        print(f"警告: {audio_file} 语音段太短({len(y_speech)}采样点)，填充至{min_length}采样点")
        y_speech = np.pad(y_speech, (0, min_length - len(y_speech)), 'constant')

    # print("Start time:", start)
    # print("End time:", end)
    # print("Segment Num:", len(segments))
    # print("-" * 20)

    # 提取MFCC特征
    mfcc_feature = computeMFCC(y_speech, sr, win, D=13, M=26)

    mfcc_mean = np.mean(mfcc_feature, axis=1)
    mfcc_std = np.std(mfcc_feature, axis=1)

    feature = np.concatenate([mfcc_mean, mfcc_std])

    return feature

In [19]:
# 准备数据集
# def prepareData(base_dir, classes):
#     features = []
#     labels = []
    
#     total_files = 0         # 总文件数
#     processed_files = 0     # 处理的文件数
    
#     print("计算总文件数...")
#     for idx, name in enumerate(classes):
#         class_dir = os.path.join(base_dir, str(idx).zfill(2))
#         if os.path.exists(class_dir):
#             for filename in os.listdir(class_dir):
#                 if filename.endswith('.dat'):
#                     total_files += 1
#     print(f"找到 {total_files} 个文件")
    
#     for idx, name in enumerate(classes):
#         class_dir = os.path.join(base_dir, str(idx).zfill(2))
#         if not os.path.exists(class_dir):
#             print(f"警告: 目录 {class_dir} 不存在")
#             continue
            
#         print(f"处理类别 {idx}: {name}")
#         for filename in os.listdir(class_dir):
#             if filename.endswith('.dat'):
#                 filepath = os.path.join(class_dir, filename)
#                 feature = extractFeature(filepath)
#                 if feature is not None:
#                     features.append(feature)
#                     labels.append(idx)
                
#                 processed_files += 1
#                 percent = processed_files / total_files * 100
#                 print(f"\r进度: [{processed_files}/{total_files}] {percent:.1f}%", end="")
    
#     print("\n处理完成!")
#     return np.array(features), np.array(labels)

# 准备数据集
def prepareData(base_dir, classes):
    features = []
    labels = []

    print(f"开始从 {base_dir} 加载数据...")
    
    total_files = 0         # 总文件数
    processed_files = 0     # 处理了的文件数
    error_files = 0         # 出错的文件数
    
    # 先计算总文件数
    for idx, name in enumerate(classes):
        class_dir = os.path.join(base_dir, str(idx).zfill(2))
        if os.path.exists(class_dir):
            for filename in os.listdir(class_dir):
                if filename.endswith('.dat'):
                    total_files += 1
    print(f"找到 {total_files} 个文件")
    
    # 创建总进度条
    with tqdm(total=total_files, desc="总进度") as pbar:
        for idx, name in enumerate(classes):
            class_dir = os.path.join(base_dir, str(idx).zfill(2))
            
            if not os.path.exists(class_dir):
                print(f"警告：目录 {class_dir} 不存在，跳过类别 '{name}'")
                continue
                
            class_files = [f for f in os.listdir(class_dir) if f.endswith('.dat')]
            print(f"正在处理类别 '{name}' ({len(class_files)}个文件)...")
            
            for filename in class_files:
                filepath = os.path.join(class_dir, filename)
                try:
                    feature = extractFeature(filepath)
                    if feature is not None:
                        features.append(feature)
                        labels.append(idx)
                    else:
                        error_files += 1
                except Exception as e:
                    print(f"处理文件 {filepath} 出错: {str(e)}")
                    error_files += 1
                
                processed_files += 1
                pbar.update(1)
            
    print(f"\n处理完成: 共 {total_files} 个文件, 成功 {processed_files - error_files} 个, 失败 {error_files} 个")
    if len(features) == 0:
        raise ValueError("没有成功处理任何文件!请检查数据路径和文件格式")
    
    return np.array(features), np.array(labels)

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import classification_report, accuracy_score
import joblib

# 训练SVM模型
def train_svm_model(features, labels, test_size=0.2, random_state=42):
    # 分割训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(
        features, labels, test_size=test_size, random_state=random_state, stratify=labels
    )
    
    # 特征标准化
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # 训练SVM模型
    svm = SVC(kernel='rbf', C=10, gamma='scale', probability=True)
    svm.fit(X_train_scaled, y_train)
    
    # 评估模型
    y_pred = svm.predict(X_test_scaled)
    accuracy = accuracy_score(y_test, y_pred)
    report = classification_report(y_test, y_pred)
    
    print(f"Accuracy: {accuracy:.4f}")
    print("Classification Report:")
    print(report)
    
    # 保存模型和标准化器
    joblib.dump(svm, 'svm_model.pkl')
    joblib.dump(scaler, 'scaler.pkl')
    
    return svm, scaler, (X_test_scaled, y_test)

In [None]:
# 使用模型进行识别
def recognize_speech(audio_file, model, scaler, class_names):
    # 提取特征
    feature = extractFeature(audio_file)
    if feature is None:
        return "无法检测到语音"
    
    # 标准化特征
    features_scaled = scaler.transform(feature.reshape(1, -1))
    
    # 预测
    probabilities = model.predict_proba(features_scaled)[0]
    class_idx = np.argmax(probabilities)
    confidence = probabilities[class_idx]
    
    return class_names[class_idx], confidence

In [17]:
classes = ["数字", "语音", "语言", "处理", "中国", "忠告", "北京", "背景", "上海", 
              "Speech", "Speaker", "Signal", "Sequence", "Processing", "Print", "Project", "File", "Open"]

features, labels = prepareData('../Data', classes)


开始从 ../Data 加载数据...


总进度:   0%|          | 0/9719 [00:00<?, ?it/s]

正在处理类别 '数字' (539个文件)...


总进度:   0%|          | 1/9719 [00:00<52:41,  3.07it/s]

警告: ../Data/00/21307130052_00_19.dat 语音段太短(128采样点)，填充至512采样点


总进度:   0%|          | 3/9719 [00:06<6:35:48,  2.44s/it]

警告: ../Data/00/22307130038_00_18.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/21307130052_00_11.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/21307130052_00_10.dat 语音段太短(128采样点)，填充至512采样点


总进度:   0%|          | 44/9719 [00:11<36:22,  4.43it/s] 

警告: ../Data/00/22307130038_00_07.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/21307130052_00_12.dat 语音段太短(128采样点)，填充至512采样点


总进度:   1%|          | 88/9719 [00:16<23:34,  6.81it/s]

警告: ../Data/00/22307130038_00_16.dat 语音段太短(128采样点)，填充至512采样点


总进度:   1%|▏         | 128/9719 [00:21<20:40,  7.73it/s]

警告: ../Data/00/22307130013_00_17.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/21307130052_00_07.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/21307130052_00_16.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/21307130052_00_18.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/22307130038_00_12.dat 语音段太短(128采样点)，填充至512采样点


总进度:   2%|▏         | 172/9719 [00:26<19:37,  8.11it/s]

警告: ../Data/00/21307130052_00_13.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/22307130013_00_01.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/21307130052_00_20.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/22307130038_00_06.dat 语音段太短(384采样点)，填充至512采样点
警告: ../Data/00/21307130052_00_09.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/22307130038_00_11.dat 语音段太短(128采样点)，填充至512采样点


总进度:   2%|▏         | 241/9719 [00:28<11:17, 14.00it/s]

警告: ../Data/00/21307130150-00-20.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/22307130013_00_20.dat 语音段太短(128采样点)，填充至512采样点


总进度:   3%|▎         | 250/9719 [00:29<11:48, 13.36it/s]

警告: ../Data/00/21307130150-00-18.dat 语音段太短(128采样点)，填充至512采样点


总进度:   3%|▎         | 262/9719 [00:36<1:17:13,  2.04it/s]

警告: ../Data/00/21307130052_00_03.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/21307130052_00_02.dat 语音段太短(128采样点)，填充至512采样点
警告: ../Data/00/22307130013_00_06.dat 语音段太短(128采样点)，填充至512采样点


总进度:   3%|▎         | 291/9719 [00:34<18:23,  8.54it/s]  

警告: ../Data/00/21307130150-00-16.dat 语音段太短(128采样点)，填充至512采样点





KeyboardInterrupt: 