# DataSet

In [None]:
from torch.utils.data import Dataset, DataLoader

## 自定义数据集
# 数据集必须实现两个接口, 一个是len一个是getitem
class NumbersDataset(Dataset):
    def __init__(self, training=True):
        if training:
            self.sample = list(range(1, 10000))
        else:
            self.sample = list(range(10001, 20000))

    def __len__(self):  # 当前数据集的长度
        return len(self.sample)

    def __getitem__(self, index):  # 根据index读取数据
        return self.sample[index]


一个复杂的自定义数据集, 可以直接用.

当然, 也可以直接用:

```python
from torchvision import transforms
tf = transforms.Compose([transforms.Resize(64), transforms.ToTensor()])
torchvision.datasets.ImageFolder(root='xxxx',transforms=tf)
```

In [None]:
import torch
import os, glob
import random, csv

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image


class MyDataset(Dataset):
    def __init__(
        self,
        path,
        mode,
        resize=None,
        to_tensor=False,
        transformers=None,
        write_to_csv=False,
    ):
        self.path = path
        self.resize = resize
        self.mode = mode
        self.write_to_csv = write_to_csv
        self.to_tensor = to_tensor
        self.transformers = transformers
        ## 1. 对每一个label进行编码
        self.name2label = {}  # 将label进行编码
        for name in sorted(os.listdir(os.path.join(path))):
            if not os.path.isdir(os.path.join(path, name)):
                continue
            # 使用当前key的长度为id
            # {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
            self.name2label[name] = len(self.name2label.keys())
        ## 2. 对每一个image获取一个label
        self.images_total, self.labels_total = self.load_csv("files.csv")

        ## 3. 训练集和测试集
        if mode == "train":  # 60%
            self.images = self.images_total[: int(0.6 * len(self.images_total))]
            self.labels = self.labels_total[: int(0.6 * len(self.labels_total))]
        if mode == "val":  # 20%
            self.images = self.images[
                int(0.6 * len(self.images_total)) : int(0.8 * len(self.images_total))
            ]
            self.labels = self.images[
                int(0.6 * len(self.labels_total)) : int(0.8 * len(self.labels_total))
            ]
        if mode == "test":  # 20%
            self.images = self.images_total[int(0.8 * len(self.images_total)) :]
            self.labels = self.labels_total[int(0.8 * len(self.labels_total)) :]

    def load_csv(self, filename):
        output, labels = [], []
        if self.write_to_csv and os.path.exists(os.path.join(self.path, filename)):
            with open(os.path.join(self.path, filename)) as f:
                reader = csv.reader(f)
                for (output_path, label) in reader:
                    label = int(label)
                    output.append(output_path)
                    labels.append(label)
        else:
            images = []
            for name in self.name2label.keys():
                folder_images = []
                folder_images += glob.glob(os.path.join(self.path, name, "*.png"))
                folder_images += glob.glob(os.path.join(self.path, name, "*.jpg"))
                folder_images += glob.glob(os.path.join(self.path, name, "*.jpeg"))
                for image in folder_images:
                    images.append((image, name))
            random.shuffle(images)
            for (output_path, label) in images:
                output.append(os.path.abspath(output_path))
                labels.append(self.name2label[label])
            if self.write_to_csv:
                with open(os.path.join(self.path, filename), mode="w", newline="") as f:
                    writer = csv.writer(f)
                    for (path, label) in images:
                        writer.writerow([os.path.abspath(path), self.name2label[label]])
        return output, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img, label = self.images[index], self.labels[index]
        ## 将路径读取进来, 然后使用torchvision输出. 这里考试构建转换器
        transformers = [lambda x: Image.open(x).convert("RGB")]
        if self.resize is not None:
            if isinstance(self.resize, int):
                transformers.append(transforms.Resize((self.resize, self.resize)))
            elif isinstance(self.resize, tuple):
                transformers.append(transforms.Resize(self.resize))
        if self.to_tensor:
            transformers.append(transforms.ToTensor())
            label = torch.tensor(label)

        if self.transformers is not None:
            if isinstance(self.transformers, bool) and self.transformers:
                transformers.append(transforms.RandomRotation(15))
                # 这里使用的是imagenet的统计结果, 一般就直接使用
                transformers.append(
                    transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225],
                    )
                )
            elif not isinstance(self.transformers, bool):
                transformers = self.transformers
        # 转换
        tf = transforms.Compose(transformers)
        return (tf(img), label)

    def denormalize(self, x_hat):
        if not self.to_tensor:
            raise Exception("To tensor is required to use denormalization")
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x: [c,h,w]
        # mean : [3] => [3,1,1] # 自动broadcast
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        return x_hat * std + mean
