In [1]:
# 1.a Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor())

testing_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor())

training_data.data.shape, testing_data.data.shape

(torch.Size([60000, 28, 28]), torch.Size([10000, 28, 28]))

In [2]:
# 1.b DataLoader
from torch.utils.data import DataLoader

training_dataloader = DataLoader(training_data, 64, shuffle=True)
testing_dataloader = DataLoader(testing_data, 64, shuffle=False)

X, y = next(iter(training_dataloader))
X.shape, y.shape

(torch.Size([64, 1, 28, 28]), torch.Size([64]))

In [3]:
# play with conv layer
from torch import nn
broken_cnn = nn.Sequential(
                nn.Conv2d(1, 30, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(30, 1, kernel_size=3, padding=1)
            )
print(broken_cnn(X).shape)
# build a conv layer with basic parameter
# n input channel, n output feature, ks:kenel size, act:activate
def conv(ni, nf, ks=3, act=True):
    layers = [nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)]
    layers.append(nn.BatchNorm2d(nf))
    if act: layers.append(nn.ReLU())
    res = nn.Sequential(*layers)
    return res

def linear(ni, nf, act=True):
    res = nn.Linear(ni, nf)
    if act: res = nn.Sequential(res, nn.ReLU())
    return res

# refactor code to using basic parameter simplify the function. easy to debug and more obvious to readers.
# note: capacity of layer is the number of activation, we using stride=2 then we need to x4 features to
# keep the capacity the same.

# we can add comment to each conv to make sure we have 
simple_cnn = nn.Sequential(
                           conv(1,4),               #14x14
                           conv(4,8),               #7x7
                           conv(8,16),              #4x4
                           conv(16, 32),            #2x2
                           nn.Flatten(),            #128
                           linear(128, 10),
)

print(X.shape, simple_cnn(X).shape)

torch.Size([64, 1, 28, 28])
torch.Size([64, 1, 28, 28]) torch.Size([64, 10])


In [4]:
# 2. Model
from torch import nn

# This is simple model that take input -> linear -> relu -> linear -> output.
class ConvModel(nn.Module):
    def __init__(self):
        super(ConvModel, self).__init__()
        self.model = nn.Sequential(
                           conv(1,4),               #14x14
                           conv(4,8),               #7x7
                           conv(8,16),              #4x4
                           conv(16, 32),            #2x2
                           nn.Flatten(),            #128
                           linear(128, 10),
                    )
    def forward(self, x):
        out = self.model(x)
        return out
    
model = ConvModel()

y_hat = model(X)

In [5]:
# 3. Loss
loss_fn = nn.CrossEntropyLoss()

loss = loss_fn(y_hat, y)

loss.backward()

In [6]:
# 4. Optimizer SGD
from torch.optim import SGD

In [7]:
# 5. Combine things together:
# train model
def train(model, dataloader, optimizer, epochs=3):
    size = len(dataloader.dataset)
    for epoch in range(epochs):
        for batch, (X, y) in enumerate(dataloader):
            y_hat = model(X)
            loss = loss_fn(y_hat, y)

            # backward
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if batch % 300 == 0:
                loss, current = loss.item(), batch * len(X)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

for lr in [0.03, 0.01, 0.001]:
    optimizer = SGD(model.parameters(), lr)
    train(model, training_dataloader, optimizer)

loss: 2.356814  [    0/60000]
loss: 0.693739  [19200/60000]
loss: 0.411766  [38400/60000]
loss: 0.624989  [57600/60000]
loss: 0.349086  [    0/60000]
loss: 0.239517  [19200/60000]
loss: 0.395932  [38400/60000]
loss: 0.253282  [57600/60000]
loss: 0.469419  [    0/60000]
loss: 0.417089  [19200/60000]
loss: 0.252391  [38400/60000]
loss: 0.257266  [57600/60000]
loss: 0.349539  [    0/60000]
loss: 0.241535  [19200/60000]
loss: 0.300125  [38400/60000]
loss: 0.335051  [57600/60000]
loss: 0.345725  [    0/60000]
loss: 0.403337  [19200/60000]
loss: 0.376659  [38400/60000]
loss: 0.196540  [57600/60000]
loss: 0.356316  [    0/60000]
loss: 0.297076  [19200/60000]
loss: 0.454761  [38400/60000]
loss: 0.395886  [57600/60000]
loss: 0.367261  [    0/60000]
loss: 0.429320  [19200/60000]
loss: 0.234397  [38400/60000]
loss: 0.196323  [57600/60000]
loss: 0.308056  [    0/60000]
loss: 0.348796  [19200/60000]
loss: 0.371335  [38400/60000]
loss: 0.552765  [57600/60000]
loss: 0.250464  [    0/60000]
loss: 0.25

In [8]:
size = len(testing_dataloader.dataset)
total = 0
for X, y in testing_dataloader:
    y_hat = model(X)
    total += sum(y_hat.argmax(1) == y).item()
print(f'Accuracy: {total/size:>2f}')

Accuracy: 0.876000
