In [1]:
from typing import Optional, Union


import os
import torch
import torch.nn as nn
from torch import Tensor

import numpy as np
from torch.nn.modules.utils import _pair

In [2]:
from torchvision.ops.poolers import LevelMapper

In [3]:
data_dirpath='D:/data/mask_rcnn'
data=torch.load(os.path.join(data_dirpath, 'box_roi_pool.pt'), map_location=torch.device("cpu"),weights_only=True)
print(data.keys())
out=dict() # we need to remove gradient since we need to save memory
for k, v in data['out'].items():
    v.requires_grad=False
    out[k]=v
proposals=data['proposals']
image_shapes=data['image_shapes']
print('out ', {k:(v.shape, v.requires_grad) for k, v in out.items()})
print('image_shapes ', image_shapes)
print('proposals ', [(p.shape, p.min().item(), p.max().item(), p.requires_grad) for p in proposals])

dict_keys(['out', 'proposals', 'image_shapes'])
out  {'0': (torch.Size([2, 256, 200, 296]), False), '1': (torch.Size([2, 256, 100, 148]), False), '2': (torch.Size([2, 256, 50, 74]), False), '3': (torch.Size([2, 256, 25, 37]), False), 'pool': (torch.Size([2, 256, 13, 19]), False)}
image_shapes  [(800, 861), (799, 1159)]
proposals  [(torch.Size([801, 4]), 0.0, 861.0, False), (torch.Size([803, 4]), 0.0, 1159.0, False)]


In [4]:
def _filter_input(x: dict[str, Tensor], featmap_names: list[str])->list[Tensor]:
    x_filtered=[]
    for k, v in out.items():
        if k in featmap_names: x_filtered.append(v)
    return x_filtered

In [5]:
def _infer_scale(feature: Tensor, original_size:list[int])->float:
    # assumption: the scale is of the form 2 ** (-k), with k integer
    size=feature.shape[-2:]
    possible_scales:list[float]=[]
    for s1, s2 in zip(size, original_size):
        approx_scale=float(s1)/float(s2)
        # print(f'In _infer_scale torch.tensor(approx_scale).log2(){torch.tensor(approx_scale).log2()}')
        scale=2**float(torch.tensor(approx_scale).log2().round())
        possible_scales.append(scale)
        print('s1 ', s1, ' s2 ', s2, ' scale ', scale, ' s1/s2 ', s1/s2, flush=True)
    return possible_scales[0]


[`setup_scales`](https://github.com/pytorch/vision/blob/main/torchvision/ops/poolers.py#L110)


In [6]:
def _setup_scales(features:list[Tensor], image_shapes:list[tuple[int, int]], canonical_scale:int,
                 canonical_level:int)->tuple[list[float], LevelMapper]:

    if not image_shapes: raise ValueError('image size list should not be empty')
    max_x=max_y=0
    for shape in image_shapes:
        max_x=max(shape[0], max_x)
        max_y=max(shape[1], max_y)
    original_input_shape=(max_x, max_y)
    scales=[_infer_scale(feat, original_input_shape) for feat in features]
    # get the levels in the feature map by leveraging the fact that the network always
    # downsamples by a factor of 2 at each level
    lvl_min=-torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
    lvl_max=-torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
    print('original_input_shape ', original_input_shape)
    map_levels=LevelMapper(int(lvl_min), int(lvl_max), canonical_scale=canonical_scale,
                      canonical_level=canonical_level)
    return scales, map_levels


In [7]:
class MultiScaleRoIAlign(nn.Module):
    # reference from https://github.com/pytorch/vision/blob/main/torchvision/ops/poolers.py
    def __init__(self, 
                 featmap_names: list[str], 
                 output_size: Union[int, tuple[int], list[int]],
                 sampling_ratio: int,
                 *,
                 canonical_scale:int=224,
                 canonical_level:int=4):

        super().__init__()
        if isinstance(output_size, int): output_size=(output_size, output_size)
        self.featmap_names=featmap_names
        self.sampling_ratio=sampling_ratio
        self.output_size=tuple(output_size)
        self.scales=None
        self.map_levels=None
        self.canonical_scale=canonical_scale
        self.canonical_level=canonical_level
    def forward(self, 
               x: dict[str: Tensor],
               boxes: list[Tensor], 
               image_shapes: list[tuple[int, int]])->Tensor:
        """
        Args:
            x (dict[Tensor]): feature maps for each level. They are assumed to have all the same number of 
                channels, but they can have different sizes
            boxes (List[Tensor[N,4]]): boxes to be used to to perform the pooling operation, in (x1,y1,x2,y2) format and
                in the image refence size, not the feature map reference. The coordinate must satisfy ``0<=x1<x2`` and
                ``0<=y1<y2``
            image_shapes (List[Tuple[height, width]]): the sizes of each image before they have been fed to a CNN to obtain feature maps.
                This allows us to infor the scale factor for each one of the levels to be pooled
        Returns:
            result (Tensor)
        """
        pass


box_roi_pool=box_roi_pool=MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], 
                                             output_size=(7, 7), sampling_ratio=2)

In [8]:
x=out
x_filtered=_filter_input(x, box_roi_pool.featmap_names)
if box_roi_pool.scales is None or box_roi_pool.map_levels is None:
    box_roi_pool.scales, box_roi_pool.map_levels=_setup_scales(x_filtered, image_shapes,
                        box_roi_pool.canonical_scale, box_roi_pool.canonical_level)
print('box_roi_pool.scales ', box_roi_pool.scales)

s1  200  s2  800  scale  0.25  s1/s2  0.25
s1  296  s2  1159  scale  0.25  s1/s2  0.2553925798101812
s1  100  s2  800  scale  0.125  s1/s2  0.125
s1  148  s2  1159  scale  0.125  s1/s2  0.1276962899050906
s1  50  s2  800  scale  0.0625  s1/s2  0.0625
s1  74  s2  1159  scale  0.0625  s1/s2  0.0638481449525453
s1  25  s2  800  scale  0.03125  s1/s2  0.03125
s1  37  s2  1159  scale  0.03125  s1/s2  0.03192407247627265
original_input_shape  (800, 1159)
box_roi_pool.scales  [0.25, 0.125, 0.0625, 0.03125]


```
 _multiscale_roi_align(
            x_filtered,
            boxes,
            self.output_size,
            self.sampling_ratio,
            self.scales,
            self.map_levels,
        )
```
[_multiscale_roi_align](https://github.com/pytorch/vision/blob/main/torchvision/ops/poolers.py#L147)

In [9]:
boxes=proposals
output_size=box_roi_pool.output_size
sampling_ratio=box_roi_pool.sampling_ratio
scales=box_roi_pool.scales
mapper=box_roi_pool.map_levels
print('boxes ', type(boxes), [b.shape for b in boxes])
print('output_size ', output_size)
print('scales ', scales)
print('mapper ', mapper, mapper.k_min, mapper.k_max)

boxes  <class 'list'> [torch.Size([801, 4]), torch.Size([803, 4])]
output_size  (7, 7)
scales  [0.25, 0.125, 0.0625, 0.03125]
mapper  <torchvision.ops.poolers.LevelMapper object at 0x0000016510256CD0> 2 5


In [10]:
if any(x is None for x in [scales, mapper]):
    raise ValueError('scales and mapper should not be None')

num_levels=len(x_filtered)
print('num_levels ', num_levels)
print('x_filtered ', type(x_filtered))

# _convert_to_roi_format https://github.com/pytorch/vision/blob/main/torchvision/ops/poolers.py#L87
concat_boxes=torch.cat(boxes, dim=0) # Nx4
device, dtype=concat_boxes.device, concat_boxes.dtype

# create ids for a set of boxes of each image so we know which boxes for which image 
img_ids=torch.cat(
    [torch.full_like(b[:,:1], i, dtype=dtype, layout=torch.strided,device=device) for i, b in enumerate(boxes)],
    dim=0
)
print('img_ids ', img_ids.shape)

rois = torch.cat([img_ids, concat_boxes], dim=1) # Nx5
print('rois ', rois.shape)

num_levels  4
x_filtered  <class 'list'>
img_ids  torch.Size([1604, 1])
rois  torch.Size([1604, 5])


```
levels=mapper(boxes)
```
[`__call__`](https://github.com/pytorch/vision/blob/main/torchvision/ops/poolers.py#L87)

In [11]:
boxlists=boxes

# compute geometric mean area
box_area=lambda boxlist: (boxlist[..., 2] - boxlist[..., 0]) * (boxlist[..., 3] - boxlist[..., 1])
s=torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists])) # 1D tensor
print('s ', s.shape, s.min().item(), s.max().item())
# Eqn.1 in FPN paper
print('lvl0 ', mapper.lvl0, ' s0 ', mapper.s0, ' eps ', mapper.eps)
target_lvls=torch.floor(mapper.lvl0+torch.log2(s/mapper.s0)+torch.tensor(mapper.eps, dtype=s.dtype))
print('target_lvls ', target_lvls.shape, target_lvls.min().item(), target_lvls.max().item())
print('k_min ', mapper.k_min, ' k_max ', mapper.k_max)
target_lvls=torch.clamp(target_lvls, min=mapper.k_min, max=mapper.k_max)
# why do we delete k_min from level?
# ChatGPT said they did this for indexing purposes. So index start at 0
# So if level=0, ROI is place to index to 0 which mapped directly to the first feature map
levels=(target_lvls.to(torch.int64)-mapper.k_min).to(torch.int64) 
print('levels ', levels.shape, levels.min().item(), levels.max().item())

s  torch.Size([1604]) 8.276786804199219 797.4650268554688
lvl0  4  s0  224  eps  1e-06
target_lvls  torch.Size([1604]) -1.0 5.0
k_min  2  k_max  5
levels  torch.Size([1604]) 0 3


In [12]:
num_rois=len(rois)
num_channels=x_filtered[0].shape[1]
print('num_rois ', num_rois, ' num_channels ', num_channels, ' x_filtered ', [i.shape for i in x_filtered])

num_rois  1604  num_channels  256  x_filtered  [torch.Size([2, 256, 200, 296]), torch.Size([2, 256, 100, 148]), torch.Size([2, 256, 50, 74]), torch.Size([2, 256, 25, 37])]


In [14]:
device,dtype=x_filtered[0].device, x_filtered[0].dtype
result=torch.zeros( (num_rois, num_channels)+output_size, dtype=dtype, device=device)
print('result ', result.shape)

result  torch.Size([1604, 256, 7, 7])


In [15]:
# iterate from feature extracted from finer image to feature extracted from coarser image
for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)): 
    idx_in_level=torch.nonzero(levels==level, as_tuple=True)[0]
    rois_per_level=rois[idx_in_level]
    break

In [16]:
idx_in_level

tensor([   6,    9,   10,   12,   13,   17,   21,   23,   24,   25,   26,   27,
          29,   30,   32,   35,   36,   37,   40,   41,   42,   43,   45,   46,
          48,   50,   51,   53,   54,   55,   56,   57,   59,   60,   61,   65,
          66,   67,   69,   70,   72,   73,   76,   77,   78,   79,   80,   81,
          83,   85,   86,   88,   89,   91,   92,   93,   95,   96,   97,   98,
          99,  100,  101,  102,  103,  104,  106,  107,  108,  109,  110,  111,
         112,  113,  115,  116,  118,  119,  120,  121,  122,  123,  126,  128,
         129,  130,  131,  132,  134,  135,  136,  137,  138,  139,  140,  141,
         142,  143,  144,  146,  147,  149,  150,  151,  152,  154,  155,  156,
         157,  158,  159,  160,  161,  162,  163,  164,  165,  166,  167,  168,
         169,  170,  171,  172,  173,  175,  176,  177,  178,  180,  181,  183,
         185,  186,  187,  188,  189,  190,  191,  192,  193,  196,  197,  198,
         199,  201,  202,  203,  204,  2

[roi_align](https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py)

In [17]:
input=per_level_feature
boxes=rois_per_level
spatial_scale=scale
aligned=False
print('sampling_ratio ', sampling_ratio, ' output_size ', output_size)

sampling_ratio  2  output_size  (7, 7)


In [18]:
rois=boxes
print('rois ', type(rois), rois.shape)

rois  <class 'torch.Tensor'> torch.Size([849, 5])


[_roi_align](https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py#L9)

In [19]:
pooled_height=output_size[0] 
pooled_width=output_size[1] 

orig_dtype=input.dtype
_,_,height,width=input.size()
ph=torch.arange(pooled_height,device=input.device) # [PH]
pw=torch.arange(pooled_width, device=input.device) # [PW]

# input: [N,C,H,W]
# rois: [K, 5]

roi_batch_ind=rois[:,0].int() # [k]
offset=0.5 if aligned else 0.
# where spatial scale tell how much smaller features compared to image
roi_start_w=rois[:,1]*spatial_scale - offset # [K]
roi_start_h=rois[:,2]*spatial_scale - offset # [K]
roi_end_w=rois[:,3]*spatial_scale-offset # [K]
roi_end_h=rois[:,4]*spatial_scale-offset # [K]

roi_width=roi_end_w-roi_start_w  # [K]
roi_height=roi_end_h-roi_start_h # [K]
print('roi_width ', roi_width.shape, roi_width.min().item(), roi_width.max().item())
print('roi_height ', roi_height.shape, roi_height.min().item(), roi_height.max().item())

if not aligned:
    roi_width=torch.clamp(roi_width,  min=1.0) # [K]
    roi_height=torch.clamp(roi_height, min=1.0) # [K]
print('roi_width ', roi_width.shape, roi_width.min().item(), roi_width.max().item())
print('roi_height ', roi_height.shape, roi_height.min().item(), roi_height.max().item())


roi_width  torch.Size([849]) 0.70745849609375 120.55207824707031
roi_height  torch.Size([849]) 0.6410369873046875 127.11204528808594
roi_width  torch.Size([849]) 1.0 120.55207824707031
roi_height  torch.Size([849]) 1.0 127.11204528808594


In [20]:
bin_size_h=roi_height/pooled_height # [K]
bin_size_w=roi_width/pooled_width # [K]
print('bin_size_h ', bin_size_h[:10], ' bin_size_w ', bin_size_w[:10])

exact_sampling=sampling_ratio>0

roi_bin_grid_h=sampling_ratio if exact_sampling else torch.ceil(roi_height/pooled_height) # scalar or [K]
roi_bin_grid_w=sampling_ratio if exact_sampling else torch.ceil(roi_width/pooled_width) # scalar or [K]
print('roi_bin_grid_h ', roi_bin_grid_h,  ' roi_bin_grid_w ', roi_bin_grid_w)

bin_size_h  tensor([ 1.8297,  3.0553,  4.6488,  3.6692,  4.0668,  4.0846,  2.4883,  3.6748,
         1.6789,  1.8744,  3.6051,  1.6295,  3.2082,  4.3968,  3.7302,  4.4600,
         3.3769,  2.6490,  3.4099,  3.7661,  3.2102,  2.8646,  4.4398,  4.0115,
         3.5923,  2.5295,  3.3016,  2.0284,  3.6113,  2.8184,  4.2214,  3.4182,
         2.1966,  2.6943,  3.5541,  2.0056,  2.2153,  2.0122,  3.8218,  1.9553,
         2.3758,  2.3740,  2.8747,  2.8402,  2.5494,  4.7609,  2.6205,  2.0872,
         2.4516,  4.0521,  3.7127,  4.4879,  4.5062,  2.8869,  2.8347,  3.8302,
         2.9499,  2.4357,  2.1665,  2.8879,  2.5705,  2.7838,  1.5856,  2.1666,
         3.8744,  2.6102,  2.7108,  2.9725,  1.6692,  3.1174,  1.5822,  2.8359,
         2.8920,  4.1371,  4.8053,  2.6567,  3.1036,  4.7339,  2.6632,  2.6439,
         2.7375,  2.2078,  3.7429,  2.4996,  3.5402,  2.0305,  4.5843,  0.4955,
         2.6215,  2.9558,  1.5014,  3.4609,  2.9567,  1.5572,  3.1574,  2.3630,
         2.6275,  2.0698,  2

In [21]:
if exact_sampling:
    count=max(roi_bin_grid_h*roi_bin_grid_w, 1) # scalar
    iy=torch.arange(roi_bin_grid_h, device=input.device) # [IY]
    ix=torch.arange(roi_bin_grid_w, device=input.device) # [IX]
    ymask=xmask=None
    print('count ', count, ' iy ', iy.shape, ' ix ', ix.shape)
else:
    count=torch.clamp(roi_bin_grid_h*roi_bin_grid_w, min=1) # [K]
    # When doing adaptive sampling, the number of samples we need to do is data-dependent based on how big the ROIs are
    # This is a bit awkward because first class dims cannot actually handle this. So instead, we inefficiently suppose that 
    # we needed to sample ALL the points and mask out things that turned out to be unnecessary
    iy=torch.arange(height, device=input.device) #[IY]
    ix=torch.arange(width, device=input.device) # [IX]
    ymask=iy[None,:]<roi_bin_grid_h[:,None] # [K, IY]
    xmask=ix[None,:]<roi_bin_grid_w[:, None] # [K, IX]

def from_K(t): return t[:,None, None]


count  4  iy  torch.Size([2])  ix  torch.Size([2])


In [22]:
tmp=from_K(roi_start_h)
print('from_K(roi_start_h) ', tmp.shape, tmp.min().item(), tmp.max().item())
tmp=ph[None,:,None]*from_K(bin_size_h)
print('ph[None,:,None]*from(bin_size_h) ', tmp.shape, tmp.min().item(), tmp.max().item())
tmp=(iy[None,None,:]+0.5).to(input.dtype)*from_K(bin_size_h/roi_bin_grid_h)
print('(iy[None,None,:]+0.5).to(input.dtype)*from_K(bin_size_h/roi_bin_grid_h) ', tmp.shape, tmp.min().item(), tmp.max().item())

from_K(roi_start_h)  torch.Size([849, 1, 1]) 0.0 199.1089630126953
ph[None,:,None]*from(bin_size_h)  torch.Size([849, 7, 1]) 0.0 108.95317840576172
(iy[None,None,:]+0.5).to(input.dtype)*from_K(bin_size_h/roi_bin_grid_h)  torch.Size([849, 1, 2]) 0.0357142873108387 13.619147300720215


In [23]:
#[K, PH, IY] = [K,1,1] + [K,PH,1] + [K,1,IY] note: sub-bin location was move to the center pixel
y=from_K(roi_start_h) + ph[None,:,None]*from_K(bin_size_h) + (iy[None,None,:]+0.5)*from_K(bin_size_h/roi_bin_grid_h)
print('y ', y.shape, y.min().item(), y.max().item())
#[K, PW, IX] = [K,1,1] + [K,PW,1] + [K,1,IX] note: sub-bin location was move to the center pixel
x=from_K(roi_start_w) + pw[None,:,None]*from_K(bin_size_w) + (ix[None,None,:]+0.5)*from_K(bin_size_w/roi_bin_grid_w)
print('x', x.shape, x.min().item(), x.max().item())

y  torch.Size([849, 7, 2]) 0.08404339849948883 200.07325744628906
x torch.Size([849, 7, 2]) 0.051833927631378174 286.4441223144531


```
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)  # [K, C, PH, PW, IY, IX]
```
[`_bilinear_interpolate`](https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py#L35)

In [24]:
_, channels, height, width=input.size()
# deal with inverse element out of feature map boundary
y=y.clamp(min=0)
x=x.clamp(min=0)
print('y ', y.shape, y.min().item(), y.max().item())
print('x ', x.shape, x.min().item(), x.max().item())
y_low=y.int()
x_low=x.int()
print('y_low ', y_low.shape, y_low.min().item(), y_low.max().item())
y_high=torch.where(y_low>=height-1, height-1, y_low+1)
y_low=torch.where(y_low>=height-1, height-1, y_low)
print('y_high ', y_high.shape, y_high.min().item(), y_high.max().item())
print('y_low ', y_low.shape, y_low.min().item(), y_low.max().item())
y=torch.where(y_low>=height-1, y.to(input.dtype), y)
print('y ', y.shape, y.min().item(), y.max().item())

y  torch.Size([849, 7, 2]) 0.08404339849948883 200.07325744628906
x  torch.Size([849, 7, 2]) 0.051833927631378174 286.4441223144531
y_low  torch.Size([849, 7, 2]) 0 200
y_high  torch.Size([849, 7, 2]) 1 199
y_low  torch.Size([849, 7, 2]) 0 199
y  torch.Size([849, 7, 2]) 0.08404339849948883 200.07325744628906


In [25]:
print('x ', x.shape, x.min().item(), x.max().item(), x.dtype)
x_high=torch.where(x_low>=width-1, width-1, x_low+1)
x_low=torch.where(x_low>=width-1, width-1, x_low)
print('x_high ', x_high.shape, x_high.min().item(), x_high.max().item(), x_high.dtype)
print('x_low ', x_low.shape, x_low.min().item(), x_low.max().item(), x_low.dtype)
x=torch.where(x_low>=width-1, x.to(input.dtype), x)
print('x ', x.shape, x.min().item(), x.max().item(), x.dtype)

x  torch.Size([849, 7, 2]) 0.051833927631378174 286.4441223144531 torch.float32
x_high  torch.Size([849, 7, 2]) 1 287 torch.int32
x_low  torch.Size([849, 7, 2]) 0 286 torch.int32
x  torch.Size([849, 7, 2]) 0.051833927631378174 286.4441223144531 torch.float32


In [26]:
ly=y-y_low
lx=x-x_low
hy=1.-ly
hx=1.-lx
print('ly ', ly.shape, ly.min().item(), ly.max().item())
print('lx ', lx.shape, lx.min().item(), lx.max().item())
print('hy ', hy.shape, hy.min().item(), hy.max().item())
print('hx ', hx.shape, hx.min().item(), hx.max().item())

ly  torch.Size([849, 7, 2]) 4.57763671875e-05 1.0732574462890625
lx  torch.Size([849, 7, 2]) 6.103515625e-05 0.999969482421875
hy  torch.Size([849, 7, 2]) -0.0732574462890625 0.9999542236328125
hx  torch.Size([849, 7, 2]) 3.0517578125e-05 0.99993896484375


In [27]:
print('input ', input.shape)
print('roi_batch_ind[:, None, None, None, None, None] ', roi_batch_ind[:, None, None, None, None, None].shape)
C=torch.arange(channels, device=input.device)[None, :, None, None, None, None]
print('C ', C.shape)
print('y[:, None, :, None, :, None],  # prev [K, PH, IY] ',
     y[:, None, :, None, :, None].shape)
print('x[:, None, None, :, None, :],  # prev [K, PW, IX] ',
     x[:, None, None, :, None, :].shape)

input  torch.Size([2, 256, 200, 296])
roi_batch_ind[:, None, None, None, None, None]  torch.Size([849, 1, 1, 1, 1, 1])
C  torch.Size([1, 256, 1, 1, 1, 1])
y[:, None, :, None, :, None],  # prev [K, PH, IY]  torch.Size([849, 1, 7, 1, 2, 1])
x[:, None, None, :, None, :],  # prev [K, PW, IX]  torch.Size([849, 1, 1, 7, 1, 2])


In [25]:
def masked_index(
    y, #[K, PH, IY]
    x, #[K, PW, IX]
):
    if ymask is not None:
        assert xmask is not None
        y=torch.where(ymask[:, None], y, 0)
        x=torch.where(xmask[:,None], x, 0)
    return input[roi_batch_ind[:,None, None, None,None,None],
    torch.arange(channels, device=input.device)[None,:,None,None,None,None],
    y[:,None,:,None,:,None], # prev [K, PH, IY]
    x[:,None,None,:,None,:], # prev [K, PW, IX]
    ] # [K, C, PH, PW, IY, IX]

v1=masked_index(y_low, x_low)
v2=masked_index(y_low, x_high)
v3=masked_index(y_high, x_low)
v4=masked_index(y_high, x_high)


In [28]:
from torchvision.ops.roi_align import _bilinear_interpolate

val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)  # [K, C, PH, PW, IY, IX]
print('val ', val.shape)

val  torch.Size([849, 256, 7, 7, 2, 2])
