<a href="https://colab.research.google.com/github/vannis422/CIFAR-10/blob/main/CIFARFolder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple
from PIL import Image
import torch
import torchvision
import os

class CIFAR10ImageFolder(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        """
        root_dir: 資料夾路徑，例如 'cifar10/train'
        transform: torchvision.transforms 用來轉換 PIL 圖片
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        class_names = sorted(os.listdir(root_dir)) # 列出 `root_dir` 底下有哪些類別資料夾，並由小至大排序
        for idx, class_name in enumerate(class_names):
            class_folder = os.path.join(root_dir, class_name)
            if not os.path.isdir(class_folder):
                continue

            for filename in os.listdir(class_folder):
                if filename.endswith(('.png', '.jpg')):
                    # TODO2-1: 將所有的檔案路徑都收集到 `self.image_paths`
                    img_path = os.path.join(class_folder, filename)
                    self.image_paths.append(img_path)

                    # TODO2-2: 將所有的 label_id 都收集到 `self.labels`
                    self.labels.append(idx)

    def __getitem__(self, idx):
        # TODO2-3: 取得 idx 所對應的檔案路徑
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')  # 讀成 RGB 模式

        if self.transform:
            image = self.transform(image)

        return image, label

    def __len__(self):
        return len(self.image_paths)