In [7]:
from dataset import ImageDataset, load_dataset
from model import VisionTransformer
import torch
from sklearn.metrics import accuracy_score
import pickle

torch.cuda.empty_cache()

# FashionMNIST0.3
## Load dataset

In [8]:
training_data, training_labels, testing_data, testing_labels = load_dataset('datasets/FashionMNIST0.3.npz')
T = torch.tensor([[0.7, 0.3, 0.0],
                  [0.0, 0.7, 0.3],
                  [0.3, 0.0, 0.7]], dtype=torch.float32)
train_dataset = ImageDataset(training_data, training_labels, transition_matrix=T)
test_dataset = ImageDataset(testing_data, testing_labels)

In [9]:
from pathlib import Path

MODEL_DIR = Path("models")
RESULTS_DIR = Path("results")
MODEL_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)


## Model Training

In [10]:
prediction_results = []
accuracy_results = []
for round in range(10):
    print(f"----------Training VisionTransformer round {round+1}/10----------")
    vit = VisionTransformer(num_classes=3, dataset_name="FashionMNIST0.3")
    vit.train(train_dataset, round=round)
    y_true, y_pred = vit.predict(test_dataset)
    prediction_results.append((y_true, y_pred))
    accuracy = accuracy_score(y_true, y_pred)
    accuracy_results.append(accuracy)
    print(f"ViT Test Acc: {accuracy*100:.2f}%")

with open( RESULTS_DIR / 'vit_fashionmnist0.3_pred_results.pkl', 'wb') as f:
    pickle.dump(prediction_results, f)

with open( RESULTS_DIR / 'vit_fashionmnist0.3_acc_results.pkl', 'wb') as f:
    pickle.dump(accuracy_results, f)

----------Training VisionTransformer round 1/10----------
Epoch [1/100], Training Loss: 1.2220, Validation Loss: 1.1359, Validation Accuracy: 36.00%
Epoch [2/100], Training Loss: 1.0959, Validation Loss: 1.0877, Validation Accuracy: 42.72%
Epoch [3/100], Training Loss: 1.0749, Validation Loss: 1.0550, Validation Accuracy: 52.58%
Epoch [4/100], Training Loss: 1.0343, Validation Loss: 1.0221, Validation Accuracy: 50.64%
Epoch [5/100], Training Loss: 1.0257, Validation Loss: 1.0094, Validation Accuracy: 48.61%
Epoch [6/100], Training Loss: 1.0017, Validation Loss: 0.9956, Validation Accuracy: 46.08%
Epoch [7/100], Training Loss: 0.9976, Validation Loss: 1.0139, Validation Accuracy: 51.47%
Epoch [8/100], Training Loss: 0.9991, Validation Loss: 1.0073, Validation Accuracy: 46.31%
Epoch [9/100], Training Loss: 0.9836, Validation Loss: 1.0089, Validation Accuracy: 46.92%
Epoch [10/100], Training Loss: 0.9878, Validation Loss: 1.0011, Validation Accuracy: 47.08%
Epoch [11/100], Training Loss: 

# FashionMNIST0.6
## Load Dataset

In [11]:
training_data, training_labels, testing_data, testing_labels = load_dataset('datasets/FashionMNIST0.6.npz')
T = torch.tensor([[0.4, 0.3, 0.3],
                  [0.3, 0.4, 0.3],
                  [0.3, 0.3, 0.4]], dtype=torch.float32)
train_dataset = ImageDataset(training_data, training_labels, transition_matrix=T)
test_dataset = ImageDataset(testing_data, testing_labels)

## Model Training

In [12]:
prediction_results = []
accuracy_results = []
for round in range(10):
    print(f"----------Training VisionTransformer round {round+1}/10----------")
    vit = VisionTransformer(num_classes=3, dataset_name="FashionMNIST0.6")
    vit.train(train_dataset, round=round)
    y_true, y_pred = vit.predict(test_dataset)
    prediction_results.append((y_true, y_pred))
    accuracy = accuracy_score(y_true, y_pred)
    accuracy_results.append(accuracy)
    print(f"ViT Test Acc: {accuracy*100:.2f}%")

with open( RESULTS_DIR / 'vit_fashionmnist0.6_pred_results.pkl', 'wb') as f:
    pickle.dump(prediction_results, f)

with open( RESULTS_DIR / 'vit_fashionmnist0.6_acc_results.pkl', 'wb') as f:
    pickle.dump(accuracy_results, f)

----------Training VisionTransformer round 1/10----------
Epoch [1/100], Training Loss: 1.1085, Validation Loss: 1.1104, Validation Accuracy: 32.00%
Epoch [2/100], Training Loss: 1.1087, Validation Loss: 1.1104, Validation Accuracy: 32.00%
Epoch [3/100], Training Loss: 1.1088, Validation Loss: 1.1104, Validation Accuracy: 32.00%
Epoch [4/100], Training Loss: 1.1084, Validation Loss: 1.1055, Validation Accuracy: 34.50%
Epoch [5/100], Training Loss: 1.1066, Validation Loss: 1.1055, Validation Accuracy: 34.50%
Epoch [6/100], Training Loss: 1.1067, Validation Loss: 1.1055, Validation Accuracy: 34.50%
Epoch [7/100], Training Loss: 1.1067, Validation Loss: 1.1055, Validation Accuracy: 34.50%
Epoch [8/100], Training Loss: 1.1068, Validation Loss: 1.1055, Validation Accuracy: 34.50%
Epoch [9/100], Training Loss: 1.1066, Validation Loss: 1.1055, Validation Accuracy: 34.50%
Epoch [10/100], Training Loss: 1.1068, Validation Loss: 1.1055, Validation Accuracy: 34.50%
Epoch [11/100], Training Loss: 

# Visualisation