# MNIST Pretrained Vision Transformer (ViT) Model

In [1]:
!pip3 install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
from tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, Dropout, Dense, Flatten

import tensorflow as tf
from vit import ViT

### Global Variables



In [3]:
IMAGE_SIZE = 28
IMAGE_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 1)
PATCH_SIZE = 4
NUM_CLASSES = 10
DATASET_SIZE = 10000
BATCH_SIZE = DATASET_SIZE

### Load datasets

In [4]:
test_ds, test_info = tfds.load(
    'mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True,
    with_info=True
)

### Preprocess Data

In [5]:
def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def preprocess_ds(ds):
  ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds = ds.cache()
  ds = ds.batch(BATCH_SIZE)
  ds = ds.prefetch(tf.data.AUTOTUNE)
  return ds

In [6]:
test_ds = preprocess_ds(test_ds)

### Load Pre-trained Model

In [7]:
vit = ViT(
      image_size = IMAGE_SIZE,
      patch_size = PATCH_SIZE,
      num_classes = NUM_CLASSES,
      dim = 1024,
      depth = 6,
      heads = 16,
      mlp_dim = NUM_CLASSES,
      dropout = 0.1,
      emb_dropout = 0.1
    )

### Evaluate Model


In [8]:
loss = tf.keras.losses.SparseCategoricalCrossentropy()
accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

def evaluate_vit(vit, test_ds):
  total_loss = 0
  total_accuracy = 0

  for X, y in test_ds:
    pred = vit(X)
    total_loss = total_loss + loss(pred, y)
    total_accuracy = total_accuracy + accuracy(pred, y)
  
  avg_loss = total_loss / DATASET_SIZE
  avg_accuracy = total_accuracy / DATASET_SIZE
  return avg_loss, avg_accuracy


In [None]:
loss, accuracy = evaluate_vit(vit, test_ds)