In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from modeling.ctx_outs import autoctx_outputs
from model.roi_pooling.functions.roi_pool import RoIPoolFunction
from model.roi_crop.functions.roi_crop import RoICropFunction
from modeling.roi_xform.roi_align.functions.roi_align import RoIAlignFunction

import numpy as np

# The spatial scale should be a list with each level's sc if FPN is applied.
def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoIPoolF',
                              resolution=7, spatial_scale=1. / 16., sampling_ratio=0):
        """Add the specified RoI pooling method. The sampling_ratio argument
        is supported for some, but not all, RoI transform methods.

        RoIFeatureTransform abstracts away:
          - Use of FPN or not
          - Specifics of the transform method
        """
    assert method in {'RoIPoolF', 'RoICrop', 'RoIAlign'}, \
            'Unknown pooling method: {}'.format(method)

    # The list suggests the feature maps(blobs_in) from different levels
    if isinstance(blobs_in, list):
        # FPN case: add RoIFeatureTransform to each FPN level
        device_id = blobs_in[0].get_device()
        k_max = cfg.FPN.ROI_MAX_LEVEL  # coarsest level of pyramid
        k_min = cfg.FPN.ROI_MIN_LEVEL  # finest level of pyramid
        assert len(blobs_in) == k_max - k_min + 1
        # set a list for appending
        bl_out_list = []
        for lvl in range(k_min, k_max + 1):
            bl_in = blobs_in[k_max - lvl]  # blobs_in is in reversed order
            dim_in = blobs_in.size(1)
            sc = spatial_scale[k_max - lvl]  # in reversed order
            # set up ctx_rois instance in current level
            ctx_rois = autoctx_outputs(dim_in,sc)
            # each level's rois keys
            bl_rois = blob_rois + '_fpn' + str(lvl)
# ------------------------------------------------------------------------------- #
            if len(rpn_ret[bl_rois]):
                # the rois from single level,2D Variable with shape
                # (roi_num,5)
                rois = Variable(torch.from_numpy(rpn_ret[bl_rois])).cuda(device_id)
                
                for object_roi in rois:
                    # the ctx_roi should be ndarray with shape of (9*5)
                    # transform into tensor
                    autoctx_rois = torch.from_numpy(ctx_rois(bl_in,object_roi))
                    if method == 'RoIPoolF':
                        # Warning!: Not check if implementation matches Detectron
                        xform_out = RoIPoolFunction(resolution, resolution, sc)(bl_in, autoctx_rois)
                    elif method == 'RoICrop':
                        # Warning!: Not check if implementation matches Detectron
                        grid_xy = net_utils.affine_grid_gen(
                            rois, bl_in.size()[2:], self.grid_size)
                        grid_yx = torch.stack(
                            [grid_xy.data[:, :, :, 1], grid_xy.data[:, :, :, 0]], 3).contiguous()
                        xform_out = RoICropFunction()(bl_in, Variable(grid_yx).detach())
                        if cfg.CROP_RESIZE_WITH_MAX_POOL:
                            xform_out = F.max_pool2d(xform_out, 2, 2)
                    elif method == 'RoIAlign':
                        xform_out = RoIAlignFunction(
                            resolution, resolution, sc, sampling_ratio)(bl_in, autoctx_rois)
                    
                    # concatenate through dim 1,with the shape of (1,C*9,h,w)
                    xform_out = xform_out.view(1,xform_out.size(0)*xform_out.size(1),
                                               xform_out.size(2),xform_out.size(3))
                    compress_inchannel = xform_out.size(1)
                    compress_outchannel = compress_inchannel / 9
                    # compress the xform into the original dimension
                    compress_conv = nn.Conv2d(compress_inchannel,compress_outchannel,1,1,0)
                    compressed_feature = compress_conv(xform_out)
                    bl_out_list.append(xform_out)
                # to keep the same dimension,cat the xform_outs from single level
                bl_out_list = torch.cat(bl_out_list,dim=0)
        # The pooled features from all levels are concatenated along the
        # batch dimension into a single 4D tensor. 
        # The shape will be (roi_nums,C,H,W)
        xform_shuffled = torch.cat(bl_out_list,dim=0)

        # Unshuffle to match rois from dataloader
        device_id = xform_shuffled.get_device()
        restore_bl = rpn_ret[blob_rois + '_idx_restore_int32']
        restore_bl = Variable(
                    torch.from_numpy(restore_bl.astype('int64', copy=False))).cuda(device_id)
        xform_out = xform_shuffled[restore_bl]
    else:
        # Single feature level
        # rois: holds R regions of interest, each is a 5-tuple
        # (batch_idx, x1, y1, x2, y2) specifying an image batch index and a
        # rectangle (x1, y1, x2, y2)
        device_id = blobs_in.get_device()
        rois = Variable(torch.from_numpy(rpn_ret[blob_rois])).cuda(device_id)
        if method == 'RoIPoolF':
            # spatial_scale[0] is the coarest level for single feature map
            xform_out = RoIPoolFunction(resolution, resolution, spatial_scale[0])(blobs_in, rois)
        elif method == 'RoICrop':
            grid_xy = net_utils.affine_grid_gen(rois, blobs_in.size()[2:], self.grid_size)
            grid_yx = torch.stack(
                [grid_xy.data[:, :, :, 1], grid_xy.data[:, :, :, 0]], 3).contiguous()
            xform_out = RoICropFunction()(blobs_in, Variable(grid_yx).detach())
            if cfg.CROP_RESIZE_WITH_MAX_POOL:
                xform_out = F.max_pool2d(xform_out, 2, 2)
        elif method == 'RoIAlign':
            xform_out = RoIAlignFunction(
                resolution, resolution, spatial_scale[0], sampling_ratio)(blobs_in, rois)

    return xform_out