In [1]:
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.nn import functional as fn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
from matplotlib import pyplot as plt

In [2]:
# get data
train_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

In [3]:
batch_size = 64

train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(X.shape)
    print(y.shape)
    break

torch.Size([64, 1, 28, 28])
torch.Size([64])


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [5]:
class DeepNet(nn.Module):
    def __init__(self):
        super(DeepNet, self).__init__()
        self.flatten = nn.Flatten()
        self.nn_stack = nn.Sequential(
            nn.Linear(28*28, 200),
            nn.ReLU(),
            nn.Linear(200, 200),
            nn.ReLU(),
            nn.Linear(200, 10)
        )


    def forward(self, input):
        x = self.flatten(input)
        x = self.nn_stack(x)
        return x

In [6]:
model = DeepNet().to(device)
print(model)

DeepNet(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (nn_stack): Sequential(
    (0): Linear(in_features=784, out_features=200, bias=True)
    (1): ReLU()
    (2): Linear(in_features=200, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=10, bias=True)
  )
)


In [7]:
loss_fn = CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)

In [8]:
def train(model, device, dataloader, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)

        # calculate error
        pred = model.forward(X)
        loss = loss_fn(pred, y)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            train_loss = loss.item()
            current = batch * len(X)
            print(f"loss: {train_loss:>7f}  [{current:>5d}/{size:>5d}]")

In [9]:
def test(model, device, dataloader, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
                X = X.to(device)
                y = y.to(device)

                # calculate error
                pred = model.forward(X)
                test_loss += loss_fn(pred, y).item()

                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [10]:
epochs = range(20)

for epoch in epochs:
    print(f"Epoch {epoch}:")
    train(model, device, train_dataloader, loss_fn, optimizer)
    test(model, device, test_dataloader, loss_fn)

Epoch 0:
loss: 2.320418  [    0/60000]
loss: 0.981808  [ 6400/60000]
loss: 0.594293  [12800/60000]
loss: 0.710337  [19200/60000]
loss: 0.593947  [25600/60000]
loss: 0.493177  [32000/60000]
loss: 0.484941  [38400/60000]
loss: 0.639416  [44800/60000]
loss: 0.569331  [51200/60000]
loss: 0.557674  [57600/60000]
Test Error: 
 Accuracy: 81.7%, Avg loss: 0.507733 

Epoch 1:
loss: 0.359407  [    0/60000]
loss: 0.448707  [ 6400/60000]
loss: 0.326884  [12800/60000]
loss: 0.498036  [19200/60000]
loss: 0.448423  [25600/60000]
loss: 0.444438  [32000/60000]
loss: 0.394728  [38400/60000]
loss: 0.551251  [44800/60000]
loss: 0.483355  [51200/60000]
loss: 0.539059  [57600/60000]
Test Error: 
 Accuracy: 83.7%, Avg loss: 0.447450 

Epoch 2:
loss: 0.279203  [    0/60000]
loss: 0.400369  [ 6400/60000]
loss: 0.289571  [12800/60000]
loss: 0.426277  [19200/60000]
loss: 0.408787  [25600/60000]
loss: 0.410262  [32000/60000]
loss: 0.349683  [38400/60000]
loss: 0.525344  [44800/60000]
loss: 0.447980  [51200/60000]