In [2]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms

# 自定义 Dataset
class CustomMNISTDataset(Dataset):
    def __init__(self, transform=None):
        # 使用 torchvision 加载 MNIST 数据集
        self.mnist_data = datasets.MNIST(root='../data', train=True, download=True)
        self.transform = transform

    def __len__(self):
        return len(self.mnist_data)

    def __getitem__(self, idx):
        image, label = self.mnist_data[idx]
        
        # 应用自定义的 transform
        if self.transform:
            image = self.transform(image)
        
        return image, label

# 定义 transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # MNIST 只包含一个通道
])

# 创建 CustomMNISTDataset 实例
custom_dataset = CustomMNISTDataset(transform=transform)

# 获取第一个样本
image, label = custom_dataset[0]
print(f"Label: {label}, Image shape: {image.shape}")


Label: 5, Image shape: torch.Size([1, 28, 28])


In [4]:
from torch.utils.data import DataLoader

# 使用 DataLoader 加载数据集
batch_size = 32
trainloader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)

# 获取一个批次的数据
images, labels = next(iter(trainloader))

print(f"Batch of images: {images.shape}, Batch of labels: {labels.shape}")


Batch of images: torch.Size([32, 1, 28, 28]), Batch of labels: torch.Size([32])


In [None]:
# 定义 MNIST 数据集的常见 transform
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 重新调整大小（MNIST 默认是28x28，但这里示例演示尺寸变化）
    transforms.RandomRotation(30),  # 随机旋转图像
    transforms.ToTensor(),  # 将图像转换为Tensor
    transforms.Normalize((0.5,), (0.5,))  # 归一化
])

# 加载 MNIST 数据集并应用 transform
mnist_data = datasets.MNIST(root='../data', train=True, download=True, transform=transform)

# 创建 DataLoader
mnist_loader = DataLoader(mnist_data, batch_size=32, shuffle=True)

# 获取一个批次的数据
images, labels = next(iter(mnist_loader))
print(f"Batch of images: {images.shape}, Batch of labels: {labels.shape}")


Batch of images: torch.Size([32, 1, 28, 28]), Batch of labels: torch.Size([32])
