Skip to content

s-chh/PyTorch-Vision-Transformer-ViT-MNIST-CIFAR10

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vision Transformer for MNIST and CIFAR10

Simplified Scratch Pytorch implementation of Vision Transformer (ViT) with detailed steps (Refer to model.py).

  • Scaled-down version of the original ViT architecture from An Image is Worth 16X16 Words.
  • Has only 200k-800k parameters depending upon the embedding dimension (Original ViT-Base has 86 million).
  • Works with small datasets by using a smaller patch size of 4.
  • Supported datasets: MNIST, FashionMNIST, SVHN, and CIFAR10.



Run commands (also available in scripts.sh):

Dataset Run command Test Acc
MNIST python main.py --dataset mnist --epochs 100 99.5
Fashion MNIST python main.py --dataset fmnist 92.3
SVHN python main.py --dataset svhn --n_channels 3 --image_size 32 --embed_dim 128 96.2
CIFAR10 python main.py --dataset cifar10 --n_channels 3 --image_size 32 --embed_dim 128 82.5 (86.3 with RandAug)



Transformer Config:

Config MNIST and FMNIST SVHN and CIFAR10
Input Size 1 X 28 X 28 3 X 32 X 32
Patch Size 4 4
Sequence Length 7*7 = 49 8*8 = 64
Embedding Size 64 128
Parameters 210k 820k
Num of Layers 6 6
Num of Heads 4 4
Forward Multiplier 2 2
Dropout 0.1 0.1



Training Graphs:

Dataset Accuracy Loss
MNIST MNIST_accuracy MNIST_loss
FMNIST FMNIST_accuracy FMNIST_loss
SVHN SVHN_accuracy SVHN_loss
CIFAR10 CIFAR10_accuracy CIFAR10_loss