In [6]:
import os
import cv2
import pandas as pd
import numpy as np
from skimage.feature import hog
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder

# 定义预期的标签集合（小写形式）
expected_labels = {'stop', 'ceder', 'danger', 'forange', 'frouge', 'fvert', 'interdiction', 'obligation', 'unknown'}

# 读取图像并裁剪特征区域
def extract_region(image_path, csv_path):
    image = cv2.imread(image_path)
    if os.path.getsize(csv_path) == 0:
        height, width = image.shape[:2]
        # 使用整个图像作为区域，并标记为 unknown
        regions = [(cv2.resize(image, (64, 64)), 'unknown')]
        return regions
    try:
        data = pd.read_csv(csv_path, header=None)
    except pd.errors.EmptyDataError:
        height, width = image.shape[:2]
        regions = [(cv2.resize(image, (64, 64)), 'unknown')]
        return regions
    regions = []
    for index, row in data.iterrows():
        x1, y1, x2, y2, label = row
        label = label.strip().lower()  # 去除前后空格并转换为小写
        if label not in expected_labels:
            print(f"Unexpected label '{label}' found in file {csv_path}.")
            label = 'unknown'
        region = image[int(y1):int(y2), int(x1):int(x2)]
        if region.size == 0:
            continue
        region = cv2.resize(region, (64, 64))
        regions.append((region, label))
    return regions

# 提取HOG特征
def extract_hog_features(image):
    image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    features, _ = hog(image_gray, orientations=9, pixels_per_cell=(8, 8),
                      cells_per_block=(2, 2), block_norm='L2-Hys', visualize=True)
    return features

# 加载数据集
def load_dataset(image_folder, label_folder):
    features = []
    labels = []
    for image_file in os.listdir(image_folder):
        if image_file.endswith('.jpg'):
            image_path = os.path.join(image_folder, image_file)
            csv_path = os.path.join(label_folder, image_file.replace('.jpg', '.csv'))
            regions = extract_region(image_path, csv_path)
            for region, label in regions:
                features.append(extract_hog_features(region))
                labels.append(label)
    return np.array(features), np.array(labels)

# 读取数据
train_image_folder = './train/images'
train_label_folder = './train/labels'
val_image_folder = './val/images'
val_label_folder = './val/labels'

X_train, y_train = load_dataset(train_image_folder, train_label_folder)
X_val, y_val = load_dataset(val_image_folder, val_label_folder)

# 检查是否成功加载数据
print(f"Loaded {len(X_train)} training samples and {len(X_val)} validation samples.")

# 标签编码
le = LabelEncoder()
y_train = le.fit_transform(y_train)
y_val = le.transform(y_val)

# 打印标签与数字的对应关系
print("标签与数字的对应关系：", le.classes_)

# 特征降维
pca = PCA(n_components=50)
X_train_pca = pca.fit_transform(X_train)
X_val_pca = pca.transform(X_val)

# 训练SVM模型
svm_model = SVC(kernel='linear')
svm_model.fit(X_train_pca, y_train)

# 评估模型
y_pred = svm_model.predict(X_val_pca)
print(classification_report(y_val, y_pred))
print(confusion_matrix(y_val, y_pred))


Unexpected label 'ff' found in file ./train/labels\0073.csv.
Unexpected label 'ff' found in file ./train/labels\0073.csv.
Unexpected label 'ff' found in file ./train/labels\0079.csv.
Unexpected label 'ff' found in file ./train/labels\0079.csv.
Unexpected label 'ff' found in file ./train/labels\0079.csv.
Unexpected label 'ff' found in file ./train/labels\0085.csv.
Unexpected label 'ff' found in file ./train/labels\0086.csv.
Unexpected label 'ff' found in file ./train/labels\0086.csv.
Unexpected label 'ff' found in file ./train/labels\0090.csv.
Unexpected label 'ff' found in file ./train/labels\0090.csv.
Unexpected label 'ff' found in file ./val/labels\0088.csv.
Loaded 1072 training samples and 119 validation samples.
标签与数字的对应关系： ['ceder' 'danger' 'forange' 'frouge' 'fvert' 'interdiction' 'obligation'
 'stop' 'unknown']
              precision    recall  f1-score   support

           0       1.00      0.82      0.90        17
           1       0.94      1.00      0.97        16
       