In [2]:
import numpy as n
import torch as t
from torch import nn

In [57]:
class AnchorGenerator(nn.Module):
    def __init__(self,ratios=[[0.5,1,2]],anchor_scales=[[8,16,32]],base_size=16):
        super().__init__()
        self.ratios = ratios
        self.anchor_scales= anchor_scales
        self.base_size= base_size
        self.cell_anchors =None
    
    def _ratio_enum(self,ratios):
        h = self.base_size*t.sqrt(ratios)
        w = self.base_size*t.sqrt(1/ratios)
        return h,w
    def num_anchors_per_location(self):
        return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
    
    def generate_anchors(self,ratios,scales,dtype=t.float32,device="cpu"):
        ratios = t.as_tensor(ratios, dtype=dtype, device=device)
        anchor_scales = t.as_tensor(scales, dtype=dtype, device=device)
        h,w = self._ratio_enum(ratios)
        hs = (anchor_scales.view(-1,1)*h.view(1,-1)).view(-1)
        ws = (anchor_scales.view(-1,1)*w.view(1,-1)).view(-1)
        base_anchor = t.stack((-hs,-ws,hs,ws),dim=1)/2
        return base_anchor
    
    def set_cell_anchors(self,dtype,device):
        if self.cell_anchors is not None:
            cell_anchors = self.cell_anchors
            assert cell_anchors is not None
            if cell_anchors[0].device == device:
                return
        cell_anchors=[self.generate_anchors(ratios,scales,dtype,device) for ratios,scales in zip(self.ratios,self,anchor_scales)]
        self.cell_anchors = cell_anchors
    def forward(self,imgs,feature_maps):
        # feature_map: (N,C,H,W)
        grid_sizes=[feature_map.shape[-2:] for feature_map in feature_maps] # (N,2)
        img_size = imgs.tensors.shape[-2:] # (N，2）
        dtype,device = feature_maps[0].dtype,feature_maps[0].device
        strides= [[t.tensor(img_size[0]//g[0],dtype=t.int64,device=device),
                  t.tensor(img_size[1]//g[1],dtype=t.int64,device=device)] for g in grid_sizes]
        self.set_cell_anchors(dtype,device)
        anchors_over_feature_maps = self.grid_anchors(grid_sizes,strides)
        return anchors_over_feature_maps
    
    def grid_anchors(self,grid_sizes,strides):
        anchors = []
        for size,stride in zip(grid_sizes,strides):
            grid_height,grid_width = size
            stride_height,stride_width = stride
            device = self.cell_anchors.device
            shifts_y = t.arange(0,grid_height,dtype=t.float32,device=device)*stride_height
            shifts_x = t.arange(0,grid_width,dtype=t.float32,device=device)*stride_width
            shift_x,shift_y = t.meshgrid(shifts_x,shifts_y)
            shift_y = shift_y.reshape(-1)
            shift_x = shift_x.reshape(-1)
            shifts = t.stack((shift_y,shift_x,shift_y,shift_y),dim=1)
            anchors.append((shifts.view(-1,1,4)+self.cell_anchors.view(1,-1,4)).reshape(-1,4))
        # anchors: (N,number_of_grid *9,4)
        return anchors




In [58]:
rpn_anchor_generator = AnchorGenerator()

In [59]:
rpn_anchor_generator.generate_anchors(ratios=[[0.5,1,2]],device="cpu",dtype=t.float32,scales=[[8,16,32]])

tensor([[ -45.2548,  -90.5097,   45.2548,   90.5097],
        [ -64.0000,  -64.0000,   64.0000,   64.0000],
        [ -90.5097,  -45.2548,   90.5097,   45.2548],
        [ -90.5097, -181.0193,   90.5097,  181.0193],
        [-128.0000, -128.0000,  128.0000,  128.0000],
        [-181.0193,  -90.5097,  181.0193,   90.5097],
        [-181.0193, -362.0387,  181.0193,  362.0387],
        [-256.0000, -256.0000,  256.0000,  256.0000],
        [-362.0387, -181.0193,  362.0387,  181.0193]])

In [36]:
ratios=[[0.5,1,2]]
anchor_scales=[[8,16,32]]

In [38]:
list(zip(ratios,anchor_scales))

[([0.5, 1, 2], [8, 16, 32])]

In [23]:
anchor_scales=t.as_tensor([8,16,32]).view(-1,1)

In [27]:
h

tensor([11.3137, 16.0000, 22.6274])

In [26]:
h.view(1,-1)*anchor_scales

tensor([[ 90.5097, 128.0000, 181.0193],
        [181.0193, 256.0000, 362.0387],
        [362.0387, 512.0000, 724.0773]])

In [22]:
t.stack((-h,-w,h,w))

tensor([[-11.3137, -16.0000, -22.6274],
        [-22.6274, -16.0000, -11.3137],
        [ 11.3137,  16.0000,  22.6274],
        [ 22.6274,  16.0000,  11.3137]])

In [31]:
def anchor_generator(ratios=[0.5,1,2],anchor_scales=[8,16,32],base_size=16):
    h,w = _ratio_enum(base_size, ratios)
    x_ctr = (base_size-1)/2
    y_ctr = (base_size-1)/2
    anchor_base = np.zeros((len(ratios)*len(anchor_scales),4),dtype = np.float32)
    for i,scale in enumerate(anchor_scales):
        for j in range(3):
            index = i*len(anchor_scales) +j
            anchor_base[index,0] = y_ctr - (h[j]*scale-1)/2
            anchor_base[index,1] = x_ctr - (w[j]*scale-1)/2
            anchor_base[index,2] = y_ctr + (h[j]*scale-1)/2
            anchor_base[index,3] = x_ctr + (w[j]*scale-1)/2
    return anchor_base
             



def _ratio_enum(base_size,ratios):
    h = base_size*np.sqrt(ratios)
    w = base_size*np.sqrt(1/np.array(ratios))
    return h,w



In [32]:
base_anchor = anchor_generator()

In [33]:
base_anchor

array([[ -37.254833,  -82.50967 ,   52.254833,   97.50967 ],
       [ -56.      ,  -56.      ,   71.      ,   71.      ],
       [ -82.50967 ,  -37.254833,   97.50967 ,   52.254833],
       [ -82.50967 , -173.01933 ,   97.50967 ,  188.01933 ],
       [-120.      , -120.      ,  135.      ,  135.      ],
       [-173.01933 ,  -82.50967 ,  188.01933 ,   97.50967 ],
       [-173.01933 , -354.03867 ,  188.01933 ,  369.03867 ],
       [-248.      , -248.      ,  263.      ,  263.      ],
       [-354.03867 , -173.01933 ,  369.03867 ,  188.01933 ]],
      dtype=float32)

In [55]:
anchor_scales=np.expand_dims(anchor_scales,axis=(0,2))

In [56]:
anchor_scales.shape

(1, 3, 1)

In [57]:
anchor_scales*ratios

array([[[ 4.,  8., 16.],
        [ 8., 16., 32.],
        [16., 32., 64.]]])

In [6]:
a = np.arange(0,16*2,16)
b = np.arange(0,16*2,16)

In [8]:
shift_x, shift_y = np.meshgrid(a,b)


In [14]:
shift_y.ravel()

array([ 0,  0, 16, 16])

In [21]:
shift = np.stack((shift_y.ravel(), shift_x.ravel(),
                      shift_y.ravel(), shift_x.ravel()), axis=1)

In [29]:
base_anchor

array([[ -37.254833,  -82.50967 ,   52.254833,   97.50967 ],
       [ -56.      ,  -56.      ,   71.      ,   71.      ],
       [ -82.50967 ,  -37.254833,   97.50967 ,   52.254833],
       [ -82.50967 , -173.01933 ,   97.50967 ,  188.01933 ],
       [-120.      , -120.      ,  135.      ,  135.      ],
       [-173.01933 ,  -82.50967 ,  188.01933 ,   97.50967 ],
       [-173.01933 , -354.03867 ,  188.01933 ,  369.03867 ],
       [-248.      , -248.      ,  263.      ,  263.      ],
       [-354.03867 , -173.01933 ,  369.03867 ,  188.01933 ]],
      dtype=float32)

In [28]:
base_anchor.reshape(1,9,4)+shift.reshape(4,1,4)

array([[[ -37.25483322,  -82.50966644,   52.25483322,   97.50966644],
        [ -56.        ,  -56.        ,   71.        ,   71.        ],
        [ -82.50966644,  -37.25483322,   97.50966644,   52.25483322],
        [ -82.50966644, -173.01933289,   97.50966644,  188.01933289],
        [-120.        , -120.        ,  135.        ,  135.        ],
        [-173.01933289,  -82.50966644,  188.01933289,   97.50966644],
        [-173.01933289, -354.03866577,  188.01933289,  369.03866577],
        [-248.        , -248.        ,  263.        ,  263.        ],
        [-354.03866577, -173.01933289,  369.03866577,  188.01933289]],

       [[ -37.25483322,  -66.50966644,   52.25483322,  113.50966644],
        [ -56.        ,  -40.        ,   71.        ,   87.        ],
        [ -82.50966644,  -21.25483322,   97.50966644,   68.25483322],
        [ -82.50966644, -157.01933289,   97.50966644,  204.01933289],
        [-120.        , -104.        ,  135.        ,  151.        ],
        [-173.0193