In [42]:
#自定义dataloader
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import os
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
class FlowerDataset(Dataset):#传进去的参数其实是__init__里面的参数
    def __init__(self,root_dir,ann_file,transform=None):
        #这里面构建两个list，第一个装着图像的路径，第二个装着label
        #ann_file:标注文件所在路径
        #root_dir：文件夹，在哪里取数据
        self.ann_file = ann_file
        self.root_dir = root_dir
        self.img_label = self.load_annotations()#从标注文件中把返回一个字典，{image：label}
        self.img = [os.path.join(self.root_dir,img)for img in list(self.img_label.keys())]#第一个list
        self.label = [label for label in list(self.img_label.values())]
        self.transform = transform

    def __getitem__(self, idx):
        #给模型传输的实际的x,y，不是路径了
        #每次系统自动打包好，打包成一个batch
        image = Image.open(self.img[idx])
        label = self.label[idx]
        #在这里不仅要获取数据，图片和标签的所有预处理咱都要在这里实现
        if self.transform:
            image = self.transform(image)#图像的预处理，会写一个transform函数
        label = int(label)
        label = torch.tensor(label,dtype=torch.long)#标签如果要做预处理的话这里也要实现，这个例子里不需要
        return image, label

    def __len__(self):
        return len(self.img)

    def load_annotations(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for file_name , label in samples:
                data_infos[file_name] = label
        return data_infos

In [43]:
data_transforms = {
    "train": transforms.Compose([
        transforms.Resize(256),                       # 先把短边缩到256（保持比例）
        transforms.RandomResizedCrop(224),             # 随机裁剪到224（增强）
        transforms.RandomHorizontalFlip(),             # 随机水平翻转
        transforms.RandomRotation(30),                 # 小角度旋转
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),     # 轻微颜色扰动
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ]),
    "valid": transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),                    # 验证集固定中心裁剪（稳定评估）
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ]),
}

In [44]:
#图像路径准备
data_dir = './flower_data_dataloader/'
train_dir = data_dir + '/train_filelist'
valid_dir = data_dir + '/val_filelist'


In [45]:
train_dataset = FlowerDataset(root_dir = train_dir,ann_file = './flower_data_dataloader/train.txt',transform =data_transforms['train'])

In [46]:
val_dataset = FlowerDataset(root_dir= valid_dir,ann_file = './flower_data_dataloader/val.txt',transform =data_transforms['valid'])

In [47]:
train_loader = DataLoader(train_dataset,batch_size = 64,shuffle = True)
val_loader = DataLoader(val_dataset,batch_size = 64,shuffle = True)

In [52]:
#检查dataloder写的对不对
im,la = next(iter(train_loader))#每次.next都是取一个batch数据
sample = im[0].squeeze()
mean = torch.tensor([0.485, 0.456, 0.406], device=sample.device).view(3,1,1)
std  = torch.tensor([0.229, 0.224, 0.225], device=sample.device).view(3,1,1)

# 反归一化
sample = sample * std + mean

# CHW -> HWC，限制到[0,1]避免imshow异常
sample = sample.permute(1,2,0).clamp(0,1)
plt.imshow(sample)
plt.show()
print('Label:{}'.format(la[0].numpy()))

Label:56
