## 数据集Dataloader制作
![title](img/DataLoader.png)

In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
#pip install torchvision
from torchvision import transforms, models, datasets
#https://pytorch.org/docs/stable/torchvision/index.html
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

#### 任务1：读取txt文件中的路径和标签

In [2]:
def load_annotations(ann_file):
    data_infos = {}
    with open(ann_file) as f:
        samples = [x.strip().split(' ') for x in f.readlines()]
        for filename, gt_label in samples:
            data_infos[filename] = np.array(gt_label, dtype=np.int64)
    return data_infos

In [3]:
print(load_annotations('./flower_data/train.txt'))

{'image_06734.jpg': array(0, dtype=int64), 'image_06735.jpg': array(0, dtype=int64), 'image_06736.jpg': array(0, dtype=int64), 'image_06737.jpg': array(0, dtype=int64), 'image_06738.jpg': array(0, dtype=int64), 'image_06740.jpg': array(0, dtype=int64), 'image_06741.jpg': array(0, dtype=int64), 'image_06742.jpg': array(0, dtype=int64), 'image_06744.jpg': array(0, dtype=int64), 'image_06745.jpg': array(0, dtype=int64), 'image_06746.jpg': array(0, dtype=int64), 'image_06747.jpg': array(0, dtype=int64), 'image_06748.jpg': array(0, dtype=int64), 'image_06750.jpg': array(0, dtype=int64), 'image_06751.jpg': array(0, dtype=int64), 'image_06753.jpg': array(0, dtype=int64), 'image_06757.jpg': array(0, dtype=int64), 'image_06759.jpg': array(0, dtype=int64), 'image_06761.jpg': array(0, dtype=int64), 'image_06762.jpg': array(0, dtype=int64), 'image_06766.jpg': array(0, dtype=int64), 'image_06767.jpg': array(0, dtype=int64), 'image_06768.jpg': array(0, dtype=int64), 'image_06770.jpg': array(0, dtype

#### 任务2：分别把数据和标签都存在list里

In [4]:
image_label = load_annotations('./flower_data/train.txt')
image_name = list(image_label.keys())
label = list(image_label.values())

In [5]:
image_name

['image_06734.jpg',
 'image_06735.jpg',
 'image_06736.jpg',
 'image_06737.jpg',
 'image_06738.jpg',
 'image_06740.jpg',
 'image_06741.jpg',
 'image_06742.jpg',
 'image_06744.jpg',
 'image_06745.jpg',
 'image_06746.jpg',
 'image_06747.jpg',
 'image_06748.jpg',
 'image_06750.jpg',
 'image_06751.jpg',
 'image_06753.jpg',
 'image_06757.jpg',
 'image_06759.jpg',
 'image_06761.jpg',
 'image_06762.jpg',
 'image_06766.jpg',
 'image_06767.jpg',
 'image_06768.jpg',
 'image_06770.jpg',
 'image_06771.jpg',
 'image_06772.jpg',
 'image_06773.jpg',
 'image_07086.jpg',
 'image_07087.jpg',
 'image_07088.jpg',
 'image_07089.jpg',
 'image_07091.jpg',
 'image_07092.jpg',
 'image_07093.jpg',
 'image_07095.jpg',
 'image_07096.jpg',
 'image_07097.jpg',
 'image_07098.jpg',
 'image_07099.jpg',
 'image_07100.jpg',
 'image_07103.jpg',
 'image_07105.jpg',
 'image_07106.jpg',
 'image_07108.jpg',
 'image_07109.jpg',
 'image_07110.jpg',
 'image_07111.jpg',
 'image_07112.jpg',
 'image_07113.jpg',
 'image_07114.jpg',


In [6]:
data_dir = './flower_data/'
train_dir = data_dir + '/train_filelist'
valid_dir = data_dir + '/val_filelist'

#### 任务4：把上面那几个事得写在一起
- 1.注意要使用from torch.utils.data import Dataset, DataLoader
- 2.类名定义class FlowerDataset(Dataset)，其中FlowerDataset可以改成自己的名字
- 3.def __init__(self, root_dir, ann_file, transform=None):咱们要根据自己任务重写
- 4.def __getitem__(self, idx):根据自己任务，返回图像数据和标签数据

In [7]:
from torch.utils.data import Dataset, DataLoader


class FlowerDataset(Dataset):
    def __init__(self, root_dir, ann_file, transform=None):
        self.ann_file = ann_file
        self.root_dir = root_dir
        self.img_label = self.load_annotations()
        self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]
        self.label = [label for label in list(self.img_label.values())]
        self.transform = transform

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):
        image = Image.open(self.img[idx])
        label = self.label[idx]
        if self.transform:
            #特征图像处理
            image = self.transform(image)
        label = torch.from_numpy(np.array(label))
        return image, label

    def load_annotations(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, gt_label in samples:
                data_infos[filename] = np.array(gt_label, dtype=np.int64)
        return data_infos

#### 任务5：数据预处理(transform)

In [8]:
data_transforms = {
    'train':
        transforms.Compose([
            transforms.Resize(64),
            transforms.RandomRotation(45),  #随机旋转，-45到45度之间随机选
            transforms.CenterCrop(64),  #从中心开始裁剪
            transforms.RandomHorizontalFlip(p=0.5),  #随机水平翻转 选择一个概率概率
            transforms.RandomVerticalFlip(p=0.5),  #随机垂直翻转
            transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
            #参数1为亮度，参数2为对比度，参数3为饱和度，参数4为色相
            transforms.RandomGrayscale(p=0.025),  #概率转换成灰度率，3通道就是R=G=B
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  #均值，标准差
        ]),
    'valid':
        transforms.Compose([
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
}

#### 任务6：根据写好的class FlowerDataset(Dataset):来实例化咱们的dataloader

In [9]:
train_dataset = FlowerDataset(root_dir=train_dir, ann_file='./flower_data/train.txt',
                              transform=data_transforms['train'])
val_dataset = FlowerDataset(root_dir=valid_dir, ann_file='./flower_data/val.txt', transform=data_transforms['valid'])

In [10]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)

In [11]:
len(train_dataset)

6552

#### 任务7：用之前先试试，整个数据和标签对应下，看看对不对

In [None]:
%matplotlib inline
image, label = next(iter(train_loader))
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
sample = np.clip(sample, 0, 1)
print(sample.shape)
print(sample)
plt.imshow(sample)

(64, 64, 3)
[[[1.9215684e-01 9.0196073e-02 9.4117656e-02]
  [2.7058825e-01 5.0980389e-02 1.0588236e-01]
  [5.9607846e-01 3.5294145e-02 2.1960786e-01]
  ...
  [0.0000000e+00 2.9563903e-08 1.1682510e-08]
  [0.0000000e+00 2.9563903e-08 1.1682510e-08]
  [0.0000000e+00 2.9563903e-08 1.1682510e-08]]

 [[2.5490198e-01 3.9215684e-02 9.4117656e-02]
  [4.1960785e-01 3.5294145e-02 1.6470590e-01]
  [6.5490198e-01 3.1372547e-02 2.3921570e-01]
  ...
  [0.0000000e+00 2.9563903e-08 1.1682510e-08]
  [0.0000000e+00 2.9563903e-08 1.1682510e-08]
  [0.0000000e+00 2.9563903e-08 1.1682510e-08]]

 [[2.6666668e-01 1.9607842e-02 1.0980393e-01]
  [4.1176471e-01 2.3529440e-02 1.5294121e-01]
  [6.2745100e-01 2.7450979e-02 2.3137257e-01]
  ...
  [0.0000000e+00 2.9563903e-08 1.1682510e-08]
  [0.0000000e+00 2.9563903e-08 1.1682510e-08]
  [0.0000000e+00 2.9563903e-08 1.1682510e-08]]

 ...

 [[0.0000000e+00 2.9563903e-08 1.1682510e-08]
  [0.0000000e+00 2.9563903e-08 1.1682510e-08]
  [0.0000000e+00 2.9563903e-08 1.16825

<matplotlib.image.AxesImage at 0x20e5ba71a50>