In [98]:
%matplotlib inline


from typing import Optional, Union

import os
import sys

import torch
import torchvision
from torch import Tensor
import torch.nn.functional as F
from torchvision.ops import boxes as box_ops
from torchvision.ops.boxes import box_area
from torchvision.ops.roi_align import _bilinear_interpolate
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(0)
np.random.seed(0)

In [3]:
device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('device ', device)

main_dirpath='D:/data/'
tensor_dirpath=os.path.join(main_dirpath, 'mask_rcnn')
original_inputs=torch.load(os.path.join(tensor_dirpath, 'part-backbonefpn-orig_input.pt'), map_location=device, weights_only=True)
tfm_inputs=torch.load(os.path.join(tensor_dirpath, 'part-backbonefpn-transform.pt'), map_location=device, weights_only=False)
proposal_boxes=torch.load(os.path.join(tensor_dirpath, 'part-rpn-proposal_boxes.pt'), map_location=device, weights_only=True)

device  cpu


In [4]:
num_classes=2
# load an instance segmentation model pre-trained on COCO
model=torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights='DEFAULT', rpn_pre_nms_top_n_train=800,
        rpn_pre_nms_top_n_test=500,  rpn_post_nms_top_n_train=800,  rpn_post_nms_top_n_test=500, box_detections_per_img=100,
        box_batch_size_per_image =400, rpn_batch_size_per_image=100)
# get number of input features for the classifier
in_features=model.roi_heads.box_predictor.cls_score.in_features
print('the number of inpute features for classifiers ', in_features)
# replace the pre-trained head with a new one
model.roi_heads.box_predictor=FastRCNNPredictor(in_features,  num_classes)

# get the number of input features for the mask classifiers
in_features_mask=model.roi_heads.mask_predictor.conv5_mask.in_channels
print('the number of input features for mask ', in_features_mask)
hidden_layer=256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor=MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

# move model to the right device
model.to(device);

the number of inpute features for classifiers  1024
the number of input features for mask  256


[`RoIHeads`](https://github.com/pytorch/vision/blob/main/torchvision/models/detection/roi_heads.py)

In [9]:
for name, child in model.roi_heads.named_children(): print(name)

box_roi_pool
box_head
box_predictor
mask_roi_pool
mask_head
mask_predictor


[`RoIHeads.forward`](https://github.com/pytorch/vision/blob/main/torchvision/models/detection/roi_heads.py)

In [5]:
features=tfm_inputs['out']
images=tfm_inputs['tfm_images']
targets=tfm_inputs['tmf_targets']
proposals=proposal_boxes
print('features ', {k:(v.shape, v.min().item(), v.max().item()) for k, v in features.items()})
print('proposals ', [(p.shape, p.min().item(), p.max().item()) for p in proposals])

features  {'0': (torch.Size([2, 256, 208, 264]), -2.1181082725524902, 2.168463945388794), '1': (torch.Size([2, 256, 104, 132]), -1.646379828453064, 1.6166932582855225), '2': (torch.Size([2, 256, 52, 66]), -1.5383321046829224, 1.5961196422576904), '3': (torch.Size([2, 256, 26, 33]), -1.5852569341659546, 1.9735815525054932), 'pool': (torch.Size([2, 256, 13, 17]), -1.4886037111282349, 1.6521222591400146)}
proposals  [(torch.Size([800, 4]), 0.0, 1053.0), (torch.Size([800, 4]), 0.0, 830.0)]


In [12]:
# check targets
for t in targets:
    if t['boxes'].dtype not in [torch.float, torch.half, torch.double]: raise TypeError(f'target boxes must be float, instead got {t["boxes"].dtype}')
    if t['labels'].dtype != torch.int64: raise TypeError(f'target labels must be int64, instead got {t["labels"].dtype}')
    if model.roi_heads.has_keypoint():
        if t['keypoints'].dtype!=torch.float32: raise TypeError(f'target keypoints must be float, instead got {t["keypoints"].dtype}')
        

`proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)`

[`select_training_samples`](https://github.com/pytorch/vision/blob/main/torchvision/models/detection/roi_heads.py#L633)

In [45]:
def check_targets(targets, use_mask=False):
    '''
    targets (list[dict[str, Tensor]])
    '''
    assert targets is not None, 'targets should not be None'
    if not all('boxes' in t for t in targets): raise ValueError('Every element of targets should have a boxes key')
    if not all('labels' in t for t in targets): raise ValueError('Every element of targets should have a labels key')
    if use_mask:
        if not all('masks' in t for t in targets): raise ValueError('Every element of targets should have a masks key')

def assign_targets_to_proposals(proposals, gt_boxes,gt_labels, matcher):
    '''
    Args:
        proposals (list[Tensor]): Mx4 boxes/proposals per image
        gt_boxes (list[Tensor]): Gx4 ground truth boxes per image
        gt_labels (list[Tensor]): G ground truth labels per image
        matcher (torchvision.models.detection._utils.Matcher)
    Returns:
        matched_idxs (list[Tensor]): M indices into corresponding grough truth boxes for each proposal
        labels (list[Tensor]): M ground truth labels for each proposal
    '''
    
    matched_idxs, labels=[],[]
    for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
        if gt_boxes_in_image.numel()==0: # background image
            device=proposals_in_image.device
            clamped_matched_idx_in_image=torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
            labels_in_image=torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
        else:
            match_quality_matrix=box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
            matched_idxs_in_image=matcher(match_quality_matrix) # M indices

            clamped_matched_idx_in_image=matched_idxs_in_image.clamp(min=0)
            labels_in_image=gt_labels_in_image[clamped_matched_idx_in_image] # index into  G index-tensor by M indices -> M labels
            labels_in_image=labels_in_image.to(dtype=torch.int64)
            # label background (below low threshold)
            bg_inds=matched_idxs_in_image==matcher.BELOW_LOW_THRESHOLD
            labels_in_image[bg_inds]=0
            # label ignore proposals (between low and high thresholds)
            ignore_inds=matched_idxs_in_image==matcher.BETWEEN_THRESHOLDS
            labels_in_image[ignore_inds]=-1 # -1 is ignored by sampler
        matched_idxs.append(clamped_matched_idx_in_image)
        labels.append(labels_in_image)
    return matched_idxs, labels

def subsample(fg_bg_sampler, labels):
    '''
    Args:
        fg_bg_sampler (torchvision.models.detection._utils.BalancedPositiveNegativeSampler)
        labels (list[Tensor]): labels of ground truth for each proposals
    '''
    sampled_pos_inds, sampled_neg_inds=fg_bg_sampler(labels)
    sampled_inds=[]
    for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
        img_sampled_inds=torch.nonzero(torch.bitwise_or(pos_inds_img, neg_inds_img), as_tuple=True)[0]
        # above is similar to torch.where(pos_inds_img | neg_inds_img)[0]
        sampled_inds.append(img_sampled_inds)
    return sampled_inds
    
def select_training_samples(use_mask, matcher, fg_bg_sampler, box_coder, proposals, targets):
    '''
    Args:
        proposals (list[Tensor]): Mx4 boxes/proposals per image
        targets (list[dict[str, Tensor]]): target per image, each is dict of annotation including Gx4 target boxes in each image
    Returns:
        proposals (list[Tensor]): Sx4 filtered proposals (balanced positive and negative) per image, where S<=M
        matched_idxs (list[Tensor]): S indices to ground-truth for each proposal
        labels (list[Tensor]): S ground-truth labels for each proposal 
        regression_targets (list[Tensor]): Sx4 regression targets (delta/adjustment between proposal and ground-truth)
    '''
    check_targets(targets=targets, use_mask=use_mask)
    dtype, device=proposals[0].dtype, proposals[0].device

    gt_boxes=[t['boxes'] for t in targets] # list of Mx4 where M is the number of boxes in each image
    gt_labels=[t['labels'] for t in targets] # list of M labels (i.e., 1D tensor) each for each box in each image
    
    # append ground-truth bboxes to propos
    #proposals = self.add_gt_proposals(proposals, gt_boxes)
    proposals=[torch.cat([proposal, gt_box], dim=0) for proposal, gt_box in zip(proposals, gt_boxes)]# list of Mx4 boxes per image
    print('proposals ', [(p.shape, p.requires_grad) for p in proposals])

    # get matching gt indices for each proposal
    # matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
    # matched_idxs are indices into gt_boxes to get ground-truth boxes for each proposal
    # labels are ground truth labels for each ground-truth associated with proposals
    matched_idxs, labels=assign_targets_to_proposals(proposals, gt_boxes,gt_labels, matcher=matcher)
    # matched_idxs_ref, labels_ref = model.roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
    # print('matched_idxs ', [torch.allclose(m, n) for m, n in zip(matched_idxs, matched_idxs_ref)])
    # print('labels ', [torch.allclose(m, n) for m, n in zip(labels, labels_ref)])

    # sample a fixed porportion of positive-negative proposals
    #sampled_inds = self.subsample(labels)
    sampled_inds=subsample(fg_bg_sampler, labels)

    matched_gt_boxes=[]
    num_images=len(proposals)
    for img_id in range(num_images):
        img_sampled_inds=sampled_inds[img_id]
        proposals[img_id]=proposals[img_id][img_sampled_inds]
        labels[img_id]=labels[img_id][img_sampled_inds]
        matched_idxs[img_id]=matched_idxs[img_id][img_sampled_inds]

        gt_boxes_in_image=gt_boxes[img_id]
        if gt_boxes_in_image.numel()==0:
            gt_boxes_in_image=torch.zeros((1,4), dtype=dtype, device=device)
        matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])

    regression_targets=box_coder.encode(matched_gt_boxes, proposals)
    print('regression_targets ',[r.shape for r in  regression_targets])
    print('proposals ',[p.shape for p in proposals])
    print('matched_idxs ', [m.shape for m in matched_idxs])
    print('labels ', [l.shape for l in labels])
    return proposals, matched_idxs, labels, regression_targets


            
proposals, matched_idxs, labels, regression_targets=select_training_samples(use_mask=model.roi_heads.has_mask(), matcher=model.roi_heads.proposal_matcher, 
                        fg_bg_sampler=model.roi_heads.fg_bg_sampler,box_coder=model.roi_heads.box_coder,
                        proposals=proposals, targets=targets)

proposals_ref, matched_idxs_ref, labels_ref, regression_targets_ref = model.roi_heads.select_training_samples(proposals, targets)
for i, (ref, obs) in enumerate(zip(regression_targets_ref, regression_targets)):
    print(i, '-'*50)
    print('ref ', ref.shape, ref.min().item(), ref.max().item())
    print('obs ', obs.shape, obs.min().item(), obs.max().item())
    print(torch.allclose(ref, obs)) # internally call torch.randperm so it is not deterministic

proposals  [(torch.Size([402, 4]), False), (torch.Size([401, 4]), False)]
regression_targets  [torch.Size([400, 4]), torch.Size([400, 4])]
proposals  [torch.Size([400, 4]), torch.Size([400, 4])]
matched_idxs  [torch.Size([400]), torch.Size([400])]
labels  [torch.Size([400]), torch.Size([400])]
0 --------------------------------------------------
ref  torch.Size([400, 4]) -36731.43359375 521.9777221679688
obs  torch.Size([400, 4]) -36731.43359375 521.9777221679688
False
1 --------------------------------------------------
ref  torch.Size([400, 4]) -221.0601043701172 1277.6134033203125
obs  torch.Size([400, 4]) -221.0601043701172 1277.6134033203125
False


[`MultiScaleRoIAlign`](https://github.com/pytorch/vision/blob/main/torchvision/ops/poolers.py)

```
box_features = self.box_roi_pool(features, proposals, image_shapes)
```

In [49]:
print(model.roi_heads.box_roi_pool.featmap_names)
print('feature names ', features.keys())

def _filter_input(x: dict[str, Tensor], featmap_names: list[str])->list[Tensor]:
    '''
    Only select features whose names are in `featmap_names
    '''
    x_filtered=[]
    for k, v in x.items():
        if k in featmap_names: x_filtered.append(v)
    return x_filtered

['0', '1', '2', '3']
feature names  odict_keys(['0', '1', '2', '3', 'pool'])


In [81]:
x=features
# x_filtered=_filter_input(x, self.featmap_names)
x_filtered=_filter_input(x, model.roi_heads.box_roi_pool.featmap_names)
print('x_filtered ', [i.shape for i in x_filtered])

x_filtered  [torch.Size([2, 256, 208, 264]), torch.Size([2, 256, 104, 132]), torch.Size([2, 256, 52, 66]), torch.Size([2, 256, 26, 33])]


[_setup_scales](https://github.com/pytorch/vision/blob/main/torchvision/ops/poolers.py#L110)

In [82]:
def _infer_scale(feature: Tensor, original_size: list[int])->float:
    # assume the scale is of the form 2**(-k) where k is integer
    size=feature.shape[-2:]
    possible_scales:list[float]=[]
    for s1, s2 in zip(size, original_size):
        approx_scale=float(s1)/float(s2)
        scale=2**float(torch.tensor(approx_scale).log2().round())
        possible_scales.append(scale)
        break # we note that x and y direction returns the same scale so we just compute on x direction only
    return possible_scales[0]
    
class LevelMapper: # see https://github.com/pytorch/vision/blob/main/torchvision/ops/poolers.py#L47
    '''
    Determine which FPN level each RoI should map to based on the heuristic in the FPN paper
    Args:
        k_min (int): the 1st level of FPN considered (shallow level with finer feature)
        k_max (int): the last level of FPN (deeper level with coarser feature)
        canonical_scale (int): geometric size of input (size used to pretrain backbone, e.g., 224)
        canonical_level (int): base FPN level (e.g., 4)
        eps (float)
    '''
    def __init__(self, k_min:int, k_max: int, canonical_scale:int=224, canonical_level:int=4, eps:float=1.e-6):
        self.k_max=k_max
        self.k_min=k_min
        self.s0=canonical_scale
        self.lvl0=canonical_level
        self.eps=eps
    def __call__(self, boxlists: list[Tensor])->Tensor:
        '''
        Args:
            boxlists (list[Tensor])
        '''
        # compute level ids
        s=torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists])) # M geomtric mean size
        print(f'In LevelMapper s {s.shape}')

        # eq.1 in FPN paper
        target_lvls=torch.floor(self.lvl0+torch.log2(s/self.s0)+torch.tensor(self.eps, dtype=s.dtype))
        target_lvls=torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
        return (target_lvls.to(torch.int64)-self.k_min).to(torch.int64)
        
        

In [83]:
# if self.scales is None or self.map_levels is None:
#             self.scales, self.map_levels = _setup_scales(
#                 x_filtered, image_shapes, self.canonical_scale, self.canonical_level
#             )

def _setup_scales(features: list[Tensor], image_shapes: list[tuple[int, int]], canonical_scale:int, 
                  canonical_level:int, eps:float=1e-6)->tuple[list[float], LevelMapper]:
    
    assert image_shapes is not None, 'image list should not be empty'

    max_x=max_y=0
    for shape in image_shapes:
        max_x=max(shape[0], max_x)
        max_y=max(shape[1], max_y)
    original_input_shape=(max_x, max_y)

    scales=[]
    for feat in features:
        scale=_infer_scale(feat, original_input_shape)
        scales.append(scale)

    # get the levels in the feature map by leveraging the fact that the network always downsamples by
    # a factor of 2 at each level
    lvl_min=-np.log2(scales[0])
    lvl_max=-np.log2(scales[-1])

    map_levels=LevelMapper(k_min=int(lvl_min), k_max=int(lvl_max), canonical_scale=canonical_scale,
                           canonical_level=canonical_level, eps=eps)
    
    
    return scales, map_levels

    
model.roi_heads.box_roi_pool.scales, model.roi_heads.box_roi_pool.map_levels  = _setup_scales(\
                x_filtered, images.image_sizes, model.roi_heads.box_roi_pool.canonical_scale,
                model.roi_heads.box_roi_pool.canonical_level)
print('model.roi_heads.box_roi_pool.scales ', model.roi_heads.box_roi_pool.scales)

model.roi_heads.box_roi_pool.scales  [0.25, 0.125, 0.0625, 0.03125]


```
_multiscale_roi_align(
            x_filtered,
            boxes,
            self.output_size,
            self.sampling_ratio,
            self.scales,
            self.map_levels,
        )
```

[`_multiscale_roi_align`](https://github.com/pytorch/vision/blob/main/torchvision/ops/poolers.py#L147)

In [101]:
def roi_align(input, boxes, output_size, spatial_scale, sampling_ratio, aligned=False):
    '''
    Perform region of interest (ROI) align operator with average pooling, as described in Mask R-CNN
    Args:
        input (Tensor[N,C,H,W]): the input tensor
        boxes (Tensor[K,5] or List[Tensor[L,4]]): the box coordinates in (x1,y1,x2,y2) where the regions will be traken from
            If a single tensor is passed, the first column should be the index of the corresponding element in the batch, i.e., a number in ``[0, N-1]``
            If a list of Tensors, the each Tensor will correspond to the boxes for an element i in the batch
        output_size (int or Tuple[int, int]): the size of output (in bins or pixels) after the pooling is performed as (height width)
        spatial_scale (float): a scaling factor that maps the box coordinates to the input coordinates. For example, if your boxes are defined on the 
            scale of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of the original image), y'll want to set
            this to 0.5
        sampling_ratio (int): number of sampling points in the interpolation grid used to compute the output value of each pooled output bin. 
            if >0, exactly ``sampling_ratioxsampling_ratio`` sampling points per bin are used. If <=0, an adaptive number of grid points are used 
            (computed as ``ceil(roi_width/output_width)``, and likewise for height)
        aligned (bool): If False, use legacy implementation. Otherwise, pixel shift the box coordinates by -.5 for a better alignment with the two
            neighboring pixel indices. This version used in Detectron2
    Returns:
        Tensor[K,C,output_size[0], output_size[1]]: the pooled ROIs
    '''
    #_roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
    pooled_height, pooled_width=output_size[0], output_size[1]
    orig_dtype=input.dtype
    rois=boxes # Kx5
    # input NxCxHxW
    height, width=input.shape[2:]
    ph=torch.arange(pooled_height, device=input.device) # [PH]
    pw=torch.arange(pooled_width, device=input.device) # [PW]

    roi_batch_ind=rois[:, 0].int() # [K]
    offset=0.5 if aligned else 0.
    roi_start_w=rois[:,1]*spatial_scale-offset #[K]
    roi_start_h=rois[:,2]*spatial_scale-offset #[K]
    roi_end_w=rois[:,3]*spatial_scale-offset #[K]
    roi_end_h=rois[:,4]*spatial_scale-offset #[K]

    roi_width=roi_end_w-roi_start_w # [K]
    roi_height=roi_end_h-roi_start_h # [K]
    if not aligned:
        roi_width=torch.clamp(roi_width, min=1.) # [K]
        roi_height=torch.clamp(roi_height, min=1.) # [K]  

    bin_size_h=roi_height/pooled_height # [K]
    bin_size_w=roi_width/pooled_width # [K]
    exact_sampling=sampling_ratio>0
    roi_bin_grid_h=sampling_ratio if exact_sampling else torch.ceil(roi_height/pooled_height) # scalar or [K]
    roi_bin_grid_w=sampling_ratio if exact_sampling else torch.ceil(roi_width/pooled_width) # scalar or [K]

    if exact_sampling:
        count=max(roi_bin_grid_h*roi_bin_grid_w, 1) # scalar
        iy=torch.arange(roi_bin_grid_h, device=input.device) # [IY]
        ix=torch.arange(roi_bin_grid_w, device=input.device) # [IX]
        ymask=xmask=None
    else:
        count=torch.clamp(roi_bin_grid_h*roi_bin_grid_w, min=1) # [K]
        # when doing adaptive sampling, the number of samples we need to do is data-dependent based on how big ROIs are. This is a bit awkward
        # because first class dims cannot actually handle this. So instead, we inefficiently suppose that we needed to sample all the points and 
        # mask out things that turned out to be unnecessary
        iy=torch.arange(height, device=input.device) # [IY]
        ix=torch.arange(width, device=input.device) # [IX]
        ymask=iy[None,:]<roi_bin_grid_h[:,None] # [K, IY]
        xmask=ix[None,:]<roi_bin_grid_w[:,None] # [K, IX]
    def from_K(t): return t[:,None,None]
 
        # [K,1,1]              #[1 PH 1] [K,1,1]                   # [1,1,IY][K,1,1]
    y=(from_K(roi_start_h)+   ph[None,:,None]*from_K(bin_size_h)+ (iy[None,None,:]+0.5).to(input.dtype)*from_K(bin_size_h/roi_bin_grid_h))#[K,PH,IY]
    x=(from_K(roi_start_w)+ pw[None,:,None]*from_K(bin_size_w)+ (ix[None, None,:]+.5).to(input.dtype)*from_K(bin_size_w/roi_bin_grid_w)) # [K PW IX]
    print('roi_batch_ind ', roi_batch_ind.shape, roi_batch_ind.unique())
    val=_bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K,C,PH, PW, IY, IX]

    # mask out samples that weren't actually adpatively needed
    if not exact_sampling:
        val = torch.where(ymask[:, None, None, None, :, None], val, 0)
        val = torch.where(xmask[:, None, None, None, None, :], val, 0)

    output=val.sum((-1, -2)) # remove IY, IX -> [K, C, PH, PW]
    if isinstance(count, torch.Tensor): output/=count[:,None,None,None]
    else:output/=count
    output=output.to(dtype=orig_dtype)
    return output


In [104]:
def _multiscale_roi_align(x_filtered: list[Tensor], boxes:list[Tensor], 
output_size:list[int], sampling_ratio:int, scales:Optional[list[float]],
mapper:Optional[LevelMapper])->Tensor:
    '''
    Args:
        x_filtered (list[Tensor]): list of features from backbone FPN
        boxes (list[Tensor[N,4]]): boxes to be used to perform the pooling operation, in (x1, y1, x2, y2) format
            and in the image reference size, not the feature map reference. The coordinate must satisfy 
            ``0<=x1<x2`` and ``0<=y1<y2``.
        output_size (Union[list[tuple[int, int]], list[int]]): size of output
        sampling_ratio (int): sampling ratio for RoIAlign
        scales (Optional[list[float]]): the ratio between feature size and image size
        mapper (Optional[LevelMapper]): 
    Returns:
        result (Tensor):
    '''
    assert all(x is not None for x in [scales, mapper]), 'scales and mapper should not be None'
    num_levels=len(x_filtered)

    # concatenate proposals to form a single tensor and append its with image id so we know which image
    # which proposal associated with
    #rois = _convert_to_roi_format(boxes)
    concat_boxes=torch.cat(boxes, dim=0) # Mx4
    device, dtype=concat_boxes.device, concat_boxes.dtype
    # Mx1
    ids=torch.cat([torch.full_like(b[:,:1], i, dtype=dtype, device=device) for i, b in enumerate(boxes)], dim=0)
    # Mx5
    rois=torch.cat([ids, concat_boxes], dim=1)

    levels=mapper(boxes) # M levels for each box 

    num_rois=len(rois)
    num_channels=x_filtered[0].shape[1]
    dtype, device=x_filtered[0].dtype, x_filtered[0].device
    result=torch.zeros((num_rois, num_channels)+output_size, dtype=dtype, device=device)
    print('result ', result.shape)
    for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
        # per_level_feature is of size BxCxHxW where B is batch size, C is the number of channel
        print(level, per_level_feature.shape, scale)
        idx_in_level=torch.nonzero(levels==level, as_tuple=True)[0]
        rois_per_level=rois[idx_in_level]

        print('rois_per_level ', rois_per_level.shape)
        #  result_idx_in_level = roi_align(per_level_feature, rois_per_level, output_size=output_size, spatial_scale=scale, sampling_ratio=sampling_ratio,)
        result_idx_in_level = roi_align(per_level_feature, rois_per_level, output_size=output_size, spatial_scale=scale,
                                        sampling_ratio=sampling_ratio,)
        print('result_idx_in_level ', result_idx_in_level.shape)
        result[idx_in_level]=result_idx_in_level.to(dtype=result.dtype)
    return result
    
result=_multiscale_roi_align(x_filtered, boxes=proposals, 
                      output_size=model.roi_heads.box_roi_pool.output_size, 
                      sampling_ratio=model.roi_heads.box_roi_pool.sampling_ratio, 
                      scales=model.roi_heads.box_roi_pool.scales,
                      mapper=model.roi_heads.box_roi_pool.map_levels)

In LevelMapper s torch.Size([800])
result  torch.Size([800, 256, 7, 7])
0 torch.Size([2, 256, 208, 264]) 0.25
rois_per_level  torch.Size([453, 5])
roi_batch_ind  torch.Size([453]) tensor([0, 1], dtype=torch.int32)
result_idx_in_level  torch.Size([453, 256, 7, 7])
1 torch.Size([2, 256, 104, 132]) 0.125
rois_per_level  torch.Size([154, 5])
roi_batch_ind  torch.Size([154]) tensor([0, 1], dtype=torch.int32)
result_idx_in_level  torch.Size([154, 256, 7, 7])
2 torch.Size([2, 256, 52, 66]) 0.0625
rois_per_level  torch.Size([159, 5])
roi_batch_ind  torch.Size([159]) tensor([0, 1], dtype=torch.int32)
result_idx_in_level  torch.Size([159, 256, 7, 7])
3 torch.Size([2, 256, 26, 33]) 0.03125
rois_per_level  torch.Size([34, 5])
roi_batch_ind  torch.Size([34]) tensor([0, 1], dtype=torch.int32)
result_idx_in_level  torch.Size([34, 256, 7, 7])


In [105]:
print(453+154+159+34)

800

In [109]:
torch.save({'box_features':result,
           'proposals':proposals, 
            'matched_idxs':matched_idxs, 'labels':labels, 'regression_targets':regression_targets},
           os.path.join(tensor_dirpath, 'part-roi_head-box_roi_pool.pt'))

In [108]:
print('result ', result.shape)

result  torch.Size([800, 256, 7, 7])
