In [153]:
import gdown
import os, sys
import json, shutil
import tensorflow as tf
import tensorflow.keras as keras
from tensorfllow.keras import layers
from tensorflow.keras.layers import *
from imutils import paths

In [None]:
# download the tensorflow pretrained weights from google drive
swin_pretrained_weights_url = "https://drive.google.com/uc?id=10dSV_LkhQT97LZteVR10ZTACPa1T-TLH"
gdown.download(swin_pretrained_weights_url)

In [29]:
# for model architecture
!git clone https://github.com/shilu10/Swin-Transformer-TF2.git
!mv Swin-Transformer-TF2 swins

In [30]:
!git clone https://github.com/shilu10/CaiT-TF2.git

Cloning into 'CaiT-TF2'...
remote: Enumerating objects: 288, done.[K
remote: Counting objects: 100% (48/48), done.[K
remote: Compressing objects: 100% (45/45), done.[K
remote: Total 288 (delta 9), reused 40 (delta 2), pack-reused 240[K
Receiving objects: 100% (288/288), 807.36 KiB | 7.34 MiB/s, done.
Resolving deltas: 100% (140/140), done.


In [31]:
# imagenet1k evaluation dataset path
!git clone https://github.com/EliSchwartz/imagenet-sample-images.git

Cloning into 'imagenet-sample-images'...
remote: Enumerating objects: 1012, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 1012 (delta 3), reused 5 (delta 2), pack-reused 1002[K
Receiving objects: 100% (1012/1012), 103.84 MiB | 19.89 MiB/s, done.
Resolving deltas: 100% (3/3), done.


In [32]:
!mv CaiT-TF2 CaiT12

In [33]:
with open("/content/CaiT12/imagenet1k_eval/imagenet_labels.json", "r") as file:
  class_labels = json.load(file)

In [34]:
def get_reverse_class_labels(class_labels):
  reverse_class_labels_dict = {}
  for key, value in class_labels.items():
    for cls in value.split(","):
      reverse_class_labels_dict[cls.strip()] = int(key)

  return reverse_class_labels_dict

In [35]:
def preprocess_data(image, label):
  #raw = tf.io.read_file(image_path)
  #image = tf.io.decode_image(raw, expand_animations = False)
  #image = tf.io.decode_jpeg(raw)
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  #image = tf.image.resize(image, (224, 224))
  return image, label

In [36]:
def get_encoded_labels(image_files, reverse_class_labels_dict):
  test_class_labels = []

  for image_file in image_files:
    class_name = image_file.split('/')[1][10: ].split(".")[0]
    class_name = class_name.replace('_', " ")
    class_label = reverse_class_labels_dict[class_name]
    test_class_labels.append(class_label)

  test_class_labels = np.array(test_class_labels)

  return test_class_labels

In [37]:
import cv2

def create_image_arr(image_files, img_size=224):
  images = []
  for image_path in image_files:
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    if len(img.shape) != 3:
      break

    img = cv2.resize(img, (img_size, img_size))
    images.append(img)

  images = np.array(images)

  return images


In [38]:
def create_tf_dataset(images, labels, batch_size=5):
  test_ds = tf.data.Dataset.from_tensor_slices((images, labels))
  test_ds = test_ds.map(lambda x, y: preprocess_data(x, y))
  test_ds = test_ds.batch(batch_size, drop_remainder=True)
  test_ds = test_ds.prefetch(tf.data.AUTOTUNE)

  return test_ds

In [39]:
def make_prediction(model, test_ds):

  acc_obj_func = tf.keras.metrics.SparseCategoricalAccuracy(name="sparse_categorical_accuracy")
  top1_acc_obj_func = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1, name="top1_sparse_categorical_accuracy")
  top5_acc_obj_func = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name="top5_sparse_categorical_accuracy")

  for i, (img_batch, label_batch) in tqdm(enumerate(test_ds), total=len(test_ds)):
    prediction = model(img_batch)

    acc_obj_func.update_state(label_batch, prediction)
    top1_acc_obj_func.update_state(label_batch, prediction)
    top5_acc_obj_func.update_state(label_batch, prediction)

  return acc_obj_func.result().numpy(), top1_acc_obj_func.result().numpy(), top5_acc_obj_func.result().numpy()


In [40]:
reverse_class_labels = get_reverse_class_labels(class_labels)

In [41]:
images_file_path = list(paths.list_images("imagenet-sample-images"))

In [42]:
encoded_test_labels = get_encoded_labels(images_file_path, reverse_class_labels)

In [43]:
images_arr = create_image_arr(images_file_path, img_size=224)

In [44]:
test_ds = create_tf_dataset(images_arr, encoded_test_labels, batch_size=5)
test_ds

<_PrefetchDataset element_spec=(TensorSpec(shape=(5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(5,), dtype=tf.int64, name=None))>

In [None]:
!cd swins
from swin_transformer import SwinTransformer
model = SwinTransformer()
model.load_weights("swin_pretrained_weights/swin_tiny_patch4_window7_224.h5")

In [126]:
acc, top1_acc, top5_acc = make_prediction(model, test_ds)

100%|██████████| 200/200 [01:13<00:00,  2.73it/s]


In [127]:
acc

0.885

In [128]:
top5_acc

0.98