<a href="https://colab.research.google.com/github/rexbrandy/classify_digits/blob/main/classify_digits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor


def get_dataloaders(batch_size=64, visualize=True):
    train_dataset = datasets.FashionMNIST(
        root='data',
        train=True,
        download=True,
        transform=ToTensor()
    )

    test_dataset = datasets.FashionMNIST(
        root='data',
        train=False,
        download=True,
        transform=ToTensor()
    )

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    if visualize:
        visualize_data(train_dataset)

    return train_dataloader, test_dataloader


In [2]:
import matplotlib.pyplot as plt


def visualize_data(training_data):
    labels_map = {
        0: "T-Shirt",
        1: "Trouser",
        2: "Pullover",
        3: "Dress",
        4: "Coat",
        5: "Sandal",
        6: "Shirt",
        7: "Sneaker",
        8: "Bag",
        9: "Ankle Boot",
    }
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(training_data), size=(1,)).item()
        img, label = training_data[sample_idx]
        figure.add_subplot(rows, cols, i)
        plt.title(labels_map[label])
        plt.axis("off")
        plt.imshow(img.squeeze(), cmap="gray")
    plt.show()



In [3]:
import torch.nn as nn

class FeedForwardNet(nn.Module):
    def __init__(self):
        super(FeedForwardNet, self).__init__()
        self.flatten = nn.Flatten

        self.sequential = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.Linear(512, 128),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        output = self.flatten(x)
        output = self.sequential(output)

        return output


In [12]:
def train(model, dataloader, criterion, optimizer, n_epochs=20):
    model.train()

    epoch_loss = []

    for epoch in range(n_epochs):
        for batch, (X, y) in enumerate(dataloader):
            pred = model(X)
            loss = criterion(pred, y)

            loss.backprop()

            optimizer.step()
            optimizer.zero_grad()



In [11]:
n_epochs = 20
lr = 0.01

model = FeedForwardNet()
criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=lr)

training_dataloader, test_dataloader = get_dataloaders(visualize=False)

train(model, training_dataloader, criterion, optim, n_epochs)


0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
