Refer for code: https://github.com/lukemelas/PyTorch-Pretrained-ViT

In [1]:
import glob
import os
import re

from PIL import Image
import torch
from torchvision import transforms
from pytorch_pretrained_vit import ViT

In [2]:
# Download ImageNet Labels
# !wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

In [3]:
# Read ImageNet categories
with open('./imagenet_classes.txt', 'r') as f:
    categories = [s.strip() for s in f.readlines()]

print(categories)

['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead', 'electric ray', 'stingray', 'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'robin', 'bulbul', 'jay', 'magpie', 'chickadee', 'water ouzel', 'kite', 'bald eagle', 'vulture', 'great grey owl', 'European fire salamander', 'common newt', 'eft', 'spotted salamander', 'axolotl', 'bullfrog', 'tree frog', 'tailed frog', 'loggerhead', 'leatherback turtle', 'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'common iguana', 'American chameleon', 'whiptail', 'agama', 'frilled lizard', 'alligator lizard', 'Gila monster', 'green lizard', 'African chameleon', 'Komodo dragon', 'African crocodile', 'American alligator', 'triceratops', 'thunder snake', 'ringneck snake', 'hognose snake', 'green snake', 'king snake', 'garter snake', 'water snake', 'vine snake', 'night snake', 'boa constrictor', 'rock python', 'Indian cobra', 'green mamba', 'sea snake', 'horned viper', 'diamondback', 

In [4]:
# Path to all image files (jpg/jpeg)
data_path = '../data/*.j*'

# Creating a mapping in the form of filename: path
basename_path_mapping = {}

for file in glob.glob(data_path):
    basename_path_mapping[os.path.basename(file).split('.')[0]] = file
    
print(basename_path_mapping)

{'crab8': '../data/crab8.jpg', 'couch1': '../data/couch1.jpeg', 'rose3': '../data/rose3.jpg', 'mushroom5': '../data/mushroom5.jpeg', 'rose5': '../data/rose5.jpeg', 'rose6': '../data/rose6.jpg', 'rose4': '../data/rose4.jpeg', 'mushroom4': '../data/mushroom4.jpeg', 'rose7': '../data/rose7.jpg', 'mushroom8': '../data/mushroom8.jpeg', 'dinosaur8': '../data/dinosaur8.jpg', 'mushroom3': '../data/mushroom3.jpeg', 'aquariumfish8': '../data/aquariumfish8.jpg', 'couch7': '../data/couch7.jpeg', 'couch6': '../data/couch6.jpeg', 'rose2': '../data/rose2.jpeg', 'mushroom2': '../data/mushroom2.jpeg', 'aquariumfish1': '../data/aquariumfish1.jpg', 'mushroom1': '../data/mushroom1.jpeg', 'couch5': '../data/couch5.jpeg', 'aquariumfish2': '../data/aquariumfish2.jpg', 'aquariumfish3': '../data/aquariumfish3.jpg', 'aquariumfish7': '../data/aquariumfish7.jpg', 'aquariumfish6': '../data/aquariumfish6.jpg', 'couch4': '../data/couch4.jpeg', 'aquariumfish4': '../data/aquariumfish4.jpg', 'aquariumfish5': '../data/a

In [5]:
# Initializing the ViT model
model = ViT('B_16_imagenet1k', pretrained=True)
model.eval()

Loaded pretrained weights.


ViT(
  (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (positional_embedding): PositionalEmbedding1D()
  (transformer): Transformer(
    (blocks): ModuleList(
      (0): Block(
        (attn): MultiHeadedSelfAttention(
          (proj_q): Linear(in_features=768, out_features=768, bias=True)
          (proj_k): Linear(in_features=768, out_features=768, bias=True)
          (proj_v): Linear(in_features=768, out_features=768, bias=True)
          (drop): Dropout(p=0.1, inplace=False)
        )
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (pwff): PositionWiseFeedForward(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (drop): Dropout(p=0.1, inplace=False)
      )
      (1): Block(
 

In [6]:
# Traversing each image and getting top 15 probabilities using ViT
for label, path in basename_path_mapping.items():
    input_image = Image.open(path)
    preprocess = transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)
    # Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
    # print(output[0])
    # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    # print(probabilities)
    print('-----------------')
    print(label)
    print('-----------------')
    top15_prob, top15_catid = torch.topk(probabilities, 15)
    for i in range(top15_prob.size(0)):
        print(categories[top15_catid[i]], top15_prob[i].item())
    print()
    print()

-----------------
crab8
-----------------
fiddler crab 0.8811163306236267
rock crab 0.09143275767564774
king crab 0.014799821190536022
Dungeness crab 0.005423028953373432
hermit crab 0.0011191068915650249
crayfish 0.0006180530763231218
isopod 0.0006104574422352016
bucket 0.0003821624268312007
spiny lobster 0.00037418559077195823
barn spider 0.00034916569711640477
American lobster 0.00025957185425795615
scorpion 0.0002067642635665834
whiptail 8.623821486253291e-05
wolf spider 7.21437027095817e-05
centipede 5.8764951972989365e-05


-----------------
couch1
-----------------
studio couch 0.683493971824646
velvet 0.021855374798178673
park bench 0.016639310866594315
rocking chair 0.006784731522202492
pillow 0.003509133355692029
hook 0.0021252837032079697
television 0.001935790409334004
wall clock 0.00170386943500489
throne 0.0016014310531318188
home theater 0.0015256119659170508
worm fence 0.0014912920305505395
abacus 0.001400043722242117
fire screen 0.0012710518203675747
tub 0.001264809980

-----------------
aquariumfish1
-----------------
goldfish 0.944821298122406
rock beauty 0.007750560063868761
gar 0.002005645539611578
axolotl 0.001873932546004653
tench 0.001468376605771482
electric ray 0.0014128837501630187
puffer 0.0014059051172807813
eel 0.001240819925442338
barracouta 0.0006603776128031313
terrapin 0.0006490605301223695
coral reef 0.0005666229990310967
anemone fish 0.0004353403637651354
common newt 0.0003996992309112102
hen-of-the-woods 0.000390605564462021
dragonfly 0.00037895995774306357


-----------------
mushroom1
-----------------
agaric 0.5238325595855713
mushroom 0.3958464562892914
earthstar 0.02262791432440281
hen-of-the-woods 0.018186427652835846
stinkhorn 0.015649685636162758
bolete 0.010051422752439976
gyromitra 0.006777067203074694
coral fungus 0.0032449238933622837
eft 6.487260543508455e-05
buckeye 6.073539043427445e-05
slug 4.638858445105143e-05
common newt 3.602137076086365e-05
European fire salamander 3.566651503206231e-05
acorn 2.9788950996589847

-----------------
crab2
-----------------
rock crab 0.8966579437255859
fiddler crab 0.04471419006586075
king crab 0.04116491973400116
Dungeness crab 0.013864537701010704
American lobster 0.0005015460774302483
hermit crab 0.00044036476174369454
spiny lobster 0.00024589133681729436
crayfish 0.00021135390852577984
scorpion 0.00018759128579404205
isopod 0.00017054435738828033
chiton 3.49488909705542e-05
conch 3.0501574656227604e-05
barn spider 2.9936474675196223e-05
night snake 2.946560925920494e-05
centipede 1.97442204807885e-05


-----------------
crab3
-----------------
rock crab 0.40357211232185364
Dungeness crab 0.3126281797885895
fiddler crab 0.18045194447040558
king crab 0.06602153182029724
American lobster 0.01008610799908638
crayfish 0.00788826122879982
spiny lobster 0.0041211070492863655
isopod 0.002299370476976037
hermit crab 0.001364665338769555
barn spider 0.0005854995688423514
scorpion 0.0003645554243121296
jay 0.00020355165179353207
chiton 0.0001706638722680509
tick 0.000159

In [7]:
# Create a mapping of class items to super class
classes = {
    'dinosaur': [51],
    'couch': [831],
    'mushroom': [947, 992, 993, 994, 995, 996, 997],
    'crab': [118, 119, 120, 121, 125],
    'aquariumfish': [1, 5, 107, 123, 124, 327, 395, 396, 397],
    'rose': []
}

# Displaying the names of the classes in imagenet
for key, value in classes.items():
    print('{}: {}'.format(key, [categories[idx] for idx in value]))

dinosaur: ['triceratops']
couch: ['studio couch']
mushroom: ['mushroom', 'agaric', 'gyromitra', 'stinkhorn', 'earthstar', 'hen-of-the-woods', 'bolete']
crab: ['Dungeness crab', 'rock crab', 'fiddler crab', 'king crab', 'hermit crab']
aquariumfish: ['goldfish', 'electric ray', 'jellyfish', 'spiny lobster', 'crayfish', 'starfish', 'gar', 'lionfish', 'puffer']
rose: []


In [8]:
# Traversing each image and getting top 15 probabilities using ViT
for label, path in basename_path_mapping.items():
    input_image = Image.open(path)
    preprocess = transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)
    # Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
    # print(output[0])
    # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    
    # Getting the class probability for the image
    image_class = re.sub(r'\d', '', label)
    print('{}: {}'.format(label, sum([probabilities[idx] for idx in classes[image_class]])))

crab8: 0.9938910603523254
couch1: 0.683493971824646
rose3: 0
mushroom5: 0.9903771877288818
rose5: 0
rose6: 0
rose4: 0
mushroom4: 0.9803153872489929
rose7: 0
mushroom8: 0.996789813041687
dinosaur8: 0.3074803352355957
mushroom3: 0.6795634627342224
aquariumfish8: 0.14485429227352142
couch7: 0.9757304191589355
couch6: 0.8093156218528748
rose2: 0
mushroom2: 0.44311320781707764
aquariumfish1: 0.9505128264427185
mushroom1: 0.9929715394973755
couch5: 0.9808145761489868
aquariumfish2: 0.0011414405889809132
aquariumfish3: 0.306693971157074
aquariumfish7: 0.0014961115084588528
aquariumfish6: 0.9948068857192993
couch4: 0.9767090082168579
aquariumfish4: 0.16591893136501312
aquariumfish5: 0.15849505364894867
couch8: 0.994729220867157
dinosaur5: 0.9514707922935486
crab1: 0.9986653923988342
couch3: 0.07007555663585663
dinosaur4: 0.3790968656539917
dinosaur6: 0.8528656363487244
crab2: 0.9968419671058655
crab3: 0.9640384316444397
mushroom7: 0.9844620823860168
dinosaur7: 0.9997953772544861
dinosaur3: 0.9