In [1]:
#Ignoring TF warning messages
import logging, os
logging.disable(logging.WARNING)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [2]:
import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'googlenet', pretrained=True)
model.eval()

Using cache found in /home/yaduk/.cache/torch/hub/pytorch_vision_v0.10.0


GoogLeNet(
  (conv1): BasicConv2d(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): BasicConv2d(
    (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception3a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track

In [3]:
# Download an example image from the pytorch website
import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

In [4]:
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)

tensor([ 3.4993e-02, -2.2940e-01, -3.2330e-01,  5.6914e-02,  1.1345e-01,
        -2.8261e-01,  6.1713e-01,  6.1639e-02,  9.4646e-01, -1.4962e+00,
        -5.5129e-01, -3.4354e-02, -1.2640e+00, -3.2054e-02,  5.3394e-01,
         1.8727e-01,  4.9359e-01, -2.8194e-01, -2.7175e-01, -2.6153e-01,
        -3.3352e-01, -7.6689e-02,  6.7201e-02, -5.7018e-01, -5.3244e-01,
        -4.9747e-02,  7.1980e-01,  1.1622e+00,  5.1527e-01,  1.3025e+00,
         6.8959e-01,  5.5545e-01,  1.3226e-01, -7.1036e-01, -5.0300e-01,
        -2.4448e-01, -6.0022e-01,  1.9392e-01, -3.2075e-01,  6.1728e-01,
         2.9609e-01, -2.4302e-01,  1.8587e-01, -4.5841e-01,  1.5209e-01,
        -6.8352e-01,  9.8970e-01,  6.1482e-01, -1.2936e+00, -4.6710e-01,
        -7.0906e-02,  3.1543e-04,  3.7383e-01,  1.8148e-01,  8.5809e-01,
         1.0046e+00, -3.0290e-01,  1.9980e-02,  7.7383e-02,  8.1170e-01,
         7.7608e-01, -7.0507e-01, -2.2673e-01, -2.5853e-01,  1.5494e-01,
        -4.5743e-01,  7.8234e-01,  1.4158e-01,  1.0

Download ImageNet labels
*!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt*

In [5]:
# Read the categories
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]
# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item()) #

Samoyed 0.9378912448883057
Pomeranian 0.008303353562951088
Great Pyrenees 0.005579737946391106
Arctic fox 0.0055304295383393764
white wolf 0.004736033733934164
