In [1]:
import math
import torch
import numpy as np

import torchvision
import matplotlib.pyplot as plt


def get_feature_map_size(input_img_size, stride):
    """
    Given an image size and stride of a network, computes the output feature map size.
    """
    
    input_size = np.asarray(input_img_size).astype(np.float)

    feature_map_size = input_size / stride
    
    return np.floor(feature_map_size).astype(np.int)
    

class AnchorBoxGenerator():
    """
    Anchor box generator class that takes care of generation of anchor
    boxes for each element of resulted feature map. Feature map size is
    computed given the input image size, it is equal to the the size of the input
    image subsampled with a ```stirde``` parameter.
    
    The class accepts ```anchor_areas```, ```aspect_ratios``` parameters.
    All possible pairs of these combinations are generated and result in
    ```size(anchor_areas) * size(aspect_ratios) ``` anchor boxes for each
    element of resulted feature map.
    """
    
    def __init__(self,
                 anchor_areas=[128*128, 256*256, 512*512],
                 aspect_ratios=[1/2., 1/1., 2/1.],
                 stride=16
                ):
        
        self.anchor_areas = anchor_areas
        self.aspect_ratios = aspect_ratios
        self.stride = stride
        
        
    def get_anchor_boxes_sizes(self):
        """
        Computes all pairs of ```anchor_areas``` and ```aspect_ratios``` parameters
        resulting in bounding boxes of various sizes. Overall, the number of boxes is
        equal to ```size(anchor_areas) * size(aspect_ratios) ```.
        """
        
        anchor_boxes_sizes = []
        
        for current_anchor_area in self.anchor_areas:
            
            for current_aspect_ratio in self.aspect_ratios:
                
                # Given:
                # aspect_ratio = w / h
                # anchor_area = w * h
                # To find:
                # w and h
                # w = sqrt( aspect_ratio * anchor_area ) = sqrt( (w*w*h) / h ) = sqrt(w*w) = w
                
                w = math.sqrt( current_aspect_ratio * current_anchor_area )
                h = current_anchor_area / w
                
                anchor_boxes_sizes.append((h, w))
        
        # Adding a dummy dimension here in order to easily use .repeat() later
        return np.expand_dims( np.asarray(anchor_boxes_sizes), axis=0 )
    
    
    def get_anchor_boxes_center_coordinates(self, input_size):
        """
        Computes the coordinates of centers of bounding boxes of each element
        of feature map with respect to the input image coordinate system.
        We will need this to compute intersections of groundtruth boxes with 
        our generated anchor boxes which we will need to classify each anchor box
        as a positive/negative/ambigious.
        """
        
        feature_map_height, feature_map_width = get_feature_map_size(input_size, stride=self.stride)

        meshgrid_height, meshgrid_width = np.meshgrid(range(feature_map_height), range(feature_map_width))

        # Getting coordinates of centers of all the grid cells of the feature map
        anchor_coordinates_feature_map = zip(meshgrid_height.flatten(), meshgrid_width.flatten())
        anchor_coordinates_feature_map = np.asarray( anchor_coordinates_feature_map )
        anchor_coordinates_feature_map = anchor_coordinates_feature_map + 0.5

        anchor_coordinates_input = anchor_coordinates_feature_map * self.stride
        
        return np.expand_dims( anchor_coordinates_input, axis=1 )
        
    
    def get_anchor_boxes(self, input_size):
        """
        Function that combines all the previous functions to compute all anchor boxes
        with their coordinates with respect to the input image's coordinate system.
        
        Number of anchor boxes can be computed as:
         ```size(anchor_areas) * size(aspect_ratios) * (input_height / stride) * (input_width / stride) ```
        """
        
        anchor_boxes_sizes = self.get_anchor_boxes_sizes()
        anchor_boxes_center_coordinates = self.get_anchor_boxes_center_coordinates(input_size)
        
        anchor_boxes_sizes_number = anchor_boxes_sizes.shape[1]
        anchor_boxes_center_coordinates_number = anchor_boxes_center_coordinates.shape[0]
        
        anchor_boxes_center_coordinates_repeated = anchor_boxes_center_coordinates.repeat(anchor_boxes_sizes_number, axis=1)
        anchor_boxes_sizes_repeated =  anchor_boxes_sizes.repeat(anchor_boxes_center_coordinates_number, axis=0)
        
        anchor_boxes = np.dstack((anchor_boxes_center_coordinates_repeated, anchor_boxes_sizes_repeated))
        
        return anchor_boxes


def center_xy_anchor_boxes_to_upper_left_xy(anchor_boxes):
    """
    Converts the coordinates of the format
    (x_center, y_center, width, height) to
    (x_topleft, y_topleft, width, height)
    """
    
    anchor_boxes_copy = anchor_boxes.copy()
    anchor_boxes_copy_flattened = anchor_boxes_copy.reshape((-1, 4))
    
    for anchor_box in anchor_boxes_copy_flattened:
        
        anchor_box[0] = anchor_box[0] - anchor_box[2] / 2
        anchor_box[1] = anchor_box[1] - anchor_box[3] / 2
    
    return anchor_boxes_copy_flattened.reshape(anchor_boxes_copy.shape)

def convert_xywh_to_xyxy(bounding_box):
    """
    Converts the coordinates of the format
    (x_topleft, y_topleft, width, height) to
    (x_topleft, y_topleft, x_bottomright, y_bottomright)
    """
    
    bounding_box = np.asarray(bounding_box).copy()
    
    bounding_box[2] = bounding_box[0] + bounding_box[2]
    bounding_box[3] = bounding_box[1] + bounding_box[3]
    
    return bounding_box

def draw_stride_grid_on_image(image, stride, grid_color=[0, 0, 0]):
    """
    Draws a stride grid on the input image with a give stride.
    Used to demonstrate regions in the input image that are associated
    with respective elementes of subsampled feature map.
    """
    img_with_grid = img.copy()
    
    # Modify the image to include the grid
    img_with_grid[:,::stride,:] = grid_color
    img_with_grid[::stride,:,:] = grid_color

    return img_with_grid

def draw_separate_grid_cell(image, x, y, stride, grid_color=[255, 255, 0]):
    """
    Highlights a certain grid cell on the image. See the previous function
    documentation for details.
    """
    
    img_with_grid = img.copy()
    
    img_with_grid[y*stride:(y+1)*stride, [x*stride, (x+1)*stride], :] = grid_color
    img_with_grid[[y*stride, (y+1)*stride], x*stride:(x+1)*stride, :] = grid_color
    
    return img_with_grid

def bboxes_ious(bboxes_group_1, bboxes_group_2):
    """
    Computes the intersections over union metric between each
    pair of boxes from group 1 and group 2.
    
    Also returns the coordinates of the intersection rectangle if
    the intersection exists.
    """
    
    # Computing the bboxes of the intersections between
    # each pair of boxes from group 1 and 2
    top_left = np.maximum(bboxes_group_1[:,None , :2],
                          bboxes_group_2[:, :2]).astype(np.float)

    bottom_right = np.minimum(bboxes_group_1[:, None, 2:],
                              bboxes_group_2[:, 2:])
    
    intersections_bboxes_xyxy = np.dstack((top_left, bottom_right))
    
    intersections_bboxes_width_height = np.clip( bottom_right - top_left, a_min=0, a_max=None)

    intersections_bboxes_areas = intersections_bboxes_width_height[:, :, 0] * intersections_bboxes_width_height[:, :, 1]
    
    bboxes_group_1_areas = (bboxes_group_1[:,2]-bboxes_group_1[:,0]) * (bboxes_group_1[:,3]-bboxes_group_1[:,1])
    bboxes_group_2_areas = (bboxes_group_2[:,2]-bboxes_group_2[:,0]) * (bboxes_group_2[:,3]-bboxes_group_2[:,1])
    
    ious = intersections_bboxes_areas / (bboxes_group_1_areas[:, None] + bboxes_group_2_areas - intersections_bboxes_areas)
    
    return ious, intersections_bboxes_xyxy
    
# Output stride of the network, change if you want a different one
stride = 32

# Loading the coco-like dataset file of PASCAL VOC 2012 detection
coco_db = torchvision.datasets.CocoDetection(annFile='/home/daniil/projects/pascal/PASCAL_VOC/pascal_val2012.json',
                                             root='/home/daniil/projects/pascal/dataset/VOCdevkit/VOC2012/JPEGImages/')    

# Taking one predefine sample from the dataset
img, annotations = coco_db[4700]

# Getting the image and its size
img = np.asarray(img)
img_shape = img.shape[:2]

# Getting the groundtruth bounding box
ground_truth_bbox = annotations[3]['bbox']
ground_truth_bbox_xyxy = convert_xywh_to_xyxy( ground_truth_bbox ).astype(np.float)
ground_truth_bbox_xyxy = ground_truth_bbox_xyxy[None, :]


feature_map_height, feature_map_width = get_feature_map_size(img_shape, stride=stride)

img = draw_stride_grid_on_image(img, stride=stride)

anchor_generator = AnchorBoxGenerator(stride=stride)

anchor_boxes_different_areas_number = len(anchor_generator.anchor_areas)
anchor_boxes_different_aspect_ratios_number = len(anchor_generator.aspect_ratios)

anchor_boxes_centered = anchor_generator.get_anchor_boxes(input_size=img_shape)
anchor_boxes_top_left = center_xy_anchor_boxes_to_upper_left_xy(anchor_boxes_centered)

loading annotations into memory...
Done (t=0.14s)
creating index...
index created!


In [4]:
from matplotlib import pyplot as plt
import matplotlib.patches as patches

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

from collections import OrderedDict


def display_anchors(x, y, anchor_area, anchor_aspect_ratio):

    # Create figure and axes
    fig, ax = plt.subplots(1, figsize=(9, 6.9))
    
    # Drawing the feature map drid cells in the input image
    image = draw_separate_grid_cell(img, x, y, stride)
    
    # Display the image
    ax.imshow(image)
    
    spatial_position = x * feature_map_height + y
    anchor_type = anchor_area * anchor_boxes_different_areas_number + anchor_aspect_ratio
    
    anchor_box = anchor_boxes_top_left[spatial_position, anchor_type, :].copy()
    # converting h, w to x, y, basically just swapping values
    anchor_box[:2] = anchor_box[1], anchor_box[0]
    anchor_box[2:] = anchor_box[3], anchor_box[2]
    
    # Create a Rectangle patch
    rect = patches.Rectangle(anchor_box[:2],
                             anchor_box[2],
                             anchor_box[3],
                             linewidth=2,
                             edgecolor='b',
                             facecolor='none')
    # Add the patch to the Axes
    ax.add_patch(rect)
    
    x_anchor_box_center = anchor_box[0] + anchor_box[2] / 2
    y_anchor_box_center = anchor_box[1] + anchor_box[3] / 2
    
    ax.plot(x_anchor_box_center, y_anchor_box_center, 'bs')
    
    # Create a Rectangle patch
    rect = patches.Rectangle( ground_truth_bbox[:2],
                             ground_truth_bbox[2],
                             ground_truth_bbox[3],
                             linewidth=2,
                             edgecolor='r',
                             facecolor='none')
    # Add the patch to the Axes
    ax.add_patch(rect)
    
    anchor_box_xyxy = convert_xywh_to_xyxy(anchor_box)
    anchor_box_xyxy = anchor_box_xyxy[None, :]
    
    intersection_area, intersection_bbox = bboxes_ious(ground_truth_bbox_xyxy, anchor_box_xyxy)
    intersection_bbox = intersection_bbox.squeeze()
    
    print("Intersection over union: ", intersection_area[0, 0])
    
    # display intersection area if intersection exists
    if (intersection_area > 0).all():
        
        # Create a Rectangle patch
        rect = patches.Rectangle( intersection_bbox[:2],
                                 intersection_bbox[2] - intersection_bbox[0],
                                 intersection_bbox[3] - intersection_bbox[1],
                                 linewidth=2,
                                 edgecolor='w',
                                facecolor='none',
                                hatch='\\')
        # Add the patch to the Axes
        ax.add_patch(rect)
        
    plt.show()    
    


spatial_position_widget = widgets.IntSlider(min=0,
                                            max=feature_map_width - 1,
                                            step=1,
                                            value=10,
                                            continuous_update=False)


spatial_position_widget2 = widgets.IntSlider(min=0,
                                             max=feature_map_height - 1,
                                             step=1,
                                             value=5,
                                             continuous_update=False)



areas_sqrt = map(lambda x: str(int(math.sqrt(x))) + '^2', anchor_generator.anchor_areas)
anchor_area_slider_names_values_dict = OrderedDict(zip(areas_sqrt, range(anchor_boxes_different_areas_number)))


anchor_area_widget = widgets.SelectionSlider(
    options=anchor_area_slider_names_values_dict,
    description='anchor area',
    continuous_update=False,
)



aspect_ratios_list_of_strings = map(lambda x: str(x), anchor_generator.aspect_ratios)


anchor_aspect_slider_names_values_dict = OrderedDict(zip(aspect_ratios_list_of_strings, range(anchor_boxes_different_aspect_ratios_number)))


anchor_aspect_widget = widgets.SelectionSlider(
    options=anchor_aspect_slider_names_values_dict,
    description='anchor aspect ratio',
    continuous_update=False,
)


interact(display_anchors,
         x=spatial_position_widget,
         y=spatial_position_widget2,
         anchor_area=anchor_area_widget,
         anchor_aspect_ratio=anchor_aspect_widget)

A Jupyter Widget

<function __main__.display_anchors>