## MINIST 手写数字识别数据集

In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(), # 将数据转换为Tensor格式
    transforms.Normalize((0.5,), (0.5,)) # 对图像进行标准化
])

# 加载训练集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# 创建训练集数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 加载测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# 创建测试集数据加载器
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [2]:
for images, labels in train_loader:
    print(f'images shape : {images.shape}')
    print(f'labels shape : {labels.shape}')
    break

images shape : torch.Size([64, 1, 28, 28])
labels shape : torch.Size([64])


## 重新实现数据集和数据集加载器

加载汽车图像分割数据集

In [3]:
from torch import nn
from torch.utils.data import Dataset


# 想要自己实现一个 PyTorch 中可用的数据集
# 就要继承 Dataset 并且实现 __getitem__ && __len__ 两个方法
class TemplateDataset(Dataset):
    def __init__(self):
        super().__init__()
        
        
    def __len__(self):
        pass
        
    def __getitem__(self, i):
        pass

In [4]:
import os
from PIL import Image
from torchvision.transforms import ToTensor


class BasicDataset(nn.Module):
    def __init__(self, data_dir):
        super().__init__()
        self.data_dir = data_dir
        files = os.listdir(os.path.join(self.data_dir, 'imgs'))
        self.files = [item.split('.')[0] for item in files]
        self.convert = ToTensor()
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        img = self._load_img(self.files[i], is_mask=False)
        mask = self._load_img(self.files[i], is_mask=True)
        
        img = self.convert(img)
        mask = self.convert(mask)
        
        return img, mask
    
    def _load_img(self, file_name, is_mask=False):
        data_dir = self.data_dir
        
        if is_mask:
            file_name = file_name + '_mask.gif'
            data_dir = os.path.join(data_dir, 'masks')
        else:
            file_name = file_name + '.jpg'
            data_dir = os.path.join(data_dir, 'imgs')
        
        return Image.open(os.path.join(data_dir, file_name))

In [5]:
data_dir = '/Volumes/SSD/SSD/blueberry/datasets/UNet'  # your data path
dataset = BasicDataset(data_dir=data_dir)

In [6]:
img, mask = dataset[0]

In [7]:
img.shape, mask.shape

(torch.Size([3, 1280, 1918]), torch.Size([1, 1280, 1918]))

In [8]:
data_loader = DataLoader(dataset, batch_size=64, shuffle=False)

In [9]:
for img, mask in data_loader:
    print(img.shape)
    print(mask.shape)
    break

torch.Size([64, 3, 1280, 1918])
torch.Size([64, 1, 1280, 1918])
