In [None]:
!git clone https://github.com/tanle8/ai-explained-vit-from-scratch.git


In [None]:
%cd ai-explained-vit-from-scratch

!pip install -q -r requirements.txt

In [None]:
import os
import json
import torch
from vit import ViTForClassfication

from utils import visualize_images, visualize_attention, load_experiment
import matplotlib.pyplot as plt

In [None]:
!nvidia-smi


In [None]:
!python train.py --exp-name "vit_bs256_ep100_run_1" --batch-size 256 --epochs 150 --lr 1e-2 --device cuda


## Result

In [None]:
# Show some training images
visualize_images()

In [None]:
# Load Experiment
config, model, train_losses, test_losses, accuracies = load_experiment("vit_bs256_ep100_run_1")


In [None]:

# Create subplots for losses and accuracy
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Train & Test Loss
ax1.plot(train_losses, label="Train loss")
ax1.plot(test_losses, label="Test loss")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.legend()

# Accuracy
ax2.plot(accuracies)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")

plt.savefig("metrics.png")
plt.show()

In [None]:
visualize_attention(model, "attention.png")

## Save result (Optional)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Set a folder inside your Google Drive to store experiments
DRIVE_BASE_DIR = "/content/drive/MyDrive/vit_experiments"
os.makedirs(DRIVE_BASE_DIR, exist_ok=True)

In [None]:
# Save attention visualization
attention_path = os.path.join(DRIVE_BASE_DIR, "attention.png")
print(f"Attention visualization saved to {attention_path}")

In [None]:
model_path = os.path.join(DRIVE_BASE_DIR, "vit_final_model.pt")

# Save model weights
torch.save(model.state_dict(), model_path)

print(f"Final model saved to {model_path}")

In [None]:
# Save config
config_path = os.path.join(DRIVE_BASE_DIR, "config.json")
with open(config_path, 'w') as f:
    json.dump(config, f, sort_keys=True, indent=4)

# Save metrics
metrics_data = {
    "train_losses": train_losses,
    "test_losses": test_losses,
    "accuracies": accuracies
}
metrics_json_path = os.path.join(DRIVE_BASE_DIR, "metrics.json")
with open(metrics_json_path, 'w') as f:
    json.dump(metrics_data, f, sort_keys=True, indent=4)

print(f"Config and metrics saved to {DRIVE_BASE_DIR}")