This check the function [`filter_proposals`](https://github.com/pytorch/vision/blob/main/torchvision/models/detection/rpn.py) in `RPN` 

In [1]:
import os
import torch
import numpy as np

from torchvision.ops import boxes as box_ops

In [2]:
data_dirpath='D:/data/mask_rcnn'

device=torch.device("cpu")
objectness=torch.load(os.path.join(data_dirpath, "objectness.pt"),map_location=device, weights_only=True) # NHWAxC
proposals=torch.load(os.path.join(data_dirpath, "proposals.pt"),map_location=device, weights_only=True) #NxHWAx4C
print('proposals ', proposals.shape, proposals.min(), proposals.max())
print('objectness ', objectness.shape, objectness.min(), objectness.max())

proposals  torch.Size([2, 211038, 4]) tensor(-430.7388) tensor(1432.4897)
objectness  torch.Size([422076, 1]) tensor(-23.0735, grad_fn=<MinBackward1>) tensor(9.0714, grad_fn=<MaxBackward1>)


In [3]:
image_shapes=[(800, 1033), (800, 1026)]
num_anchors_per_level=[np.int64(158400), np.int64(39600), np.int64(9900), np.int64(2475), np.int64(663)]

num_images=proposals.shape[0] # batch size
device=proposals.device

# do not backprop through objectness
objectness=objectness.detach()
objectness=objectness.reshape(num_images, -1) # from NHWAxC to NxHWAC=NxHWA where C=1 typically
print('objectness ', objectness.shape)

objectness  torch.Size([2, 211038])


In [4]:
def _topk_min(input, orig_kval, axis):
    """
    Given the original k-value (orig_kval), check whether the k-value <= the number of input values. 
    If not, change return the number of input values instead of original k-values
    Args:
        input (Tensor): The original input tensor.
        orig_kval (int): The provided k-value.
        axis(int): Axis along which we retrieve the input size.
    Returns:
        min_kval (int): Appropriately selected k-value.
    """
    return min(orig_kval, input.size(axis))

# pre_nms_top_n={'training':2000, 'testing':1000}
def get_top_n_idx(objectness, num_anchors_per_level, pre_nms_top_n=2000):
    '''
    Get the indices of the top n highest objectness values
    Args:
        objectness (tensor): NxHWA from all levels
        num_anchors_per_level (sequence): the list of number of bounding boxes per level
        pre_nms_top_n (int): the number of top k objectness before NMS
    Return:
        top_k (tensor): NxM where M=Sum(min(K, HWA)) for each level. If, for all levels, HWA>K, then Nx(LK) where L is the number of levels
    '''
    # select top_n boxes independently per level before applying NMS
    r,offset=[],0
    for ob in objectness.split(num_anchors_per_level, 1): # split objectness into the number per level
        num_anchors=ob.shape[1] # NxHWA
        #pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
        pre_nms_top_n=_topk_min(input=ob,orig_kval=pre_nms_top_n,axis=1)
        _,top_n_idx=ob.topk(pre_nms_top_n, dim=1) # NxK
        r.append(top_n_idx+offset)
        offset+=num_anchors
    return torch.cat(r, dim=1)

In [5]:
levels=[torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)]
print('levels ', [l.shape for l in levels])
levels=torch.cat(levels, dim=0)
levels=levels.reshape(1, -1).expand_as(objectness)
print('levels ', levels.shape)

# select top_n boxes independently per level before applying nms
#top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
top_n_idx=get_top_n_idx(objectness, num_anchors_per_level, pre_nms_top_n=2000)
print('top_n_idx ', top_n_idx)

levels  [torch.Size([158400]), torch.Size([39600]), torch.Size([9900]), torch.Size([2475]), torch.Size([663])]
levels  torch.Size([2, 211038])
top_n_idx  tensor([[ 37959,   3966,   3174,  ..., 210480, 210429, 210431],
        [ 69578,  68786,  69576,  ..., 210430, 210379, 210421]])


In [6]:
image_range=torch.arange(num_images, device=device) # where num_images is batch size
batch_idx=image_range[:,None] # Nx1
print('image_range ', image_range, ' batch_idx ', batch_idx)

# select top k objectness
objectness=objectness[batch_idx, top_n_idx] # NxK 
levels=levels[batch_idx, top_n_idx] # NxK 
proposals=proposals[batch_idx, top_n_idx] # NxK 
objectness_prob=torch.sigmoid(objectness)
print('objectness ', objectness.shape, ' levels ', levels.shape, ' proposals ', proposals.shape)
print('objectness ', objectness.min(), objectness.max())
print('objectness_prob ', objectness_prob.min(), objectness_prob.max())

image_range  tensor([0, 1])  batch_idx  tensor([[0],
        [1]])
objectness  torch.Size([2, 8663])  levels  torch.Size([2, 8663])  proposals  torch.Size([2, 8663, 4])
objectness  tensor(-15.5430) tensor(9.0714)
objectness_prob  tensor(1.7773e-07) tensor(0.9999)


In [13]:
min_size=0.001
score_thresh=0.0  
nms_thresh=0.7
post_nms_top_n=2000
final_boxes=[]
final_scores=[]
for i, (boxes, scores, lvl, img_shape) in enumerate(zip(proposals, objectness_prob, levels, image_shapes)):
    print(i, '-'*50)
    print(f'\tboxes {boxes.shape}, scores {scores.shape}, lvl {lvl.shape}')
    # boxes Kx4 where K from all levels, and the 4 are for x1,y1,x2,y2 
    boxes = box_ops.clip_boxes_to_image(boxes, img_shape) # img_shape is HxW or YxX
    # remove small boxes
    keep_idx=box_ops.remove_small_boxes(boxes, min_size)
    boxes,scores,lvl=boxes[keep_idx],scores[keep_idx],lvl[keep_idx]
    print(f'\tboxes {boxes.shape}, scores {scores.shape}, lvl {lvl.shape}')
    # remove low scoring boxes
    # use >= for backward compatibility
    keep_idx=torch.where(scores>=score_thresh)[0]
    boxes,scores,lvl=boxes[keep_idx],scores[keep_idx],lvl[keep_idx]
    print(f'\tboxes {boxes.shape}, scores {scores.shape}, lvl {lvl.shape}')
    print(f'\tnms_thresh {nms_thresh}, scores {scores.shape}, ({scores.min()}, {scores.max()})')

    # non-maximum suppression, independently done per level
    keep_idx= box_ops.batched_nms(boxes, scores, lvl, nms_thresh) # HWAx4

    # keep only topk scoring prediction
    keep_idx=keep_idx[:post_nms_top_n]
    boxes,scores=boxes[keep_idx],scores[keep_idx]
    print(f'\tnms_thresh {nms_thresh}, scores {scores.shape}, ({scores.min()}, {scores.max()})')

    final_boxes.append(boxes)
    final_scores.append(scores)

print('\nfinal_boxes ', [x.shape for x in final_boxes])
print('\nfinal_scores ', [x.shape for x in final_scores])

0 --------------------------------------------------
	boxes torch.Size([8663, 4]), scores torch.Size([8663]), lvl torch.Size([8663])
	boxes torch.Size([8663, 4]), scores torch.Size([8663]), lvl torch.Size([8663])
	boxes torch.Size([8663, 4]), scores torch.Size([8663]), lvl torch.Size([8663])
	nms_thresh 0.7, scores torch.Size([8663]), (1.7773150773336965e-07, 0.999885082244873)
	nms_thresh 0.7, scores torch.Size([2000]), (0.10210991650819778, 0.999885082244873)
1 --------------------------------------------------
	boxes torch.Size([8663, 4]), scores torch.Size([8663]), lvl torch.Size([8663])
	boxes torch.Size([8658, 4]), scores torch.Size([8658]), lvl torch.Size([8658])
	boxes torch.Size([8658, 4]), scores torch.Size([8658]), lvl torch.Size([8658])
	nms_thresh 0.7, scores torch.Size([8658]), (0.00010085690882988274, 0.9996401071548462)
	nms_thresh 0.7, scores torch.Size([2000]), (0.0072363014332950115, 0.9996401071548462)

final_boxes  [torch.Size([2000, 4]), torch.Size([2000, 4])]

fi

In [15]:
for b, (boxes, scores) in enumerate(zip(final_boxes, final_scores)):
    print(b, '-'*50)
    print('\tboxes ', boxes.shape, boxes.min(0).values, boxes.max(0).values)
    print('\tscores ', scores.shape, scores.min(), scores.max())

0 --------------------------------------------------
	boxes  torch.Size([2000, 4]) tensor([0.0000, 0.0000, 2.8989, 4.0253]) tensor([1019.4376,  797.7085, 1033.0000,  800.0000])
	scores  torch.Size([2000]) tensor(0.1021) tensor(0.9999)
1 --------------------------------------------------
	boxes  torch.Size([2000, 4]) tensor([0.0000, 0.0000, 2.5135, 3.6417]) tensor([1023.5251,  797.5223, 1026.0000,  800.0000])
	scores  torch.Size([2000]) tensor(0.0072) tensor(0.9996)
