In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision

from matplotlib import pyplot as plt
import numpy as np

print(torch.__version__)

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# Download training data from open datasets.
training_data = torchvision.datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)


# Download test data from open datasets.
test_data = torchvision.datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)


In [None]:
print(training_data.data.size())
print(test_data.data.size())

In [None]:
# from: https://www.kaggle.com/code/pankajj/fashion-mnist-with-pytorch-93-accuracy
def output_label(label):
    output_mapping = {
                 0: "T-shirt/Top",
                 1: "Trouser",
                 2: "Pullover",
                 3: "Dress",
                 4: "Coat", 
                 5: "Sandal", 
                 6: "Shirt",
                 7: "Sneaker",
                 8: "Bag",
                 9: "Ankle Boot"
                 }
    input = (label.item() if type(label) == torch.Tensor else label)
    return output_mapping[input]



In [None]:
def plot_first_samples(dataloader):
    rows = 8
    columns = 8
    fig = plt.figure(figsize=(10, 10), layout="constrained")
    for X, y in dataloader:
        for i, (image, label) in enumerate(zip(X, y)):
            fig.add_subplot(rows, columns, i+1)
            plt.axis('off')
            plt.title(output_label(label.item()), size="small")
            plt.imshow(image[0].detach().numpy(), cmap='gray')
        break

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

plot_first_samples(test_dataloader)
plt.show()


In [None]:
# Feed Forward Neural Network Autoencoder

class Classifier(nn.Module):
    def __init__(self, input_shape : torch.Size, hidden_sizes : list[int], num_classes : int, hidden_activation_function, loss_fn):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential( # todo use hidden sizes
            nn.Linear(input_shape.numel(), 512),
            hidden_activation_function,
            nn.Linear(512, 512),
            hidden_activation_function,
            nn.Linear(512, 512),
            hidden_activation_function,
            nn.Linear(512, num_classes) # no softmax necessary
        )

        self.num_classes = num_classes
        self.loss_fn = loss_fn

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


    def train_network(self, dataloader : DataLoader, optimizer : torch.optim.Optimizer):
        size = len(dataloader.dataset)
        self.train()
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)

            # Compute prediction error
            y_pred = self.forward(X)
            loss = self.loss_fn(y_pred, F.one_hot(y, self.num_classes).type(y_pred.dtype))

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch % 100 == 0:
                loss, current = loss.item(), (batch + 1) * len(X)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    def test(self, dataloader : DataLoader):
        num_batches = len(dataloader)
        self.eval()
        test_loss = 0
        num_right = 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                y_pred  = self.forward(X)
                test_loss += self.loss_fn(y_pred , F.one_hot(y, self.num_classes).type(y_pred.dtype)).item()
                num_right += (torch.argmax(y_pred, dim=1) == y).sum().item()
        test_loss /= num_batches
        accuracy = num_right / len(dataloader.dataset)
        print(f"Test Error: Avg loss: {test_loss:>8f} Accuracy: {accuracy:>8f} \n")

In [None]:
# Train Classifier
size = torch.Size((1,28,28))
num_classes = 10
classifier = Classifier(size, [], num_classes, nn.ReLU(), nn.CrossEntropyLoss()).to(device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    classifier.train_network(train_dataloader, optimizer)
    classifier.test(test_dataloader)
print("Done!")

In [None]:
def plot_classified_first_samples(dataloader, model):
    rows = 8
    columns = 8
    fig = plt.figure(figsize=(10, 10), layout="constrained")
    with torch.no_grad():
        for X, y in dataloader:
            y_pred = torch.argmax(model.forward(X.to(device)), dim=1).to('cpu')
            for i, (image, true_label, predicted_label) in enumerate(zip(X, y, y_pred)):
                fig.add_subplot(rows, columns, i+1)
                plt.axis('off')
                plt.title(output_label(predicted_label.item()) + ("\nCorrect: " + output_label(true_label) if predicted_label != true_label else ""),
                          color="green" if predicted_label == true_label else "red",
                          size="small")
                plt.imshow(image[0].detach().numpy(), cmap='gray')
            break

plot_classified_first_samples(test_dataloader, classifier)
plt.show()