In [3]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.applications import imagenet_utils
from tensorflow.keras.preprocessing.image import img_to_array
from imutils.object_detection import non_max_suppression

import numpy as np
import argparse
import cv2

In [4]:
def selective_search(image, method="fast"):
    
    ss = cv2.ximgproc.segmentation.createSelectiveSearchSegmentation()
    ss.setBaseImage(image)
    
    if method == "fast":
        ss.switchToSelectiveSearchFast()
    else:
        ss.switchToSelectiveSearchQualify()
        
    rects = ss.process()
    
    return rects

In [8]:
image_path = 'beagle.png'
method = 'fast'
conf = 0.9
labelFilters = None # A list of objects that we'll be looking for otherwise None

In [9]:
model = ResNet50(weights='imagenet')

image = cv2.imread(image_path)
(H, W) = image.shape[:2]

rects = selective_search(image, method=method)
print("[INFO] {} regions found by selective search".format(len(rects)))

[INFO] 922 regions found by selective search


In [10]:
proposals = []
boxes = []

for (x, y, w, h) in rects:
    
    # Skip the rects that are too small
    if w / float(W) < 0.1 or h / float(H) < 0.1:
        continue
        
    roi = image[y:y+h, x:x+w]
    roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
    roi = cv2.resize(roi, (224, 224))
    
    roi = img_to_array(roi)
    roi = preprocess_input(roi)
    
    proposals.append(roi)
    boxes.append((x, y, w, h))

In [11]:
proposals = np.array(proposals)
print("[INFO] proposal shape: {}".format(proposals.shape))


preds = model.predict(proposals)
preds = imagenet_utils.decode_predictions(preds, top=1)

[INFO] proposal shape: (534, 224, 224, 3)


In [13]:
labels = {}

for (i, p) in enumerate(preds):
    (imagenetID, label, prob) = p[0]
    
    if labelFilters is not None and label not in labelFilters:
        continue
        
    if prob >= conf:
        (x, y, w, h) = boxes[i]
        box = (x, y, x + w, y + h)
        
        L = labels.get(label, [])
        L.append((box, prob))
        labels[label] = L