In [1]:
import math
import torch
import numpy as np


def get_feature_map_size(input_img_size, stride):
    
    input_size = np.asarray(input_img_size).astype(np.float)

    feature_map_size = input_size / stride
    
    return np.floor(feature_map_size).astype(np.int)
    
    
class AnchorBoxGenerator():
    
    def __init__(self,
                 anchor_areas=[128*128, 256*256, 512*512],
                 aspect_ratios=[1/2., 1/1., 2/1.],
                 stride=16
                ):
        
        self.anchor_areas = anchor_areas
        self.aspect_ratios = aspect_ratios
        self.stride = stride
        
        
    def get_anchor_boxes_sizes(self):
        
        anchor_boxes_sizes = []
        
        for current_anchor_area in self.anchor_areas:
            
            for current_aspect_ratio in self.aspect_ratios:
                
                # Given:
                # aspect_ratio = w / h
                # anchor_area = w * h
                # To find:
                # w and h
                # w = sqrt( aspect_ratio * anchor_area ) = sqrt( (w*w*h) / h ) = sqrt(w*w) = w
                
                w = math.sqrt( current_aspect_ratio * current_anchor_area )
                h = current_anchor_area / w
                
                anchor_boxes_sizes.append((h, w))
        
        # Adding a dummy dimension here in order to easily use .repeat() later
        return np.expand_dims( np.asarray(anchor_boxes_sizes), axis=0 )
    
    
    def get_anchor_boxes_center_coordinates(self, input_size):
        
        feature_map_height, feature_map_width = get_feature_map_size(input_size, stride=self.stride)

        meshgrid_height, meshgrid_width = np.meshgrid(range(feature_map_height), range(feature_map_width))

        # Getting coordinates of centers of all the grid cells of the feature map
        anchor_coordinates_feature_map = zip(meshgrid_height.flatten(), meshgrid_width.flatten())
        anchor_coordinates_feature_map = np.asarray( anchor_coordinates_feature_map )
        anchor_coordinates_feature_map = anchor_coordinates_feature_map + 0.5

        anchor_coordinates_input = anchor_coordinates_feature_map * self.stride
        
        return np.expand_dims( anchor_coordinates_input, axis=1 )
        
    
    def get_anchor_boxes(self, input_size):
        
        anchor_boxes_sizes = self.get_anchor_boxes_sizes()
        anchor_boxes_center_coordinates = self.get_anchor_boxes_center_coordinates(input_size)
        
        anchor_boxes_sizes_number = anchor_boxes_sizes.shape[1]
        anchor_boxes_center_coordinates_number = anchor_boxes_center_coordinates.shape[0]
        
        anchor_boxes_center_coordinates_repeated = anchor_boxes_center_coordinates.repeat(anchor_boxes_sizes_number, axis=1)
        anchor_boxes_sizes_repeated =  anchor_boxes_sizes.repeat(anchor_boxes_center_coordinates_number, axis=0)
        
        anchor_boxes = np.dstack((anchor_boxes_center_coordinates_repeated, anchor_boxes_sizes_repeated))
        
        return anchor_boxes
        

In [38]:
import skimage.data as data

def center_xy_anchor_boxes_to_upper_left_xy(anchor_boxes):
    
    anchor_boxes_copy = anchor_boxes.copy()
    anchor_boxes_copy_flattened = anchor_boxes_copy.reshape((-1, 4))
    
    for anchor_box in anchor_boxes_copy_flattened:
        
        anchor_box[0] = anchor_box[0] - anchor_box[2] / 2
        anchor_box[1] = anchor_box[1] - anchor_box[3] / 2
    
    return anchor_boxes_copy_flattened.reshape(anchor_boxes_copy.shape)

def draw_stride_grid_on_image(image, stride, grid_color=[0, 0, 0]):
    
    img_with_grid = img.copy()
    
    # Modify the image to include the grid
    img_with_grid[:,::stride,:] = grid_color
    img_with_grid[::stride,:,:] = grid_color

    return img_with_grid

def draw_separate_grid_cell(image, x, y, stride, grid_color=[255, 255, 0]):
    
    img_with_grid = img.copy()
    
    img_with_grid[y*stride:(y+1)*stride, [x*stride, (x+1)*stride], :] = grid_color
    img_with_grid[[y*stride, (y+1)*stride], x*stride:(x+1)*stride, :] = grid_color
    
    return img_with_grid
    

stride = 16

img = data.astronaut()
img_shape = img.shape[:2]

feature_map_height, feature_map_width = get_feature_map_size(img_shape, stride=stride)

img = draw_stride_grid_on_image(img, stride=stride)

anchor_generator = AnchorBoxGenerator(stride=stride)

anchor_boxes_different_areas_number = len(anchor_generator.anchor_areas)
anchor_boxes_different_aspect_ratios_number = len(anchor_generator.aspect_ratios)

anchor_boxes_centered = anchor_generator.get_anchor_boxes(input_size=img_shape)
anchor_boxes_top_left = center_xy_anchor_boxes_to_upper_left_xy(anchor_boxes_centered)

In [41]:
%matplotlib inline

from matplotlib import pyplot as plt
import matplotlib.patches as patches

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

from collections import OrderedDict


def display_anchors(x, y, anchor_area, anchor_aspect_ratio):

    # Create figure and axes
    fig, ax = plt.subplots(1, figsize=(10, 10))
    
    image = draw_separate_grid_cell(img, x, y, stride)
    
    # Display the image
    ax.imshow(image)
    
    spatial_position = x * feature_map_height + y
    
    anchor_type = anchor_area * anchor_boxes_different_areas_number + anchor_aspect_ratio
    

    anchor_box = anchor_boxes_top_left[spatial_position, anchor_type, :]

    x = anchor_box[1]
    y = anchor_box[0]

    x_size = anchor_box[3]
    y_size = anchor_box[2]

    # Create a Rectangle patch
    rect = patches.Rectangle( (x, y), x_size, y_size, linewidth=1, edgecolor='b', facecolor='none')
    # Add the patch to the Axes
    ax.add_patch(rect)
    
    ax.plot([x + x_size / 2], [y + y_size / 2], 'bs')

    plt.show()
    


spatial_position_widget = widgets.IntSlider(min=0,
                                            max=feature_map_width - 1,
                                            step=1,
                                            value=10,
                                            continuous_update=False)


spatial_position_widget2 = widgets.IntSlider(min=0,
                                             max=feature_map_height - 1,
                                             step=1,
                                             value=10,
                                             continuous_update=False)



areas_sqrt = map(lambda x: str(int(math.sqrt(x))) + '^2', anchor_generator.anchor_areas)
anchor_area_slider_names_values_dict = OrderedDict(zip(areas_sqrt, range(anchor_boxes_different_areas_number)))


anchor_area_widget = widgets.SelectionSlider(
    options=anchor_area_slider_names_values_dict,
    description='anchor area',
    continuous_update=False,
)


aspect_ratios_list_of_strings = map(lambda x: str(x), anchor_generator.aspect_ratios)

anchor_aspect_slider_names_values_dict = OrderedDict(zip(aspect_ratios_list_of_strings, range(anchor_boxes_different_aspect_ratios_number)))

anchor_aspect_widget = widgets.SelectionSlider(
    options=anchor_aspect_slider_names_values_dict,
    description='anchor aspect ratio',
    continuous_update=False,
)


interact(display_anchors,
         x=spatial_position_widget,
         y=spatial_position_widget2,
         anchor_area=anchor_area_widget,
         anchor_aspect_ratio=anchor_aspect_widget)

A Jupyter Widget

<function __main__.display_anchors>