In [None]:
from dataset import ImageDataset, load_dataset
from model import VisionTransformer
import torch
from sklearn.metrics import accuracy_score
import pickle

torch.cuda.empty_cache()

# FashionMNIST0.3
## Load dataset

In [None]:
training_data, training_labels, testing_data, testing_labels = load_dataset('datasets/FashionMNIST0.3.npz')
T = torch.tensor([[0.7, 0.3, 0.0],
                  [0.0, 0.7, 0.3],
                  [0.3, 0.0, 0.7]], dtype=torch.float32)
train_dataset = ImageDataset(training_data, training_labels, transition_matrix=T)
test_dataset = ImageDataset(testing_data, testing_labels)

## Model Training

In [None]:
prediction_results = []
accuracy_results = []
for round in range(10):
    print(f"----------Training VisionTransformer round {round+1}/10----------")
    vit = VisionTransformer(num_classes=3, dataset_name="FashionMNIST0.3")
    vit.train(train_dataset, round=round)
    y_true, y_pred = vit.predict(test_dataset)
    prediction_results.append((y_true, y_pred))
    accuracy = accuracy_score(y_true, y_pred)
    accuracy_results.append(accuracy)
    print(f"ViT Test Acc: {accuracy*100:.2f}%")

with open('results/vit_fashionmnist0.3_pred_results.pkl', 'wb') as f:
    pickle.dump(prediction_results, f)

with open('results/vit_fashionmnist0.3_acc_results.pkl', 'wb') as f:
    pickle.dump(accuracy_results, f)

# FashionMNIST0.6
## Load Dataset

In [None]:
training_data, training_labels, testing_data, testing_labels = load_dataset('datasets/FashionMNIST0.6.npz')
T = torch.tensor([[0.4, 0.3, 0.3],
                  [0.3, 0.4, 0.3],
                  [0.3, 0.3, 0.4]], dtype=torch.float32)
train_dataset = ImageDataset(training_data, training_labels, transition_matrix=T)
test_dataset = ImageDataset(testing_data, testing_labels)

## Model Training

In [None]:
prediction_results = []
accuracy_results = []
for round in range(10):
    print(f"----------Training VisionTransformer round {round+1}/10----------")
    vit = VisionTransformer(num_classes=3, dataset_name="FashionMNIST0.6")
    vit.train(train_dataset, round=round)
    y_true, y_pred = vit.predict(test_dataset)
    prediction_results.append((y_true, y_pred))
    accuracy = accuracy_score(y_true, y_pred)
    accuracy_results.append(accuracy)
    print(f"ViT Test Acc: {accuracy*100:.2f}%")

with open('results/vit_fashionmnist0.6_pred_results.pkl', 'wb') as f:
    pickle.dump(prediction_results, f)

with open('results/vit_fashionmnist0.6_acc_results.pkl', 'wb') as f:
    pickle.dump(accuracy_results, f)

# Visualisation