In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import random

In [None]:
# >>> Load a pre-trained MobileNetV3 CNN and its associated transform
neural_net = None
transform = None

In [None]:
# >>> Overwrite the neural_net.classifier[3] layer with a linear layer having 10 outputs
neural_net.classifier[3] = None

In [None]:
# Training parameters
learning_rate = 0.001
weight_decay = 0.000001
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(neural_net.parameters(), lr=learning_rate, weight_decay=weight_decay)


In [None]:
# Load the STL10 dataset
stl10_dataset = torchvision.datasets.STL10(root='./', split='train', transform=transform,
                                               download=True)
train_dataset, validation_dataset = torch.utils.data.random_split(stl10_dataset, [0.8, 0.2])
batch_size = 16
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=None)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=batch_size, sampler=None)
# List of human-readable classes
stl10_classes = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']

In [None]:
# Display some images from the dataset
# Create an array of images
number_of_rows = 6
number_of_columns = 4
fig, axs = plt.subplots(number_of_rows, number_of_columns)
for row in range(number_of_rows):
    for col in range(number_of_columns):
        index = random.randint(0, len(stl10_dataset) - 1)  # Choose a random index from the dataset
        img_tsr, class_ndx = stl10_dataset[index]  # Get the image tensor and the target class index
        img_tsr = torch.moveaxis(img_tsr, 0, 2)  # (C, H, W) -> (H, W, C)
        # Set the range to [0, 1]
        min_val = torch.min(img_tsr)
        max_val = torch.max(img_tsr)
        img_tsr = (img_tsr - min_val)/(max_val - min_val)
        axs[row, col].imshow(img_tsr.squeeze(0).numpy())
plt.show()

In [None]:
# Record statistics
epochs = []
train_losses = []
validation_losses = []
accuracies = []
number_of_epochs = 4
highest_validation_accuracy = 0.0
champion_neural_net = None

def numberOfCorrectPredictions(predictions_tsr, target_class_tsr):
    return sum(torch.argmax(predictions_tsr, dim=1) == target_class_tsr).item()

In [None]:
# The training loop
# >>> Write the code that must run in the training and validation parts of the loop
for epoch in range(1, number_of_epochs + 1):
    # Set the neural network to training mode
    neural_net.train()
    running_loss = 0.0
    number_of_batches = 0
    for input_tsr, target_class_tsr in train_dataloader:
        # >>> Write the code that must run in the training part of the loop

        running_loss += loss.item()
        number_of_batches += 1
        if number_of_batches % 10 == 1:
            print('.', flush=True, end='')
    average_training_loss = running_loss / number_of_batches

    # Evaluate with the validation dataset
    # Set the neural network to evaluation (inference) mode
    neural_net.eval()
    validation_running_loss = 0.0
    number_of_batches = 0
    number_of_correct_predictions = 0
    number_of_predictions = 0
    for validation_input_tsr, validation_target_output_tsr in validation_dataloader:
        # >>> Write the code that must run in the validation part of the loop
        
        number_of_correct_predictions += numberOfCorrectPredictions(validation_output_tsr,
                                                                    validation_target_output_tsr)
        number_of_predictions += validation_input_tsr.shape[0]
        number_of_batches += 1
    average_validation_loss = validation_running_loss / number_of_batches
    accuracy = number_of_correct_predictions / number_of_predictions
    print(
        f"Epoch {epoch}: average_training_loss = {average_training_loss}; average_validation_loss = {average_validation_loss}; accuracy = {accuracy}")
    epochs.append(epoch)
    train_losses.append(average_training_loss)
    validation_losses.append(average_validation_loss)
    accuracies.append(accuracy)

    if (accuracy > highest_validation_accuracy):
        print(f" * * * * Champion! * * * * ")
        torch.save(neural_net.state_dict(), "./neural_net.pth")
        highest_validation_accuracy = accuracy
        champion_neural_net = neural_net

In [None]:
# Display the evolution of the metrics
fig1, ax1 = plt.subplots()
ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')
ax1.plot(epochs, train_losses, color='b', label='Training loss')
ax1.plot(epochs, validation_losses, color='r', label='Validation loss')
ax1.grid(True)
ax1.legend(loc='right')
ax2 = ax1.twinx()  # Instantiate a second axis that shares the same x-axis
ax2.set_ylabel('accuracy', color='g')
ax2.plot(epochs, accuracies, color='g', label='Accuracy')
ax2.legend(loc='upper right')
plt.show()

In [None]:
# Test the champion neural network
stl10_test_dataset = torchvision.datasets.STL10(root="./", split='test', transform=transform, download=False)

In [None]:
# >>> Choose a random index in [0, 7999]
sample_ndx = None
sample_test_tsr, class_ndx = stl10_test_dataset[sample_ndx]
champion_neural_net.eval()
sample_output_tsr = champion_neural_net(sample_test_tsr.unsqueeze(0) )
print(f"sample_output_tsr = \n{sample_output_tsr}")
predicted_class = torch.argmax(sample_output_tsr, dim=1).item()
print(f"predicted_class = {predicted_class} ({stl10_classes[predicted_class]}); True class = {class_ndx} ({stl10_classes[class_ndx]})")

In [None]:
sample_test_tsr = torch.moveaxis(sample_test_tsr, 0, 2)
# Set the range to [0, 1]
min_val = torch.min(sample_test_tsr)
max_val = torch.max(sample_test_tsr)
sample_test_tsr = (sample_test_tsr - min_val)/(max_val - min_val)
plt.imshow(sample_test_tsr.squeeze(0))