In [1]:
import glob
import os

from PIL import Image
import torch
from torchvision import transforms

Refer for code: https://pytorch.org/hub/pytorch_vision_vgg/

In [2]:
# Download ImageNet Labels
# !wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

In [3]:
# Read ImageNet categories
with open('./imagenet_classes.txt', 'r') as f:
    categories = [s.strip() for s in f.readlines()]

print(categories)

['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead', 'electric ray', 'stingray', 'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'robin', 'bulbul', 'jay', 'magpie', 'chickadee', 'water ouzel', 'kite', 'bald eagle', 'vulture', 'great grey owl', 'European fire salamander', 'common newt', 'eft', 'spotted salamander', 'axolotl', 'bullfrog', 'tree frog', 'tailed frog', 'loggerhead', 'leatherback turtle', 'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'common iguana', 'American chameleon', 'whiptail', 'agama', 'frilled lizard', 'alligator lizard', 'Gila monster', 'green lizard', 'African chameleon', 'Komodo dragon', 'African crocodile', 'American alligator', 'triceratops', 'thunder snake', 'ringneck snake', 'hognose snake', 'green snake', 'king snake', 'garter snake', 'water snake', 'vine snake', 'night snake', 'boa constrictor', 'rock python', 'Indian cobra', 'green mamba', 'sea snake', 'horned viper', 'diamondback', 

In [4]:
# Path to all image files (jpg/jpeg)
data_path = '../data/*.j*'

# Creating a mapping in the form of filename: path
basename_path_mapping = {}

for file in glob.glob(data_path):
    basename_path_mapping[os.path.basename(file).split('.')[0]] = file
    
print(basename_path_mapping)

{'crab8': '../data/crab8.jpg', 'couch1': '../data/couch1.jpeg', 'rose3': '../data/rose3.jpg', 'mushroom5': '../data/mushroom5.jpeg', 'rose5': '../data/rose5.jpeg', 'rose6': '../data/rose6.jpg', 'rose4': '../data/rose4.jpeg', 'mushroom4': '../data/mushroom4.jpeg', 'rose7': '../data/rose7.jpg', 'mushroom8': '../data/mushroom8.jpeg', 'dinosaur8': '../data/dinosaur8.jpg', 'mushroom3': '../data/mushroom3.jpeg', 'aquariumfish8': '../data/aquariumfish8.jpg', 'couch7': '../data/couch7.jpeg', 'couch6': '../data/couch6.jpeg', 'rose2': '../data/rose2.jpeg', 'mushroom2': '../data/mushroom2.jpeg', 'aquariumfish1': '../data/aquariumfish1.jpg', 'mushroom1': '../data/mushroom1.jpeg', 'couch5': '../data/couch5.jpeg', 'aquariumfish2': '../data/aquariumfish2.jpg', 'aquariumfish3': '../data/aquariumfish3.jpg', 'aquariumfish7': '../data/aquariumfish7.jpg', 'aquariumfish6': '../data/aquariumfish6.jpg', 'couch4': '../data/couch4.jpeg', 'aquariumfish4': '../data/aquariumfish4.jpg', 'aquariumfish5': '../data/a

In [5]:
# Initializing the VGG model
model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True)
model.eval()

Using cache found in /Users/rajchoudhary/.cache/torch/hub/pytorch_vision_v0.10.0


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [6]:
# Traversing each image and getting top 15 probabilities using VGG16
for label, path in basename_path_mapping.items():
    input_image = Image.open(path)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)
    # Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
    # print(output[0])
    # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    # print(probabilities)
    print('-----------------')
    print(label)
    print('-----------------')
    top15_prob, top15_catid = torch.topk(probabilities, 15)
    for i in range(top15_prob.size(0)):
        print(categories[top15_catid[i]], top15_prob[i].item())
    print()
    print()

-----------------
crab8
-----------------
fiddler crab 0.5659475922584534
hermit crab 0.19332216680049896
rock crab 0.09858668595552444
crayfish 0.04314781352877617
scorpion 0.029799729585647583
grasshopper 0.01599862053990364
Dungeness crab 0.012048449367284775
cicada 0.007420102600008249
king crab 0.00702812010422349
cricket 0.0066971140913665295
mantis 0.003444529604166746
wolf spider 0.0026380245108157396
bee 0.0019025111105293036
isopod 0.001562517136335373
leafhopper 0.0014880168018862605


-----------------
couch1
-----------------
studio couch 0.9948638081550598
velvet 0.0023133382201194763
tile roof 0.0012970365351065993
panpipe 0.0004132630710955709
window shade 0.00032087654108181596
shoji 0.0002507360768504441
home theater 6.735726492479444e-05
bath towel 4.428693500813097e-05
quilt 3.9666610973654315e-05
French loaf 3.7423520552692935e-05
bookcase 2.1550840756390244e-05
table lamp 1.9564336980693042e-05
tobacco shop 1.8383429051027633e-05
rocking chair 1.698490814305842e-0

-----------------
aquariumfish1
-----------------
goldfish 0.9444636702537537
eft 0.02757934108376503
ladybug 0.019525503739714622
sea slug 0.0018515000119805336
leafhopper 0.0010651614284142852
tree frog 0.0010443690698593855
common newt 0.0008446556166745722
goose 0.000775810272898525
leaf beetle 0.0007427109521813691
tailed frog 0.00042301122448407114
agama 0.0003307318256702274
goldfinch 0.00015259893552865833
puffer 0.0001258340635104105
agaric 0.00010992075112881139
terrapin 8.206439815694466e-05


-----------------
mushroom1
-----------------
agaric 0.8123747110366821
mushroom 0.15526844561100006
earthstar 0.011140630580484867
stinkhorn 0.008341273292899132
bolete 0.006127296946942806
hen-of-the-woods 0.003587126499041915
gyromitra 0.002629122231155634
coral fungus 0.00015067579806782305
golf ball 5.876569775864482e-05
strawberry 3.3355179766658694e-05
croquet ball 2.739815681707114e-05
baseball 1.650319609325379e-05
armadillo 1.2272346793906763e-05
Brittany spaniel 1.2150486327

-----------------
crab2
-----------------
rock crab 0.8000825643539429
fiddler crab 0.19898657500743866
Dungeness crab 0.0003908384242095053
scorpion 0.0003306907892692834
king crab 0.0001692734658718109
crayfish 3.181925421813503e-05
hermit crab 6.334229965432314e-06
American lobster 1.1128391861348064e-06
cicada 1.7351941039578378e-07
spiny lobster 1.6162886140591581e-07
centipede 1.4593028652143403e-07
tick 1.2720305164748424e-07
tarantula 2.100154716799807e-08
mud turtle 1.843399388690159e-08
weevil 1.4871690545703586e-08


-----------------
crab3
-----------------
Dungeness crab 0.6113294363021851
rock crab 0.1876998394727707
king crab 0.13289770483970642
crayfish 0.04995255544781685
fiddler crab 0.00827249325811863
American lobster 0.0031063267961144447
hermit crab 0.0026174162048846483
spiny lobster 0.002561083994805813
scorpion 0.0009577951277606189
tick 0.00019847252406179905
cicada 0.00012364548456389457
centipede 3.313008710392751e-05
grasshopper 2.4506833142368123e-05
Europ