## This is a demo notebook showing how we can use the Object Detection and Recognition from Facebook AI Research (FAIR)

Import required libraries and download the model from PyTorch hub. There are lots of models to choose from. To run inference, set *pretrained=True*

In [None]:
import torch as th
import torchvision.transforms as T
import requests
from PIL import Image, ImageDraw, ImageFont
import ipywidgets as widgets

In [None]:
model = th.hub.load('facebookresearch/detr', 'detr_resnet101', pretrained=True)
model.eval()
model = model.cuda()

Run basic compositions are required for ResNet

In [None]:
# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

Get user input for image search

In [None]:
print("Enter up to 5 comma separated search terms")
search_terms=widgets.Textarea(
    value="",
    description=""
)
display(search_terms)

In [None]:
search_terms=search_terms.value.split(",")
if len(search_terms)>5:
  search_terms=search_terms[:5]

Set up code to return images using the Unsplash API. 

We save the top search result and its description from the json response

In [None]:
height, width=800, 600
#Sign up for an unsplash API here: 
#https://unsplash.com/oauth/applications
access_key="<Access Key here>"
base_url="https://api.unsplash.com/search/photos?query="
client_id_str="&client_id="+access_key
keywords=search_terms
imgs=[]
img_descs=[]
for keyword in keywords:
  query=base_url+keyword+client_id_str
  r=requests.get(query)
  img_url=r.json()['results'][0]['urls']['raw']
  img_desc=r.json()['results'][0]['description']
  img = Image.open(requests.get(img_url, stream=True).raw).resize((height, width)).convert('RGB')
  imgs.append(img)
  img_descs.append(img_desc)

This is the cell where we run inference on the images gathered from Unsplash. The steps are as follows:
1. Make a copy of the image
2. Set model to eval mode by setting *no_grad*
3. Forward propagate image through the model. Output includes:
 - Probabilities of classes
 - Bounding boxes for those classes
4. We pick the top *objects_to_detect* values from the outputs and their args and plot the boxes and labels on the copy of the image

In [None]:
objects_to_detect=3
from PIL import ImageFont
# fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", 10)
for idx, img in enumerate(imgs):
  im2 = img.copy()
  drw = ImageDraw.Draw(im2)
  img_tens = transform(img).unsqueeze(0).cuda()
  with th.no_grad():
    output = model(img_tens)
  pred_logits=output['pred_logits'][0][:, :len(CLASSES)]
  pred_boxes=output['pred_boxes'][0]
  max_output = pred_logits.softmax(-1).max(-1)
  topk = max_output.values.topk(objects_to_detect)

  pred_logits = pred_logits[topk.indices]
  pred_boxes = pred_boxes[topk.indices]
  for logits, box in zip(pred_logits, pred_boxes):
    cls = logits.argmax()
    if cls >= len(CLASSES):
      continue
    label = CLASSES[cls]
    box = box.cpu() * th.Tensor([height, width, height, width])
    x, y, w, h = box
    x0, x1 = x-w//2, x+w//2
    y0, y1 = y-h//2, y+h//2
    drw.rectangle([x0, y0, x1, y1], outline='black', width=5)
    label_pos=(x0, y0-15)
    drw.text(label_pos, label, fill='red')
  print(img_descs[idx])
  display(im2)

To Do:
1. Return logits above a certain confidence threshold only
2. Find similarity between image tags and detected objects
3. Automatically add missing tags if confidence threshold is above a certain value