### 1. 原始数据读取
- 并不是把所有图像全部读进内存！
- 而是把所有图像的`路径`和`类别`归纳和梳理出来！
- img_path
- img_label

In [1]:
"""
    尝试读取 train 
"""
import os
train_root = os.path.join("gesture", "train")
train_paths = []
train_labels = []

for label in os.listdir(train_root):
    label_root = os.path.join(train_root, label)
    for file in os.listdir(label_root):
        file_path = os.path.join(label_root, file)
        train_paths.append(file_path)
        train_labels.append(label)

In [2]:
"""
    尝试读取 test 
"""
import os
test_root = os.path.join("gesture", "test")
test_paths = []
test_labels = []

for label in os.listdir(test_root):
    label_root = os.path.join(test_root, label)
    for file in os.listdir(label_root):
        file_path = os.path.join(label_root, file)
        test_paths.append(file_path)
        test_labels.append(label)

In [11]:
# 构建 标签字典 label dict
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
label2idx = {label: idx for idx, label in enumerate(labels)}
idx2label = {idx: label for label, idx in label2idx.items()}

### 2. 批量化打包
- 继承 Dataset，自定义一个数据集
- 实例化 DataLoader

In [45]:
# 引入必要的工具类
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import transforms
import torch

In [46]:
class GestureDataset(Dataset):
    """
        自定义手势识别数据集
    """
    def __init__(self, X, y):
        """
            初始化
        """
        self.X = X
        self.y = y

    def __getitem__(self, idx):
        """
            实现：
                - 按下标来索引一个样本
        """
        # 获取图像路径
        img_path = self.X[idx]
        # 读取图像
        img = Image.open(fp=img_path)
        # 统一大小
        img = img.resize((32, 32))
        # 转张量 [C, H, W]
        # [0, 1]
        img = transforms.ToTensor()(img)
        # [-1, 1]
        img = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img)

        # 读取标签
        img_label = self.y[idx]
        # 标签转 id
        img_idx = label2idx.get(img_label)
        # 转张量
        label = torch.tensor(data=img_idx, dtype=torch.long)

        return img, label

    
    def __len__(self):
        """
            返回该数据集的样本个数
        """
        return len(self.X)

In [51]:
# 训练集加载器
train_dataset = GestureDataset(X=train_paths, y=train_labels)
train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=16)
# 测试集加载器
test_dataset = GestureDataset(X=test_paths, y=test_labels)
test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=32)

In [53]:
# 测试
for X, y in test_dataloader:
    print(X.shape)
    print(y.shape)
    break

torch.Size([32, 3, 32, 32])
torch.Size([32])
