Skip to content

This repo contains my implementation of ViT in PyTorch from scratch. Here I am using MNIST as toy dataset so that anyone with low hardware specs can train the transformer in few minutes :)

License

Notifications You must be signed in to change notification settings

Shubhamai/pytorch-vit

Repository files navigation

PyTorch ViT

License: MIT

This repo contains my implementation of the ViT paper - An image is worth 16x16 words : Transformers for image recognition at scale in PyTorch from scratch ( for educational purposes ). Here I am using MNIST as toy dataset so that anyone with low hardware specs can train the transformer in few minutes :) . You can view the specs I used to train this transformer here - Hardware Requirements

Since, this project is also highly inspired by the mrdbourke's paper replication project as part of his pytorch course, I also have added his dataset to train the model on.

What is ViT ?


Figure 1. Vision Transformer Architecture

ViT is a short form of Vision Transformer, a paper introduced by Dosovitskiy et al. in 2020. The paper is called An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale and uses transformers for computer vision tasks like classification.

Previously, transformers were mostly used in the NLP domain such as text generation - BERT, RoBERTa, XLNet, GPT-2, etc. But due to their powerful attention mechanism, there was a lot of interest in using transforms for computer vision tasks.

1. Input Embeddings

One of the main challenges in using transformers for computer vision is the sheer amount of computational cost needed to directly use image pixels as input embeddings. Instead, we first split the images into multiple patches, and run them through a convolution layer to decrease their size and generate the embeddings.

Original image Patches

The convolution layer also learns through backpropagation to generate meaningful features from the image to be used as embeddings.

After generating the patch embeddings, we add class token and position encoding to the embeddings.


Figure 2: Position embeddings of models trained with different hyperparameters.

2. Transformer Encoder

One we have the embedded patches ( see figure 1. ), we then send it through the Transformer Encoder.


Figure 3. Transformer Encoder

As you can probably see from the image, the transformer encoder does a lot of things, but they are very easy to understand and implement given enough time. I have shared some cool resources in the Acknowledgments section to help you understand transformer encoder.

These encoder block outputs the same dimension as the input one, so we can easily stack them together capturing more information and better generalization.

3. MLP Head

The MLP head is a fully connected layer that takes the output of the Transformer Encoder and generates the final output. It is simply a combination of a layer normalization and a fully connected layer.

Setup

  1. Clone the repo - git clone https://github.com/Shubhamai/pytorch-vit/
  2. Create a new conda environment using conda env create --prefix env python=3.7.13 --file=environment.yml
  3. Activate the conda environment using conda activate ./env

Usage

Training

You can start the training by running python train.py. You can view the default parameters by running python train.py -h. There are configs file also available in the configs folder, you can use them to train the model by running ex. - python train.py --config_path configs/mnist.yaml. The arguments passed to in cli will be overridden by the config file.

Here what mostly the script does -

  • Downloads the MNIST dataset automatically and save in the data directory.
  • Training the ViT model on the MNIST train set with parameters passed in the CLI.
  • The trainings metrics are saved in the experiments/results directory with the corresponding data name.
  • The model is automatically saved in the experiments/models/ directory with the corresponding data name.


Training Results from 2 epochs on MNIST train set

Testing

To test the model, simply run python test.py . You can view the default parameters by running python train.py -h. There are configs file also available in the configs folder, you can use them to train the model by running ex. - python test.py --config_path configs/mnist.yaml. The arguments passed to in cli will be overridden by the config file.

Here what mostly the script does -

  • The model will be automatically loaded from experiments/models/ directory with the corresponding data name.
  • Test the model on the MNIST test set with the parameters passed in the CLI.
  • The test metrics are saved in the experiments/results directory with the corresponding data name.


Testing Results on MNIST test set

Notes

  1. The hidden dim in MLP block is calculated by simply multiplying the input dim by a factor of 4, it is due to a similar pattern observed in Table 1. of the ViT paper.
  2. To hidden dim in Attention Head for the feed forward layers for query, key and value is calculated by embed_dim // n_heads . I am still unsure of why we do it this way.

FAQ 4 myself :)

What's the point of value, key and query in Multi-Head Attention [This video](https://youtu.be/H-4bmOxiKyU) from Alex-AI helped me a lot to understand this.
Why position embeddings are simply added to the embeddings ? I found [this video](https://youtu.be/M2ToEXF6Olw) from AI Coffee Break with Letitia very helpful to understand this.
Why is class token added in the embeddings ? I find [this answer](https://datascience.stackexchange.com/a/110637) from datascience stackexchange to be the most satisfying answer
Why 0th index for the MLP Head input? One realization is that the 0th index is actually the class token embedding we concatenated when creating the input embeddings. The question `Why is class token added in the embeddings ?` answers the rest.

Hardware Requirements

The hardware which I tried the model on default settings is -

  • Ryzen 5 4600H
  • NVIDIA GeForce GTX 1660Ti - 6 GB VRAM
  • 12 GB ram

It took around 2 min per epoch on my machine. Since, google colab has similar hardware in terms compute power from what I understand, it should run just fine on colab :)

Acknowledgments

I found these resources helpful in creating this project:

Citation

@misc{https://doi.org/10.48550/arxiv.2010.11929,
  doi = {10.48550/ARXIV.2010.11929},
  url = {https://arxiv.org/abs/2010.11929},
  author = {Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
  keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
  publisher = {arXiv},
  year = {2020},
  copyright = {arXiv.org perpetual, non-exclusive license}
}

About

This repo contains my implementation of ViT in PyTorch from scratch. Here I am using MNIST as toy dataset so that anyone with low hardware specs can train the transformer in few minutes :)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages