In [28]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn 

In [24]:
train_transform = transforms.Compose([  transforms.RandomRotation(degrees=15), 
                                        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                                        transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
                                        transforms.ToTensor(),])

In [25]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=train_transform
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

In [26]:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
len(train_dataloader), len(test_dataloader)

(938, 157)

In [27]:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])


In [None]:
class SimpleCNN_MNIST(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features= 32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(32 * 8 * 8, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.conv_block(x)                 # [B×32×8×8] if input is [B×3×32×32]
        x = x.view(x.size(0), -1)
        return self.classifier(x)