In [1]:
# Proof of concept implementation

In [2]:
from pytorch_pretrained_vit import ViT
import json
from PIL import Image
import torch
from torchvision import transforms
import cv2

In [3]:
# Load pretrained Vision Transformer
model = ViT('B_16_imagenet1k', pretrained=True)

Loaded pretrained weights.


In [4]:
# Check image
img = Image.open('img.jpg')

In [5]:
# Object detection
detector_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).autoshape()
results = detector_model(img, size=640)
results.print()
results.show()

Using cache found in /Users/lautenschlager/.cache/torch/hub/ultralytics_yolov5_master
Fusing layers... 
Model Summary: 224 layers, 7266973 parameters, 0 gradients
Adding AutoShape... 
YOLOv5 🚀 2021-5-23 torch 1.8.1 CPU



AutoShape already enabled, skipping... 
image 1/1: 563x845 1 bird
Speed: 11.1ms pre-process, 195.7ms inference, 2.0ms NMS per image at shape (1, 3, 448, 640)


In [6]:
# Get bounding box coordinates from detection
x1 = int(results.xyxy[0][0][0])
y1 = int(results.xyxy[0][0][1])
x2 = int(results.xyxy[0][0][2])
y2 = int(results.xyxy[0][0][3])
bbox_points=[x1, y1, x2, y2]
print('bounding box is:', x1, y1, x2, y2)

bounding box is: 253 87 587 541


In [7]:
# Show cropped image (bounding box)
cropped_img = img.crop((x1, y1, x2, y2))
cropped_img.show()

In [8]:
# Preprocess image for transformer
tfms = transforms.Compose([transforms.Resize(model.image_size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),])
cropped_img = tfms(cropped_img).unsqueeze(0)

In [9]:
# Load class labels
labels_map = json.load(open('labels_map.txt'))
labels_map = [labels_map[str(i)] for i in range(1000)]

In [10]:
# Classify
model.eval()
with torch.no_grad():
    outputs = model(cropped_img).squeeze(0)
print('-----')  
for idx in torch.topk(outputs, k=3).indices.tolist():
    prob = torch.softmax(outputs, -1)[idx].item()
    print('[{idx}] {label:<75} ({p:.2f}%)'.format(idx=idx, label=labels_map[idx], p=prob*100))

-----
[24] great grey owl, great gray owl, Strix nebulosa                              (99.98%)
[21] kite                                                                        (0.00%)
[798] slide rule, slipstick                                                       (0.00%)
