In [None]:
!pip install -U gdown -q

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from dataclasses import dataclass
import numpy as np
import requests
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt

## Setup configuration

In [None]:
RESOLUTION = 224
PATCH_SIZE = 16

## Data augmentation & preprocessing



In [None]:
crop_layer = layers.CenterCrop(RESOLUTION, RESOLUTION)
norm_layer = layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
rescale_layer = layers.Rescaling(scale=1.0/127.5, offset=-1)

In [None]:
def preprocess_image(image, model_type, size=RESOLUTION):
  image = np.array(image)
  image = tf.expand_dims(image, 0)

  # If the `model_type` is ViT, rescale the image to [-1, 1]
  if model_type == 'original_vit':
    image = rescale_layer(image)

  resize_size = int((256 / 224) * size)
  image = tf.image.resize(image, (resize_size, resize_size), method='bicubic')

  image = crop_layer(image)

  # If the `model_type` is Deit or DINO, normalize the image
  if model_type != 'original_vit':
    image = norm_layer(image)

  return image.numpy()

## Load a test image and display it

In [None]:
def load_image_from_url(url, model_type):
  response = requests.get(url)
  image = Image.open(BytesIO(response.content))
  preprocessed_image = preprocess_image(image, model_type)
  return image, preprocessed_image

In [None]:
mapping_file = keras.utils.get_file(
  origin="https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
)

In [None]:
with open(mapping_file, 'r') as f:
  lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

In [None]:
img_url = "https://dl.fbaipublicfiles.com/dino/img.png"
image, preprocessed_image = load_image_from_url(img_url, model_type="original_vit")

plt.imshow(image)
plt.axis("off")
plt.show()