In [40]:
!pip install torchmetrics
!pip install torchvision

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [41]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy, Precision, Recall

In [42]:
# Load datasets
from torchvision import datasets
import torchvision.transforms as transforms

train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

In [43]:
# Initializing important variables
num_classes = len(train_data.classes)
num_input_channels = 1
num_output_channels = 16
image_size = train_data[0][0].shape[1]

# Training data loader
dataloader_train = DataLoader(
    train_data,
    batch_size = 10,
    shuffle = True,
)

# Testing data loader
dataloader_test = DataLoader(
    test_data,
    batch_size = 10,
    shuffle = False,
)

In [44]:
# Defining the neural network
class Net(nn.Module):
    def __init__(self, num_classes):
        super(Net, self).__init__()
        # Define feature extractor
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(num_input_channels, num_output_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
        )
        # Define classifier
        self.classifier = nn.Linear(num_output_channels * (image_size//2)**2, num_classes)
    
    def forward(self, x):  
        # Pass input through feature extractor and classifier
        x = self.feature_extractor(x)
        x = self.classifier(x)
        return x

In [45]:
# Define the model
net = Net(num_classes=num_classes)
# Define the loss function
criterion = nn.CrossEntropyLoss()
# Define the optimizer
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(2):
    running_loss = 0.0
    # Iterate over training batches
    for images, labels in dataloader_train:
        optimizer.zero_grad()
        outputs = net(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    epoch_loss = running_loss / len(dataloader_train)
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")

Epoch 1, Loss: 0.3992
Epoch 2, Loss: 0.2935


In [46]:
# Define the metrics
accuracy_metric = Accuracy(task='multiclass', num_classes=num_classes)
precision_metric = Precision(task='multiclass', num_classes=num_classes, average=None)
recall_metric = Recall(task='multiclass', num_classes=num_classes, average=None)

# Run model on test set
net.eval()
predictions = []
for i, (features, labels) in enumerate(dataloader_test):
    output = net.forward(features.reshape(-1, 1, image_size, image_size))
    cat = torch.argmax(output, dim=-1)
    predictions.extend(cat.tolist())
    accuracy_metric(cat, labels)
    precision_metric(cat, labels)
    recall_metric(cat, labels)

# Compute the metrics
accuracy = accuracy_metric.compute().item()
precision = precision_metric.compute().tolist()
recall = recall_metric.compute().tolist()
print('Accuracy:', accuracy)
print('Precision (per class):')
for i, prec in enumerate(precision, start=1):
    print(f'    Class {i}: {prec:.16f}')

print('Recall (per class):')
for i, rec in enumerate(recall, start=1):
    print(f'    Class {i}: {rec:.16f}')

Accuracy: 0.8902000188827515
Precision (per class):
    Class 1: 0.8339843750000000
    Class 2: 0.9888211488723755
    Class 3: 0.8101145029067993
    Class 4: 0.9043210148811340
    Class 5: 0.8162055611610413
    Class 6: 0.9563530683517456
    Class 7: 0.7027027010917664
    Class 8: 0.9732906222343445
    Class 9: 0.9838056564331055
    Class 10: 0.9357622265815735
Recall (per class):
    Class 1: 0.8539999723434448
    Class 2: 0.9729999899864197
    Class 3: 0.8489999771118164
    Class 4: 0.8790000081062317
    Class 5: 0.8259999752044678
    Class 6: 0.9860000014305115
    Class 7: 0.6759999990463257
    Class 8: 0.9110000133514404
    Class 9: 0.9720000028610229
    Class 10: 0.9760000109672546
