In [1]:
from dataset import ImageDataset, load_dataset, train_val_split, data_augmentation
from model import CNN, CNNWithNAL
import torch
from sklearn.metrics import accuracy_score
import pickle

torch.cuda.empty_cache()

# CIFAR
## Load dataset

In [2]:
training_data, training_labels, testing_data, testing_labels = load_dataset('datasets/CIFAR.npz')

## Preprocessing

In [3]:
training_subset, training_sub_labels, validation_subset, validation_sub_labels = train_val_split(training_data, training_labels)
aug_training, aug_labels = data_augmentation(training_subset, training_sub_labels)

train_dataset = ImageDataset(training_subset, training_sub_labels)
val_dataset = ImageDataset(validation_subset, validation_sub_labels)
test_dataset = ImageDataset(testing_data, testing_labels)

## CNN with Noise Adaption Layer Training

In [4]:
import os
if not os.path.exists('results/cnnwithnal_CIFAR_pred_results.pkl'):
    prediction_results = []
    accuracy_results = []
    for round in range(10):
        print(f"----------Training CNN round {round+1}/10----------")
        cnn = CNNWithNAL(num_classes=3, dataset_name="CIFAR", batch_size=256, learning_rate=0.0005, patience=5)
        cnn.train(train_dataset, val_dataset)
        y_true, y_pred = cnn.predict(test_dataset)
        prediction_results.append((y_true, y_pred))
        accuracy = accuracy_score(y_true, y_pred)
        accuracy_results.append(accuracy)
        print(f"CNN Test Acc: {accuracy*100:.2f}%")

    with open('results/cnnwithnal_CIFAR_pred_results.pkl', 'wb') as f:
        pickle.dump(prediction_results, f)

    with open('results/cnnwithnal_CIFAR_acc_results.pkl', 'wb') as f:
        pickle.dump(accuracy_results, f)

----------Training CNN round 1/10----------
Training noisy CNN to estimate NAL Layer params...
Epoch [1/100], Training Loss: 1.1693, Validation Loss: 1.1634, Validation Accuracy: 31.97%
Epoch [2/100], Training Loss: 1.0822, Validation Loss: 1.1325, Validation Accuracy: 35.03%
Epoch [3/100], Training Loss: 1.0452, Validation Loss: 1.1654, Validation Accuracy: 35.47%
Epoch [4/100], Training Loss: 0.9973, Validation Loss: 1.1948, Validation Accuracy: 35.33%
Epoch [5/100], Training Loss: 0.8859, Validation Loss: 1.4590, Validation Accuracy: 34.23%
Epoch [6/100], Training Loss: 0.7152, Validation Loss: 1.4837, Validation Accuracy: 36.10%
Epoch [7/100], Training Loss: 0.4795, Validation Loss: 2.1679, Validation Accuracy: 33.90%
No improvement for 5 epochs. Early stopping.
Epoch [1/100], Training Loss: 1.1339, Validation Loss: 1.1134, Validation Accuracy: 36.47%
Epoch [2/100], Training Loss: 1.0821, Validation Loss: 1.1014, Validation Accuracy: 37.50%
Epoch [3/100], Training Loss: 1.0550, Val

KeyboardInterrupt: 