In [None]:
import os
import numpy as np
from PIL import Image
import random
from feature import *

In [None]:
def load_dataset(data_folder):
    images = []
    labels = []
    for label in os.listdir(data_folder):
        label_path = os.path.join(data_folder, label)
        if not os.path.isdir(label_path):
            continue
        for img_name in os.listdir(label_path):
            img_path = os.path.join(label_path, img_name)
            try:
                image = Image.open(img_path).convert('RGB')
                images.append(np.array(image))
                labels.append(label)
            except Exception as e:
                print(f"載入圖片失敗 {img_path}，錯誤: {e}")
    return images, labels

In [None]:
def split_train_test(images, labels, test_ratio=0.2, seed=42):
    random.seed(seed)
    label_to_indices = {}
    for idx, label in enumerate(labels):
        label_to_indices.setdefault(label, []).append(idx)
    
    train_indices = []
    test_indices = []
    for label, indices in label_to_indices.items():
        random.shuffle(indices)
        split_point = int(len(indices) * (1 - test_ratio))
        train_indices.extend(indices[:split_point])
        test_indices.extend(indices[split_point:])
    
    train_images = [images[i] for i in train_indices]
    train_labels = [labels[i] for i in train_indices]
    test_images = [images[i] for i in test_indices]
    test_labels = [labels[i] for i in test_indices]
    return train_images, train_labels, test_images, test_labels

In [None]:


data_path=""

print("載入資料集...")
images, labels = load_dataset(data_path)
print(f"共載入圖片數量: {len(images)}")

print("切分訓練集與測試集(8:2)...")
train_imgs, train_lbls, test_imgs, test_lbls = split_train_test(images, labels)
print(f"訓練集大小: {len(train_imgs)}，測試集大小: {len(test_imgs)}")

print("提取訓練集特徵...")
train_feats = [mango_feature_extractor(img) for img in train_imgs]
print("提取測試集特徵...")
test_feats = [mango_feature_extractor(img) for img in test_imgs]

print("訓練KNN分類器...")
knn = KNNClassifier(k=5)
knn.fit(train_feats, train_lbls)

print("開始測試...")
correct = 0
for feat, lbl in zip(test_feats, test_lbls):
    pred = knn.predict(feat)
    if pred == lbl:
        correct += 1

accuracy = correct / len(test_lbls)
print(f"測試集數量: {len(test_lbls)}")
print(f"分類準確率: {accuracy*100:.2f}%")
print(f"錯誤率: {(1 - accuracy)*100:.2f}%")