#------------------code for training------------------

In [None]:
#!pip install easydict

In [None]:
#!cp -r ../input/cascadercnn .

In [None]:
#cd cascadercnn/lib

In [None]:
#!python setup.py build develop

In [None]:
#cd ..

In [None]:
#!ls .

#lr = 0.00125 for one card and one image per batch

In [None]:
#!python train_cascade_fpn.py --dataset pascal_voc --net res50 --epoch 30 --lr_decay_step 9 --disp_interval 1 --bs 6 --nw 16 --lr 0.001 --lr_decay_step 8 --cuda --mGPUs

In [None]:
#!rm -rf ../cascadercnn

#------------------code for testing------------------

In [None]:
!ls .

In [None]:
!rm cascade_fpn_1_64_1686.pth

In [None]:
import cv2
import math
import os
import numpy as np
import pandas as pd

os.environ['CUDA_VISIBLE_DEVICES'] = "0"

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from torch.nn.functional import avg_pool2d
from torch.autograd import Variable


classes = np.asarray(['__background__', 'wheat'])
pascal_classes = ['__background__', 'wheat']
categoryList={'bg':0,'wheat':1} 

cfg = { 'ANCHOR_RATIOS': [0.5,1,2], 
        'ANCHOR_SCALES': [4,8,16,32],
        'FEAT_STRIDE': [16, ],
        'POOLING_SIZE': 7,
        'TRAIN_TRUNCATED': False,
        'POOLING_MODE': 'align',
        'CROP_RESIZE_WITH_MAX_POOL': False,
        'FPN_ANCHOR_SCALES': [32, 64, 128, 256, 512],
        'FPN_FEAT_STRIDES': [4, 8, 16, 32, 64],
        'FPN_ANCHOR_STRIDE': 1,
        'RPN_PRE_NMS_TOP_N': 6000,
        'RPN_POST_NMS_TOP_N': 300,
        'RPN_NMS_THRESH': 0.7,
        'RPN_MIN_SIZE': 16,
        'TRAIN_RPN_NEGATIVE_OVERLAP': 0.3,
        'TRAIN_RPN_POSITIVE_OVERLAP': 0.7,
        'TRAIN_RPN_FG_FRACTION': 0.5,
        'TRAIN_RPN_BATCHSIZE': 256,
        'TRAIN_RPN_BBOX_INSIDE_WEIGHTS': (1.0, 1.0, 1.0, 1.0),
        'TRAIN_RPN_POSITIVE_WEIGHT': -1.0,
        'TRAIN_FG_THRESH': 0.5,
        'TRAIN_BG_THRESH_HI': 0.5,
        'TRAIN_BG_THRESH_LO':0.1,
        'TRAIN_FG_THRESH_2ND': 0.6,
        'TRAIN_FG_THRESH_3RD': 0.7,
        'TRAIN_BBOX_NORMALIZE_TARGETS_PRECOMPUTED': True,
        'TRAIN_BATCH_SIZE': 128,
        'TRAIN_FG_FRACTION': 0.25,
        'TRAIN_BBOX_NORMALIZE_MEANS': (0.0, 0.0, 0.0, 0.0),
        'TRAIN_BBOX_NORMALIZE_STDS':(0.1, 0.1, 0.2, 0.2),
        'TRAIN_BBOX_INSIDE_WEIGHTS':(1.0, 1.0, 1.0, 1.0),
        'RESNET_FIXED_BLOCKS': 1,
        #'PIXEL_MEANS': np.array([[[0.485, 0.456, 0.406]]]),
        'PIXEL_MEANS': np.array([[[122.7717, 115.9465, 102.9801 ]]]),  # RGB
        #'PIXEL_MEANS': np.array([[[102.9801, 115.9465, 122.7717]]]),  # BGR
        'TEST_SCALES': (1024,),
        'TEST_MAX_SIZE':1024,
        'TEST_BBOX_REG': True,
        
        }


#--------------------------------------------------#
def clip_boxes(boxes, im_shape, batch_size):

    for i in range(batch_size):
        boxes[i,:,0::4].clamp_(0, im_shape[i, 1]-1)
        boxes[i,:,1::4].clamp_(0, im_shape[i, 0]-1)
        boxes[i,:,2::4].clamp_(0, im_shape[i, 1]-1)
        boxes[i,:,3::4].clamp_(0, im_shape[i, 0]-1)

    return boxes
    
def bbox_transform_inv(boxes, deltas, batch_size):
    # print("               bbox_transform_inv               ")
    # print("bbox shape:",boxes.shape)
    # print("deltas shape:",deltas.shape)
    widths = boxes[:, :, 2] - boxes[:, :, 0] + 1.0
    heights = boxes[:, :, 3] - boxes[:, :, 1] + 1.0
    ctr_x = boxes[:, :, 0] + 0.5 * widths
    ctr_y = boxes[:, :, 1] + 0.5 * heights

    dx = deltas[:, :, 0::4]
    dy = deltas[:, :, 1::4]
    dw = deltas[:, :, 2::4]
    dh = deltas[:, :, 3::4]

    # print(dx.shape)

    pred_ctr_x = dx * widths.unsqueeze(2) + ctr_x.unsqueeze(2)
    pred_ctr_y = dy * heights.unsqueeze(2) + ctr_y.unsqueeze(2)
    pred_w = torch.exp(dw) * widths.unsqueeze(2)
    pred_h = torch.exp(dh) * heights.unsqueeze(2)

    pred_boxes = deltas.clone()
    # x1
    pred_boxes[:, :, 0::4] = pred_ctr_x - 0.5 * pred_w
    # y1
    pred_boxes[:, :, 1::4] = pred_ctr_y - 0.5 * pred_h
    # x2
    pred_boxes[:, :, 2::4] = pred_ctr_x + 0.5 * pred_w
    # y2
    pred_boxes[:, :, 3::4] = pred_ctr_y + 0.5 * pred_h

    return pred_boxes


def generate_anchors_single_pyramid(scales, ratios, shape, feature_stride, anchor_stride):
    """
    scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
    ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
    shape: [height, width] spatial shape of the feature map over which
            to generate anchors.
    feature_stride: Stride of the feature map relative to the image in pixels.
    anchor_stride: Stride of anchors on the feature map. For example, if the
        value is 2 then generate anchors for every other feature map pixel.
    """
    # Get all combinations of scales and ratios
    scales, ratios = np.meshgrid(np.array(scales), np.array(ratios))
    scales = scales.flatten()
    ratios = ratios.flatten()

    # Enumerate heights and widths from scales and ratios
    heights = scales / np.sqrt(ratios)
    widths = scales * np.sqrt(ratios)

    # Enumerate shifts in feature space
    shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride
    shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride
    shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y)

    # Enumerate combinations of shifts, widths, and heights
    box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
    box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
    
    # # Reshape to get a list of (y, x) and a list of (h, w)
    # box_centers = np.stack(
    #     [box_centers_y, box_centers_x], axis=2).reshape([-1, 2])
    # box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2])

    # NOTE: the original order is  (y, x), we changed it to (x, y) for our code
    # Reshape to get a list of (x, y) and a list of (w, h)
    box_centers = np.stack(
        [box_centers_x, box_centers_y], axis=2).reshape([-1, 2])
    box_sizes = np.stack([box_widths, box_heights], axis=2).reshape([-1, 2])

    # Convert to corner coordinates (x1, y1, x2, y2)
    boxes = np.concatenate([box_centers - 0.5 * box_sizes,
                            box_centers + 0.5 * box_sizes], axis=1)
    # print(boxes)
    return boxes


def generate_anchors_all_pyramids(scales, ratios, feature_shapes, feature_strides,
                             anchor_stride):
    """Generate anchors at different levels of a feature pyramid. Each scale
    is associated with a level of the pyramid, but each ratio is used in
    all levels of the pyramid.
    Returns:
    anchors: [N, (y1, x1, y2, x2)]. All generated anchors in one array. Sorted
        with the same order of the given scales. So, anchors of scale[0] come
        first, then anchors of scale[1], and so on.
    """
    # Anchors
    # [anchor_count, (y1, x1, y2, x2)]
    anchors = []
    for i in range(len(scales)):
        anchors.append(generate_anchors_single_pyramid(scales[i], ratios, feature_shapes[i],
                                        feature_strides[i], anchor_stride))
    return np.concatenate(anchors, axis=0)



class _ProposalTargetLayer(nn.Module):
    """
    Assign object detection proposals to ground-truth targets. Produces proposal
    classification labels and bounding-box regression targets.
    """

    def __init__(self, nclasses):
        super(_ProposalTargetLayer, self).__init__()
        self._num_classes = nclasses
        self.BBOX_NORMALIZE_MEANS = torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_MEANS'])
        self.BBOX_NORMALIZE_STDS = torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_STDS'])
        self.BBOX_INSIDE_WEIGHTS = torch.FloatTensor(cfg['TRAIN_BBOX_INSIDE_WEIGHTS'])

    def forward(self, all_rois, gt_boxes, num_boxes, stage=1):

        self.BBOX_NORMALIZE_MEANS = self.BBOX_NORMALIZE_MEANS.type_as(gt_boxes)
        self.BBOX_NORMALIZE_STDS = self.BBOX_NORMALIZE_STDS.type_as(gt_boxes)
        self.BBOX_INSIDE_WEIGHTS = self.BBOX_INSIDE_WEIGHTS.type_as(gt_boxes)

        gt_boxes_append = gt_boxes.new(gt_boxes.size()).zero_()
        gt_boxes_append[:,:,1:5] = gt_boxes[:,:,:4]

        # Include ground-truth boxes in the set of candidate rois
        all_rois = torch.cat([all_rois, gt_boxes_append], 1)

        num_images = 1
        rois_per_image = int(cfg['TRAIN_BATCH_SIZE'] / num_images)
        fg_rois_per_image = int(np.round(cfg['TRAIN_FG_FRACTION'] * rois_per_image))
        fg_rois_per_image = 1 if fg_rois_per_image == 0 else fg_rois_per_image

        labels, rois, gt_assign, bbox_targets, bbox_inside_weights = self._sample_rois_pytorch(
            all_rois, gt_boxes, fg_rois_per_image,
            rois_per_image, self._num_classes, stage=stage)

        bbox_outside_weights = (bbox_inside_weights > 0).float()

        return rois, labels, gt_assign, bbox_targets, bbox_inside_weights, bbox_outside_weights

    def backward(self, top, propagate_down, bottom):
        """This layer does not propagate gradients."""
        pass

    def reshape(self, bottom, top):
        """Reshaping happens during the call to forward."""
        pass

    def _get_bbox_regression_labels_pytorch(self, bbox_target_data, labels_batch, num_classes):
        """Bounding-box regression targets (bbox_target_data) are stored in a
        compact form b x N x (class, tx, ty, tw, th)

        This function expands those targets into the 4-of-4*K representation used
        by the network (i.e. only one class has non-zero targets).

        Returns:
            bbox_target (ndarray): b x N x 4K blob of regression targets
            bbox_inside_weights (ndarray): b x N x 4K blob of loss weights
        """
        batch_size = labels_batch.size(0)
        rois_per_image = labels_batch.size(1)
        clss = labels_batch
        bbox_targets = bbox_target_data.new(batch_size, rois_per_image, 4).zero_()
        bbox_inside_weights = bbox_target_data.new(bbox_targets.size()).zero_()

        for b in range(batch_size):
            # assert clss[b].sum() > 0
            if clss[b].sum() == 0:
                continue
            inds = torch.nonzero(clss[b] > 0).view(-1)
            for i in range(inds.numel()):
                ind = inds[i]
                bbox_targets[b, ind, :] = bbox_target_data[b, ind, :]
                bbox_inside_weights[b, ind, :] = self.BBOX_INSIDE_WEIGHTS

        return bbox_targets, bbox_inside_weights


    def _compute_targets_pytorch(self, ex_rois, gt_rois):
        """Compute bounding-box regression targets for an image."""

        assert ex_rois.size(1) == gt_rois.size(1)
        assert ex_rois.size(2) == 4
        assert gt_rois.size(2) == 4

        batch_size = ex_rois.size(0)
        rois_per_image = ex_rois.size(1)

        targets = bbox_transform_batch(ex_rois, gt_rois)

        if cfg['TRAIN_BBOX_NORMALIZE_TARGETS_PRECOMPUTED']:
            # Optionally normalize targets by a precomputed mean and stdev
            targets = ((targets - self.BBOX_NORMALIZE_MEANS.expand_as(targets))
                        / self.BBOX_NORMALIZE_STDS.expand_as(targets))

        return targets


    def _sample_rois_pytorch(self, all_rois, gt_boxes, fg_rois_per_image, rois_per_image, num_classes, stage=1):
        """Generate a random sample of RoIs comprising foreground and background
        examples.
        """
        # overlaps: (rois x gt_boxes)

        overlaps = bbox_overlaps_batch(all_rois, gt_boxes)
        
        max_overlaps, gt_assignment = torch.max(overlaps, 2)

        batch_size = overlaps.size(0)
        num_proposal = overlaps.size(1)
        num_boxes_per_img = overlaps.size(2)

        offset = torch.arange(0, batch_size)*gt_boxes.size(1)
        offset = offset.view(-1, 1).type_as(gt_assignment) + gt_assignment
        labels = gt_boxes[:,:,4].contiguous().view(-1)[(offset.view(-1),)].view(batch_size, -1)

        labels_batch = labels.new(batch_size, rois_per_image).zero_()
        rois_batch  = all_rois.new(batch_size, rois_per_image, 5).zero_()
        gt_assign_batch = all_rois.new(batch_size, rois_per_image).zero_()
        gt_rois_batch = all_rois.new(batch_size, rois_per_image, 5).zero_()
        # Guard against the case when an image has fewer than max_fg_rois_per_image
        # foreground RoIs
        if stage == 1:
            fg_thresh = cfg['TRAIN_FG_THRESH']
            bg_thresh_hi = cfg['TRAIN_BG_THRESH_HI']
            bg_thresh_lo = cfg['TRAIN_BG_THRESH_LO']
        elif stage == 2:
            fg_thresh = cfg['TRAIN_FG_THRESH_2ND']
            bg_thresh_hi = cfg['TRAIN_FG_THRESH_2ND']
            bg_thresh_lo = cfg['TRAIN_BG_THRESH_LO']
        elif stage == 3:
            fg_thresh = cfg['TRAIN_FG_THRESH_3RD']
            bg_thresh_hi = cfg['TRAIN_FG_THRESH_3RD']
            bg_thresh_lo = cfg['TRAIN_BG_THRESH_LO']
        else:
            raise RuntimeError('stage must be in [1, 2, 3]')
        for i in range(batch_size):

            fg_inds = torch.nonzero(max_overlaps[i] >= fg_thresh).view(-1)
            fg_num_rois = fg_inds.numel()

            # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
            bg_inds = torch.nonzero((max_overlaps[i] < bg_thresh_hi) &
                                    (max_overlaps[i] >= bg_thresh_lo)).view(-1)
            bg_num_rois = bg_inds.numel()

            if fg_num_rois > 0 and bg_num_rois > 0:
                # sampling fg
                fg_rois_per_this_image = min(fg_rois_per_image, fg_num_rois)

                # torch.randperm seems has a bug on multi-gpu setting that cause the segfault.
                # See https://github.com/pytorch/pytorch/issues/1868 for more details.
                # use numpy instead.
                #rand_num = torch.randperm(fg_num_rois).long().cuda()
                rand_num = torch.from_numpy(np.random.permutation(fg_num_rois)).type_as(gt_boxes).long()
                fg_inds = fg_inds[rand_num[:fg_rois_per_this_image]]

                # sampling bg
                bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image

                # Seems torch.rand has a bug, it will generate very large number and make an error.
                # We use numpy rand instead.
                #rand_num = (torch.rand(bg_rois_per_this_image) * bg_num_rois).long().cuda()
                rand_num = np.floor(np.random.rand(bg_rois_per_this_image) * bg_num_rois)
                rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
                bg_inds = bg_inds[rand_num]

            elif fg_num_rois > 0 and bg_num_rois == 0:
                # sampling fg
                #rand_num = torch.floor(torch.rand(rois_per_image) * fg_num_rois).long().cuda()
                rand_num = np.floor(np.random.rand(rois_per_image) * fg_num_rois)
                rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
                fg_inds = fg_inds[rand_num]
                fg_rois_per_this_image = rois_per_image
                bg_rois_per_this_image = 0
            elif bg_num_rois > 0 and fg_num_rois == 0:
                # sampling bg
                #rand_num = torch.floor(torch.rand(rois_per_image) * bg_num_rois).long().cuda()
                rand_num = np.floor(np.random.rand(rois_per_image) * bg_num_rois)
                rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()

                bg_inds = bg_inds[rand_num]
                bg_rois_per_this_image = rois_per_image
                fg_rois_per_this_image = 0
            else:
                print(i, overlaps[i], max_overlaps[i], gt_boxes[i])
                raise ValueError("bg_num_rois = 0 and fg_num_rois = 0, this should not happen!")

            # The indices that we're selecting (both fg and bg)
            keep_inds = torch.cat([fg_inds, bg_inds], 0)

            # Select sampled values from various arrays:
            labels_batch[i].copy_(labels[i][keep_inds])

            # Clamp labels for the background RoIs to 0
            if fg_rois_per_this_image < rois_per_image:
                labels_batch[i][fg_rois_per_this_image:] = 0

            rois_batch[i] = all_rois[i][keep_inds]
            rois_batch[i,:,0] = i

            # TODO: check the below line when batch_size > 1, no need to add offset here
            gt_assign_batch[i] = gt_assignment[i][keep_inds]

            gt_rois_batch[i] = gt_boxes[i][gt_assignment[i][keep_inds]]

        bbox_target_data = self._compute_targets_pytorch(
                rois_batch[:,:,1:5], gt_rois_batch[:,:,:4])

        bbox_targets, bbox_inside_weights = \
                self._get_bbox_regression_labels_pytorch(bbox_target_data, labels_batch, num_classes)

        return labels_batch, rois_batch, gt_assign_batch, bbox_targets, bbox_inside_weights
    
class _AnchorTargetLayer_FPN(nn.Module):
    """
        Assign anchors to ground-truth targets. Produces anchor classification
        labels and bounding-box regression targets.
    """
    def __init__(self, feat_stride, scales, ratios):
        super(_AnchorTargetLayer_FPN, self).__init__()
        self._anchor_ratios = ratios
        self._feat_stride = feat_stride
        self._fpn_scales = np.array(cfg['FPN_ANCHOR_SCALES'])
        self._fpn_feature_strides = np.array(cfg['FPN_FEAT_STRIDES'])
        self._fpn_anchor_stride = cfg['FPN_ANCHOR_STRIDE']

        # allow boxes to sit over the edge by a small amount
        self._allowed_border = 0  # default is 0

    def forward(self, input):
        # Algorithm:
        #
        # for each (H, W) location i
        #   generate 9 anchor boxes centered on cell i
        #   apply predicted bbox deltas at cell i to each of the 9 anchors
        # filter out-of-image anchors
        # 
        scores = input[0]
        gt_boxes = input[1]
        im_info = input[2]
        num_boxes = input[3]
        feat_shapes = input[4]

        # NOTE: need to change
        # height, width = scores.size(2), scores.size(3)
        height, width = 0, 0

        batch_size = gt_boxes.size(0)

        anchors = torch.from_numpy(generate_anchors_all_pyramids(self._fpn_scales, self._anchor_ratios, 
                feat_shapes, self._fpn_feature_strides, self._fpn_anchor_stride)).type_as(scores)    
        total_anchors = anchors.size(0)
        # print(self._fpn_feature_strides)
        # print(anchors.shape)
        keep = ((anchors[:, 0] >= -self._allowed_border) &
                (anchors[:, 1] >= -self._allowed_border) &
                (anchors[:, 2] < long(im_info[0][1]) + self._allowed_border) &
                (anchors[:, 3] < long(im_info[0][0]) + self._allowed_border))

        inds_inside = torch.nonzero(keep).view(-1)

        # keep only inside anchors
        anchors = anchors[inds_inside, :]

        # label: 1 is positive, 0 is negative, -1 is dont care
        labels = gt_boxes.new(batch_size, inds_inside.size(0)).fill_(-1)
        bbox_inside_weights = gt_boxes.new(batch_size, inds_inside.size(0)).zero_()
        bbox_outside_weights = gt_boxes.new(batch_size, inds_inside.size(0)).zero_()

        overlaps = bbox_overlaps_batch(anchors, gt_boxes)

        max_overlaps, argmax_overlaps = torch.max(overlaps, 2)
        gt_max_overlaps, _ = torch.max(overlaps, 1)

        labels[max_overlaps < cfg['TRAIN_RPN_NEGATIVE_OVERLAP']] = 0

        gt_max_overlaps[gt_max_overlaps==0] = 1e-5
        keep = torch.sum(overlaps.eq(gt_max_overlaps.view(batch_size,1,-1).expand_as(overlaps)), 2)

        if torch.sum(keep) > 0:
            labels[keep>0] = 1

        # fg label: above threshold IOU
        labels[max_overlaps >= cfg['TRAIN_RPN_POSITIVE_OVERLAP']] = 1

        num_fg = int(cfg['TRAIN_RPN_FG_FRACTION'] * cfg['TRAIN_RPN_BATCHSIZE'])

        sum_fg = torch.sum((labels == 1).int(), 1)
        sum_bg = torch.sum((labels == 0).int(), 1)

        for i in range(batch_size):
            # subsample positive labels if we have too many
            if sum_fg[i] > num_fg:
                fg_inds = torch.nonzero(labels[i] == 1).view(-1)
                # torch.randperm seems has a bug on multi-gpu setting that cause the segfault. 
                # See https://github.com/pytorch/pytorch/issues/1868 for more details.
                # use numpy instead.                
                #rand_num = torch.randperm(fg_inds.size(0)).type_as(gt_boxes).long()
                rand_num = torch.from_numpy(np.random.permutation(fg_inds.size(0))).type_as(gt_boxes).long()
                disable_inds = fg_inds[rand_num[:fg_inds.size(0)-num_fg]]
                labels[i][disable_inds] = -1

            num_bg = cfg['TRAIN_RPN_BATCHSIZE'] - sum_fg[i]

            # subsample negative labels if we have too many
            if sum_bg[i] > num_bg:
                bg_inds = torch.nonzero(labels[i] == 0).view(-1)
                #rand_num = torch.randperm(bg_inds.size(0)).type_as(gt_boxes).long()

                rand_num = torch.from_numpy(np.random.permutation(bg_inds.size(0))).type_as(gt_boxes).long()
                disable_inds = bg_inds[rand_num[:bg_inds.size(0)-num_bg]]
                labels[i][disable_inds] = -1

        offset = torch.arange(0, batch_size)*gt_boxes.size(1)

        argmax_overlaps = argmax_overlaps + offset.view(batch_size, 1).type_as(argmax_overlaps)
        bbox_targets = _compute_targets_batch(anchors, gt_boxes.view(-1,5)[argmax_overlaps.view(-1), :].view(batch_size, -1, 5))

        # use a single value instead of 4 values for easy index.
        bbox_inside_weights[labels==1] = cfg['TRAIN_RPN_BBOX_INSIDE_WEIGHTS'][0]


        if cfg['TRAIN_RPN_POSITIVE_WEIGHT'] < 0:
            num_examples = torch.sum(labels[i] >= 0)
            positive_weights = 1.0 / num_examples.item()
            negative_weights = 1.0 / num_examples.item()
        else:
            assert ((cfg['TRAIN_RPN_POSITIVE_WEIGHT'] > 0) &
                    (cfg['TRAIN_RPN_POSITIVE_WEIGHT'] < 1))

        bbox_outside_weights[labels == 1] = positive_weights
        bbox_outside_weights[labels == 0] = negative_weights

        labels = _unmap(labels, total_anchors, inds_inside, batch_size, fill=-1)
        bbox_targets = _unmap(bbox_targets, total_anchors, inds_inside, batch_size, fill=0)
        bbox_inside_weights = _unmap(bbox_inside_weights, total_anchors, inds_inside, batch_size, fill=0)
        bbox_outside_weights = _unmap(bbox_outside_weights, total_anchors, inds_inside, batch_size, fill=0)

        outputs = []

        # labels = labels.view(batch_size, height, width, A).permute(0,3,1,2).contiguous()
        # labels = labels.view(batch_size, 1, A * height, width)
        outputs.append(labels)
        # bbox_targets = bbox_targets.view(batch_size, height, width, A*4).permute(0,3,1,2).contiguous()
        outputs.append(bbox_targets)

        # anchors_count = bbox_inside_weights.size(1)
        # bbox_inside_weights = bbox_inside_weights.view(batch_size,anchors_count,1).expand(batch_size, anchors_count, 4)
        # bbox_inside_weights = bbox_inside_weights.contiguous().view(batch_size, height, width, 4*A)\
                            # .permute(0,3,1,2).contiguous()

        outputs.append(bbox_inside_weights)

        # bbox_outside_weights = bbox_outside_weights.view(batch_size,anchors_count,1).expand(batch_size, anchors_count, 4)
        # bbox_outside_weights = bbox_outside_weights.contiguous().view(batch_size, height, width, 4*A)\
                            # .permute(0,3,1,2).contiguous()
        outputs.append(bbox_outside_weights)

        return outputs

    def backward(self, top, propagate_down, bottom):
        """This layer does not propagate gradients."""
        pass

    def reshape(self, bottom, top):
        """Reshaping happens during the call to forward."""
        pass

    
    
class _ProposalLayer_FPN(nn.Module):
    """
    Outputs object detection proposals by applying estimated bounding-box
    transformations to a set of regular boxes (called "anchors").
    """

    def __init__(self, feat_stride, scales, ratios):
        super(_ProposalLayer_FPN, self).__init__()
        self._anchor_ratios = ratios
        self._feat_stride = feat_stride
        self._fpn_scales = np.array(cfg['FPN_ANCHOR_SCALES'])
        self._fpn_feature_strides = np.array(cfg['FPN_FEAT_STRIDES'])
        self._fpn_anchor_stride = cfg['FPN_ANCHOR_STRIDE']
        # self._anchors = torch.from_numpy(generate_anchors_all_pyramids(self._fpn_scales, ratios, self._fpn_feature_strides, fpn_anchor_stride))
        # self._num_anchors = self._anchors.size(0)

    def forward(self, input):

        # Algorithm:
        #
        # for each (H, W) location i
        #   generate A anchor boxes centered on cell i
        #   apply predicted bbox deltas at cell i to each of the A anchors
        # clip predicted boxes to image
        # remove predicted boxes with either height or width < threshold
        # sort all (proposal, score) pairs by score from highest to lowest
        # take top pre_nms_topN proposals before NMS
        # apply NMS with threshold 0.7 to remaining proposals
        # take after_nms_topN proposals after NMS
        # return the top proposals (-> RoIs top, scores top)


        # the first set of _num_anchors channels are bg probs
        # the second set are the fg probs
        scores = input[0][:, :, 1]  # batch_size x num_rois x 1
        bbox_deltas = input[1]      # batch_size x num_rois x 4
        im_info = input[2]
        cfg_key = input[3]
        feat_shapes = input[4]        

        pre_nms_topN  = cfg['RPN_PRE_NMS_TOP_N']
        post_nms_topN = cfg['RPN_POST_NMS_TOP_N']
        nms_thresh    = cfg['RPN_NMS_THRESH']
        min_size      = cfg['RPN_MIN_SIZE']

        batch_size = bbox_deltas.size(0)

        anchors = torch.from_numpy(generate_anchors_all_pyramids(self._fpn_scales, self._anchor_ratios, 
                feat_shapes, self._fpn_feature_strides, self._fpn_anchor_stride)).type_as(scores)
        num_anchors = anchors.size(0)

        anchors = anchors.view(1, num_anchors, 4).expand(batch_size, num_anchors, 4)

        # Convert anchors into proposals via bbox transformations
        proposals = bbox_transform_inv(anchors, bbox_deltas, batch_size)

        # 2. clip predicted boxes to image
        proposals = clip_boxes(proposals, im_info, batch_size)
        # keep_idx = self._filter_boxes(proposals, min_size).squeeze().long().nonzero().squeeze()
                
        scores_keep = scores
        proposals_keep = proposals

        _, order = torch.sort(scores_keep, 1, True)

        output = scores.new(batch_size, post_nms_topN, 5).zero_()
        for i in range(batch_size):
            # # 3. remove predicted boxes with either height or width < threshold
            # # (NOTE: convert min_size to input image scale stored in im_info[2])
            proposals_single = proposals_keep[i]
            scores_single = scores_keep[i]

            # # 4. sort all (proposal, score) pairs by score from highest to lowest
            # # 5. take top pre_nms_topN (e.g. 6000)
            order_single = order[i]

            if pre_nms_topN > 0 and pre_nms_topN < scores_keep.numel():
                order_single = order_single[:pre_nms_topN]

            proposals_single = proposals_single[order_single, :]
            scores_single = scores_single[order_single].view(-1,1)
            # print("-------------------------")
            # print(type(proposals_single))
            # print(proposals_single.shape)
            # print(type(scores_single))
            # print(scores_single.shape)
            # print("-------------------------")
            # # 6. apply nms (e.g. threshold = 0.7)
            # # 7. take after_nms_topN (e.g. 300)
            # # 8. return the top proposals (-> RoIs top)
            # print(proposals_single)
            # print('------------------------')
            # print(proposals_single.cpu().numpy())
            keep_idx_i = soft_nms(proposals_single, scores_single.squeeze(1), sigma=0.5, thresh=0.001, cuda=1)
            # keep_idx_i = soft_nms(proposals_single.cpu().numpy(), scores_single.cpu().numpy(), 0, thresh = 0.2, Nt = nms_thresh)
            # keep_idx_i = torch.from_numpy(keep_idx_i)
            # keep_idx_i = nms(proposals_single, scores_single.squeeze(1), nms_thresh)
            # keep_idx_i = nms(proposals_single, scores_single, nms_thresh)
            keep_idx_i = keep_idx_i.long().view(-1)

            if post_nms_topN > 0:
                keep_idx_i = keep_idx_i[:post_nms_topN]
            proposals_single = proposals_single[keep_idx_i, :]
            scores_single = scores_single[keep_idx_i, :]

            # padding 0 at the end.
            num_proposal = proposals_single.size(0)
            output[i,:,0] = i
            output[i,:num_proposal,1:] = proposals_single

        return output

    def backward(self, top, propagate_down, bottom):
        """This layer does not propagate gradients."""
        pass

    def reshape(self, bottom, top):
        """Reshaping happens during the call to forward."""
        pass

    def _filter_boxes(self, boxes, min_size):
        """Remove all boxes with any side smaller than min_size."""
        ws = boxes[:, :, 2] - boxes[:, :, 0] + 1
        hs = boxes[:, :, 3] - boxes[:, :, 1] + 1
        keep = ((ws >= min_size) & (hs >= min_size))
        return keep


class _RPN_FPN(nn.Module):
    """ region proposal network """
    def __init__(self, din):
        super(_RPN_FPN, self).__init__()

        self.din = din  # get depth of input feature map, e.g., 512
        self.anchor_ratios = cfg['ANCHOR_RATIOS']
        self.anchor_scales = cfg['ANCHOR_SCALES']
        self.feat_stride = cfg['FEAT_STRIDE']

        # define the convrelu layers processing input feature map
        self.RPN_Conv = nn.Conv2d(self.din, 512, 3, 1, 1, bias=True)

        # define bg/fg classifcation score layer
        # self.nc_score_out = len(self.anchor_scales) * len(self.anchor_ratios) * 2 # 2(bg/fg) * 9 (anchors)
        self.nc_score_out = 1 * len(self.anchor_ratios) * 2 # 2(bg/fg) * 3 (anchor ratios) * 1 (anchor scale)
        self.RPN_cls_score = nn.Conv2d(512, self.nc_score_out, 1, 1, 0)

        # define anchor box offset prediction layer
        # self.nc_bbox_out = len(self.anchor_scales) * len(self.anchor_ratios) * 4 # 4(coords) * 9 (anchors)
        self.nc_bbox_out = 1 * len(self.anchor_ratios) * 4 # 4(coords) * 3 (anchors) * 1 (anchor scale)
        self.RPN_bbox_pred = nn.Conv2d(512, self.nc_bbox_out, 1, 1, 0)

        # define proposal layer
        self.RPN_proposal = _ProposalLayer_FPN(self.feat_stride, self.anchor_scales, self.anchor_ratios)

        # define anchor target layer
        self.RPN_anchor_target = _AnchorTargetLayer_FPN(self.feat_stride, self.anchor_scales, self.anchor_ratios)

        self.rpn_loss_cls = 0
        self.rpn_loss_box = 0

    @staticmethod
    def reshape(x, d):
        input_shape = x.size()
        x = x.contiguous().view(
            input_shape[0],
            int(d),
            int(float(input_shape[1] * input_shape[2]) / float(d)),
            input_shape[3]
        )
        return x

    def forward(self, rpn_feature_maps, im_info, gt_boxes, num_boxes):        

        n_feat_maps = len(rpn_feature_maps)

        rpn_cls_scores = []
        rpn_cls_probs = []
        rpn_bbox_preds = []
        rpn_shapes = []

        for i in range(n_feat_maps):
            feat_map = rpn_feature_maps[i]
            batch_size = feat_map.size(0)
            
            # return feature map after convrelu layer
            rpn_conv1 = F.relu(self.RPN_Conv(feat_map), inplace=True)
            # get rpn classification score
            rpn_cls_score = self.RPN_cls_score(rpn_conv1)

            rpn_cls_score_reshape = self.reshape(rpn_cls_score, 2)
            rpn_cls_prob_reshape = F.softmax(rpn_cls_score_reshape, 1)
            rpn_cls_prob = self.reshape(rpn_cls_prob_reshape, self.nc_score_out)

            # get rpn offsets to the anchor boxes
            rpn_bbox_pred = self.RPN_bbox_pred(rpn_conv1)

            rpn_shapes.append([rpn_cls_score.size()[2], rpn_cls_score.size()[3]])
            rpn_cls_scores.append(rpn_cls_score.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2))
            rpn_cls_probs.append(rpn_cls_prob.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2))
            rpn_bbox_preds.append(rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 4))

        rpn_cls_score_alls = torch.cat(rpn_cls_scores, 1)
        rpn_cls_prob_alls = torch.cat(rpn_cls_probs, 1)
        rpn_bbox_pred_alls = torch.cat(rpn_bbox_preds, 1)

        n_rpn_pred = rpn_cls_score_alls.size(1)

        # proposal layer
        cfg_key = 'TRAIN' if self.training else 'TEST'

        rois = self.RPN_proposal((rpn_cls_prob_alls.data, rpn_bbox_pred_alls.data,
                                 im_info, cfg_key, rpn_shapes))

        self.rpn_loss_cls = 0
        self.rpn_loss_box = 0

        # generating training labels and build the rpn loss
        if self.training:
            assert gt_boxes is not None

            rpn_data = self.RPN_anchor_target((rpn_cls_score_alls.data, gt_boxes, im_info, num_boxes, rpn_shapes))

            # compute classification loss
            rpn_label = rpn_data[0].view(batch_size, -1)
            rpn_keep = Variable(rpn_label.view(-1).ne(-1).nonzero().view(-1))
            rpn_cls_score = torch.index_select(rpn_cls_score_alls.view(-1,2), 0, rpn_keep)
            rpn_label = torch.index_select(rpn_label.view(-1), 0, rpn_keep.data)
            rpn_label = Variable(rpn_label.long())
            self.rpn_loss_cls = F.cross_entropy(rpn_cls_score, rpn_label)
            fg_cnt = torch.sum(rpn_label.data.ne(0))

            rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]
            # print(rpn_bbox_targets.shape)
            # compute bbox regression loss
            rpn_bbox_inside_weights = Variable(rpn_bbox_inside_weights.unsqueeze(2) \
                    .expand(batch_size, rpn_bbox_inside_weights.size(1), 4))
            rpn_bbox_outside_weights = Variable(rpn_bbox_outside_weights.unsqueeze(2) \
                    .expand(batch_size, rpn_bbox_outside_weights.size(1), 4))
            rpn_bbox_targets = Variable(rpn_bbox_targets)
            
            self.rpn_loss_box = _smooth_l1_loss(rpn_bbox_pred_alls, rpn_bbox_targets, rpn_bbox_inside_weights, 
                            rpn_bbox_outside_weights, sigma=3)

        return rois, self.rpn_loss_cls, self.rpn_loss_box


class _ROIPool(Function):
    @staticmethod
    def forward(ctx, input, roi, output_size, spatial_scale):
        ctx.output_size = _pair(output_size)
        ctx.spatial_scale = spatial_scale
        ctx.input_shape = input.size()
        output, argmax = _C.roi_pool_forward(
            input, roi, spatial_scale, output_size[0], output_size[1]
        )
        ctx.save_for_backward(input, roi, argmax)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        input, rois, argmax = ctx.saved_tensors
        output_size = ctx.output_size
        spatial_scale = ctx.spatial_scale
        bs, ch, h, w = ctx.input_shape
        grad_input = _C.roi_pool_backward(
            grad_output,
            input,
            rois,
            argmax,
            spatial_scale,
            output_size[0],
            output_size[1],
            bs,
            ch,
            h,
            w,
        )
        return grad_input, None, None, None


roi_pool = _ROIPool.apply


class ROIPool(nn.Module):
    def __init__(self, output_size, spatial_scale):
        super(ROIPool, self).__init__()
        self.output_size = output_size
        self.spatial_scale = spatial_scale

    def forward(self, input, rois):
        return roi_pool(input, rois, self.output_size, self.spatial_scale)

    def __repr__(self):
        tmpstr = self.__class__.__name__ + "("
        tmpstr += "output_size=" + str(self.output_size)
        tmpstr += ", spatial_scale=" + str(self.spatial_scale)
        tmpstr += ")"
        return tmpstr


class _ROIAlign(Function):
    @staticmethod
    def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
        ctx.save_for_backward(roi)
        ctx.output_size = _pair(output_size)
        ctx.spatial_scale = spatial_scale
        ctx.sampling_ratio = sampling_ratio
        ctx.input_shape = input.size()
        output = _C.roi_align_forward(input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        rois, = ctx.saved_tensors
        output_size = ctx.output_size
        spatial_scale = ctx.spatial_scale
        sampling_ratio = ctx.sampling_ratio
        bs, ch, h, w = ctx.input_shape
        grad_input = _C.roi_align_backward(
            grad_output,
            rois,
            spatial_scale,
            output_size[0],
            output_size[1],
            bs,
            ch,
            h,
            w,
            sampling_ratio,
        )
        return grad_input, None, None, None, None

roi_align = _ROIAlign.apply


def bbox_decode(rois, bbox_pred, batch_size, class_agnostic, classes, im_info, training, cls_prob):
    boxes = rois.data[:, :, 1:5]
    if cfg['TEST_BBOX_REG']:
        # Apply bounding-box regression deltas
        box_deltas = bbox_pred.data
        if cfg['TRAIN_BBOX_NORMALIZE_TARGETS_PRECOMPUTED']:
            # Optionally normalize targets by a precomputed mean and stdev
            if class_agnostic or training:
                box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_STDS']).cuda() \
                             + torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_MEANS']).cuda()
                box_deltas = box_deltas.view(batch_size, -1, 4)
            else:
                cls_prob[:,0]=0
                bbox_pred_cls_argmax=torch.argmax(cls_prob,dim=1)
                # print(bbox_pred_cls_argmax)
                for i in range(bbox_pred.size(1)):
                    bbox_pred_cls_argmax[i]=bbox_pred_cls_argmax[i]+i*classes

                bbox_pred_max=bbox_pred.view(batch_size,-1,4)
                bbox_pred_max=torch.index_select(bbox_pred_max,1,bbox_pred_cls_argmax)
                box_deltas = bbox_pred_max.data

                box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_STDS']).cuda() \
                             + torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_MEANS']).cuda()
                box_deltas = box_deltas.view(batch_size, -1,4)

        pred_boxes = bbox_transform_inv(boxes, box_deltas, batch_size)
        pred_boxes = clip_boxes(pred_boxes, im_info, batch_size)
    else:
        # Simply repeat the boxes, once for each class
        pred_boxes = boxes

    pred_boxes = pred_boxes.view(batch_size,-1,4)
    ret_boxes = torch.zeros(pred_boxes.size(0), pred_boxes.size(1), 5).cuda()
    ret_boxes[:, :, 1:] = pred_boxes
    for b in range(batch_size):
        ret_boxes[b, :, 0] = b

    return ret_boxes


class ROIAlignAvg(nn.Module):
    def __init__(self, output_size, spatial_scale, sampling_ratio):
        super(ROIAlignAvg, self).__init__()
        self.output_size = output_size
        self.spatial_scale = spatial_scale
        self.sampling_ratio = sampling_ratio

    def forward(self, input, rois,spatial_scale):
        self.spatial_scale = spatial_scale
        # x= roi_align(
            # input, rois, self.output_size, self.spatial_scale, self.sampling_ratio
        # )
        x= torchvision.ops.roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio)
        return avg_pool2d(x, kernel_size=2, stride=1)

    def __repr__(self):
        tmpstr = self.__class__.__name__ + "("
        tmpstr += "output_size=" + str(self.output_size)
        tmpstr += ", spatial_scale=" + str(self.spatial_scale)
        tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
        tmpstr += ")"
        return tmpstr        

        
class _FPN(nn.Module):
    """ FPN """
    def __init__(self, classes, class_agnostic):
        super(_FPN, self).__init__()
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = class_agnostic
        # loss
        self.RCNN_loss_cls = 0
        self.RCNN_loss_bbox = 0

        self.maxpool2d = nn.MaxPool2d(1, stride=2)
        # define rpn
        self.RCNN_rpn = _RPN_FPN(self.dout_base_model)
        self.RCNN_proposal_target = _ProposalTargetLayer(self.n_classes)

        # NOTE: the original paper used pool_size = 7 for cls branch, and 14 for mask branch, to save the
        # computation time, we first use 14 as the pool_size, and then do stride=2 pooling for cls branch.
        self.RCNN_roi_pool = ROIPool((cfg['POOLING_SIZE'], cfg['POOLING_SIZE']), 1.0/16.0)
        self.RCNN_roi_align = ROIAlignAvg((cfg['POOLING_SIZE']+1, cfg['POOLING_SIZE']+1), 1.0/16.0, 0)
        # self.RCNN_roi_pool = ROIPool(cfg['POOLING_SIZE'], cfg['POOLING_SIZE'], 1.0/16.0)
        # self.RCNN_roi_align = ROIAlignAvg(cfg['POOLING_SIZE'], cfg['POOLING_SIZE'], 1.0/16.0)
        self.grid_size = cfg['POOLING_SIZE'] * 2 if cfg['CROP_RESIZE_WITH_MAX_POOL'] else cfg['POOLING_SIZE']

    def _init_weights(self):
        def normal_init(m, mean, stddev, truncated=False):
            """
            weight initalizer: truncated normal and random normal.
            """
            # x is a parameter
            if truncated:
                m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation
            else:
                m.weight.data.normal_(mean, stddev)
                m.bias.data.zero_()

        # custom weights initialization called on netG and netD
        def weights_init(m, mean, stddev, truncated=False):
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                m.weight.data.normal_(0.0, 0.02)
                m.bias.data.fill_(0)
            elif classname.find('BatchNorm') != -1:
                m.weight.data.normal_(1.0, 0.02)
                m.bias.data.fill_(0)

        normal_init(self.RCNN_toplayer, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_smooth1, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_smooth2, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_smooth3, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_latlayer1, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_latlayer2, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_latlayer3, 0, 0.01, cfg['TRAIN_TRUNCATED'])

        normal_init(self.RCNN_rpn.RPN_Conv, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_rpn.RPN_cls_score, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_rpn.RPN_bbox_pred, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_cls_score, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_bbox_pred, 0, 0.001, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_cls_score_2nd, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_bbox_pred_2nd, 0, 0.001, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_cls_score_3rd, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        normal_init(self.RCNN_bbox_pred_3rd, 0, 0.001, cfg['TRAIN_TRUNCATED'])
        weights_init(self.RCNN_top, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        weights_init(self.RCNN_top_2nd, 0, 0.01, cfg['TRAIN_TRUNCATED'])
        weights_init(self.RCNN_top_3rd, 0, 0.01, cfg['TRAIN_TRUNCATED'])

    def create_architecture(self):
        self._init_modules()
        self._init_weights()

    def _upsample_add(self, x, y):
        '''Upsample and add two feature maps.
        Args:
          x: (Variable) top feature map to be upsampled.
          y: (Variable) lateral feature map.
        Returns:
          (Variable) added feature map.
        Note in PyTorch, when input size is odd, the upsampled feature map
        with `F.upsample(..., scale_factor=2, mode='nearest')`
        maybe not equal to the lateral feature map size.
        e.g.
        original input size: [N,_,15,15] ->
        conv2d feature map size: [N,_,8,8] ->
        upsampled feature map size: [N,_,16,16]
        So we choose bilinear upsample which supports arbitrary output sizes.
        '''
        _,_,H,W = y.size()
        return F.interpolate(x, size=(H,W), mode='bilinear',align_corners=True) + y

    def _PyramidRoI_Feat(self, feat_maps, rois, im_info):
        ''' roi pool on pyramid feature maps'''
        # do roi pooling based on predicted rois
        # print("rois shape",rois.shape)
        # print("feat_maps",feat_maps.shape)
        img_area = im_info[0][0] * im_info[0][1]
        h = rois.data[:, 4] - rois.data[:, 2] + 1
        w = rois.data[:, 3] - rois.data[:, 1] + 1
        # print(h)
        # print(w)
        
        roi_level = torch.log2(torch.sqrt(h * w) / 224.0)
        roi_level = torch.round(roi_level + 4)
        roi_level[roi_level < 2] = 2
        roi_level[roi_level > 5] = 5
        # roi_level.fill_(5)
        # print("roi_level",roi_level)
        if cfg['POOLING_MODE'] == 'align':
            roi_pool_feats = []
            box_to_levels = []
            for i, l in enumerate(range(2, 6)):
                # print(i, l)
                # print(roi_level)
                if (roi_level == l).sum() == 0:
                    continue
                
                idx_l = (roi_level == l).nonzero().squeeze()
                # print(idx_l.dim())
                # print((idx_l.cpu().numpy()))
                if(idx_l.dim()==0):
                    idx_l=idx_l.unsqueeze(0)
                    # continue
                    # print("^^^^^^^^^^^^^^^^^^^^^^",idx_l.dim())
                box_to_levels.append(idx_l)
                scale = feat_maps[i].size(2) / im_info[0][0]
                # self.RCNN_roi_align.scale=scale
                feat = self.RCNN_roi_align(feat_maps[i], rois[idx_l], scale)
                roi_pool_feats.append(feat)

            # print("box_to_levels")
            # print(box_to_levels)
            roi_pool_feat = torch.cat(roi_pool_feats, 0)
            box_to_level = torch.cat(box_to_levels, 0)
            idx_sorted, order = torch.sort(box_to_level)
            roi_pool_feat = roi_pool_feat[order]

        elif cfg['POOLING_MODE'] == 'pool':
            roi_pool_feats = []
            box_to_levels = []
            for i, l in enumerate(range(2, 6)):
                if (roi_level == l).sum() == 0:
                    continue
                idx_l = (roi_level == l).nonzero().squeeze()
                box_to_levels.append(idx_l)
                scale = feat_maps[i].size(2) / im_info[0][0]
                self.RCNN_roi_pool.scale=scale
                feat = self.RCNN_roi_pool(feat_maps[i], rois[idx_l])
                roi_pool_feats.append(feat)
            roi_pool_feat = torch.cat(roi_pool_feats, 0)
            box_to_level = torch.cat(box_to_levels, 0)
            idx_sorted, order = torch.sort(box_to_level)
            roi_pool_feat = roi_pool_feat[order]
            
        return roi_pool_feat

    def forward(self, im_data, im_info, gt_boxes, num_boxes):
        batch_size = im_data.size(0)

        im_info = im_info.data
        gt_boxes = gt_boxes.data
        num_boxes = num_boxes.data

        # feed image data to base model to obtain base feature map
        # Bottom-up
        c1 = self.RCNN_layer0(im_data)
        c2 = self.RCNN_layer1(c1)
        c3 = self.RCNN_layer2(c2)
        c4 = self.RCNN_layer3(c3)
        c5 = self.RCNN_layer4(c4)
        # Top-down
        p5 = self.RCNN_toplayer(c5)
        p4 = self._upsample_add(p5, self.RCNN_latlayer1(c4))
        p4 = self.RCNN_smooth1(p4)
        p3 = self._upsample_add(p4, self.RCNN_latlayer2(c3))
        p3 = self.RCNN_smooth2(p3)
        p2 = self._upsample_add(p3, self.RCNN_latlayer3(c2))
        p2 = self.RCNN_smooth3(p2)

        p6 = self.maxpool2d(p5)

        rpn_feature_maps = [p2, p3, p4, p5, p6]
        mrcnn_feature_maps = [p2, p3, p4, p5]

        rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(rpn_feature_maps, im_info, gt_boxes, num_boxes)
        # print("rois shape stage1:",rois.shape)
        # if it is training phrase, then use ground trubut bboxes for refining
        if self.training:
            roi_data = self.RCNN_proposal_target(rois, gt_boxes, num_boxes)
            rois, rois_label, gt_assign, rois_target, rois_inside_ws, rois_outside_ws = roi_data

            ## NOTE: additionally, normalize proposals to range [0, 1],
            #        this is necessary so that the following roi pooling
            #        is correct on different feature maps
            # rois[:, :, 1::2] /= im_info[0][1]
            # rois[:, :, 2::2] /= im_info[0][0]

            rois = rois.view(-1, 5)
            rois_label = rois_label.view(-1).long()
            gt_assign = gt_assign.view(-1).long()
            pos_id = rois_label.nonzero().squeeze()
            gt_assign_pos = gt_assign[pos_id]
            rois_label_pos = rois_label[pos_id]
            rois_label_pos_ids = pos_id

            rois_pos = Variable(rois[pos_id])
            rois = Variable(rois)
            rois_label = Variable(rois_label)

            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:
            ## NOTE: additionally, normalize proposals to range [0, 1],
            #        this is necessary so that the following roi pooling
            #        is correct on different feature maps
            # rois[:, :, 1::2] /= im_info[0][1]
            # rois[:, :, 2::2] /= im_info[0][0]

            rois_label = None
            gt_assign = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = 0
            rpn_loss_bbox = 0
            rois = rois.view(-1, 5)
            pos_id = torch.arange(0, rois.size(0)).long().type_as(rois).long()
            rois_label_pos_ids = pos_id
            rois_pos = Variable(rois[pos_id])
            rois = Variable(rois)

        roi_pool_feat = self._PyramidRoI_Feat(mrcnn_feature_maps, rois, im_info)

        # feed pooled features to top model
        pooled_feat = self._head_to_tail(roi_pool_feat)

        bbox_pred = self.RCNN_bbox_pred(pooled_feat)

        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(bbox_pred_view, 1, rois_label.long().view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        cls_score = self.RCNN_cls_score(pooled_feat)
        cls_prob = F.softmax(cls_score,1)
        # print(cls_prob)
        # print("*******************cls prob shape",cls_prob.shape)

        RCNN_loss_cls = 0
        RCNN_loss_bbox = 0

        if self.training:
            RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)
            # loss (l1-norm) for bounding box regression
            RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)

        rois = rois.view(batch_size, -1, rois.size(1))
        bbox_pred = bbox_pred.view(batch_size, -1, bbox_pred.size(1))
        
        if self.training:
            rois_label = rois_label.view(batch_size, -1)

        # 2nd-----------------------------
        # decode
        rois = bbox_decode(rois, bbox_pred, batch_size, self.class_agnostic, self.n_classes, im_info, self.training,cls_prob)

        if self.training:
            roi_data = self.RCNN_proposal_target(rois, gt_boxes, num_boxes, stage=2)
            rois, rois_label, gt_assign, rois_target, rois_inside_ws, rois_outside_ws = roi_data

            rois = rois.view(-1, 5)
            rois_label = rois_label.view(-1).long()
            gt_assign = gt_assign.view(-1).long()
            pos_id = rois_label.nonzero().squeeze()
            gt_assign_pos = gt_assign[pos_id]
            rois_label_pos = rois_label[pos_id]
            rois_label_pos_ids = pos_id

            rois_pos = Variable(rois[pos_id])
            rois = Variable(rois)
            rois_label = Variable(rois_label)

            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:
            rois_label = None
            gt_assign = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = 0
            rpn_loss_bbox = 0
            rois = rois.view(-1, 5)
            pos_id = torch.arange(0, rois.size(0)).long().type_as(rois).long()
            # print(pos_id)
            rois_label_pos_ids = pos_id
            rois_pos = Variable(rois[pos_id])
            rois = Variable(rois)

        roi_pool_feat = self._PyramidRoI_Feat(mrcnn_feature_maps, rois, im_info)
        # feed pooled features to top model
        pooled_feat = self._head_to_tail_2nd(roi_pool_feat)
        # compute bbox offset
        bbox_pred = self.RCNN_bbox_pred_2nd(pooled_feat)
        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(bbox_pred_view, 1,
                                            rois_label.long().view(rois_label.size(0), 1, 1).expand(rois_label.size(0),
                                                                                                    1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        # compute object classification probability
        cls_score = self.RCNN_cls_score_2nd(pooled_feat)
        cls_prob_2nd = F.softmax(cls_score,1) 

        RCNN_loss_cls_2nd = 0
        RCNN_loss_bbox_2nd = 0

        if self.training:
            # loss (cross entropy) for object classification
            RCNN_loss_cls_2nd = F.cross_entropy(cls_score, rois_label)
            # loss (l1-norm) for bounding box regression
            RCNN_loss_bbox_2nd = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)

        rois = rois.view(batch_size, -1, rois.size(1))
        # cls_prob_2nd = cls_prob_2nd.view(batch_size, -1, cls_prob_2nd.size(1))  ----------------not be used ---------
        bbox_pred_2nd = bbox_pred.view(batch_size, -1, bbox_pred.size(1))

        if self.training:
            rois_label = rois_label.view(batch_size, -1)

        # 3rd---------------
        # decode
        rois = bbox_decode(rois, bbox_pred_2nd, batch_size, self.class_agnostic, self.n_classes, im_info, self.training,cls_prob_2nd)

        # proposal_target
        # if it is training phrase, then use ground trubut bboxes for refining
        if self.training:
            roi_data = self.RCNN_proposal_target(rois, gt_boxes, num_boxes, stage=3)
            rois, rois_label, gt_assign, rois_target, rois_inside_ws, rois_outside_ws = roi_data

            rois = rois.view(-1, 5)
            rois_label = rois_label.view(-1).long()
            gt_assign = gt_assign.view(-1).long()
            pos_id = rois_label.nonzero().squeeze()
            gt_assign_pos = gt_assign[pos_id]
            rois_label_pos = rois_label[pos_id]
            rois_label_pos_ids = pos_id

            rois_pos = Variable(rois[pos_id])
            rois = Variable(rois)
            rois_label = Variable(rois_label)

            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:

            rois_label = None
            gt_assign = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = 0
            rpn_loss_bbox = 0
            rois = rois.view(-1, 5)
            pos_id = torch.arange(0, rois.size(0)).long().type_as(rois).long()
            rois_label_pos_ids = pos_id
            rois_pos = Variable(rois[pos_id])
            rois = Variable(rois)

        roi_pool_feat = self._PyramidRoI_Feat(mrcnn_feature_maps, rois, im_info)

        # feed pooled features to top model
        pooled_feat = self._head_to_tail_3rd(roi_pool_feat)

        # compute bbox offset
        bbox_pred = self.RCNN_bbox_pred_3rd(pooled_feat)
        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(bbox_pred_view, 1,
                                            rois_label.long().view(rois_label.size(0), 1, 1).expand(
                                                rois_label.size(0),
                                                1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        # compute object classification probability
        cls_score = self.RCNN_cls_score_3rd(pooled_feat)
        cls_prob_3rd = F.softmax(cls_score, 1)

        RCNN_loss_cls_3rd = 0
        RCNN_loss_bbox_3rd = 0

        if self.training:
            # loss (cross entropy) for object classification
            RCNN_loss_cls_3rd = F.cross_entropy(cls_score, rois_label)
            # loss (l1-norm) for bounding box regression
            RCNN_loss_bbox_3rd = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)

        rois = rois.view(batch_size, -1, rois.size(1))
        cls_prob_3rd = cls_prob_3rd.view(batch_size, -1, cls_prob_3rd.size(1))
        bbox_pred_3rd = bbox_pred.view(batch_size, -1, bbox_pred.size(1))

        if self.training:
            rois_label = rois_label.view(batch_size, -1)
        if not self.training:
            # 3rd_avg
            # 1st_3rd
            pooled_feat_1st_3rd = self._head_to_tail(roi_pool_feat)
            cls_score_1st_3rd = self.RCNN_cls_score(pooled_feat_1st_3rd)
            cls_prob_1st_3rd = F.softmax(cls_score_1st_3rd, 1)
            cls_prob_1st_3rd = cls_prob_1st_3rd.view(batch_size, -1, cls_prob_1st_3rd.size(1))
            # 2nd_3rd
            pooled_feat_2nd_3rd = self._head_to_tail_2nd(roi_pool_feat)
            cls_score_2nd_3rd = self.RCNN_cls_score_2nd(pooled_feat_2nd_3rd)
            cls_prob_2nd_3rd = F.softmax(cls_score_2nd_3rd, 1)
            cls_prob_2nd_3rd = cls_prob_2nd_3rd.view(batch_size, -1, cls_prob_2nd_3rd.size(1))

            cls_prob_3rd_avg = (cls_prob_1st_3rd + cls_prob_2nd_3rd + cls_prob_3rd) / 3
        else:
            cls_prob_3rd_avg = cls_prob_3rd

        return rois, cls_prob_3rd_avg, bbox_pred_3rd, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, RCNN_loss_cls_2nd, RCNN_loss_bbox_2nd, RCNN_loss_cls_3rd, RCNN_loss_bbox_3rd, rois_label

        
def conv3x3(in_planes, out_planes, stride=1):
  "3x3 convolution with padding"
  return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
           padding=1, bias=False)


# class BasicBlock(nn.Module):
  # expansion = 1

  # def __init__(self, inplanes, planes, stride=1, downsample=None):
    # super(BasicBlock, self).__init__()
    # self.conv1 = conv3x3(inplanes, planes, stride)
    # self.bn1 = nn.BatchNorm2d(planes)
    # self.relu = nn.ReLU(inplace=True)
    # self.conv2 = conv3x3(planes, planes)
    # self.bn2 = nn.BatchNorm2d(planes)
    # self.downsample = downsample
    # self.stride = stride

  # def forward(self, x):
    # residual = x

    # out = self.conv1(x)
    # out = self.bn1(out)
    # out = self.relu(out)

    # out = self.conv2(out)
    # out = self.bn2(out)

    # if self.downsample is not None:
      # residual = self.downsample(x)

    # out += residual
    # out = self.relu(out)

    # return out


class Bottleneck(nn.Module):
  expansion = 4

  def __init__(self, inplanes, planes, stride=1, downsample=None):
    super(Bottleneck, self).__init__()
    self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
                 padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)
    self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
    self.bn3 = nn.BatchNorm2d(planes * 4)
    self.relu = nn.ReLU(inplace=True)
    self.downsample = downsample
    self.stride = stride

  def forward(self, x):
    residual = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)

    out = self.conv3(out)
    out = self.bn3(out)

    if self.downsample is not None:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)

    return out


class ResNet(nn.Module):
  def __init__(self, block, layers, num_classes=1000):
    self.inplanes = 64
    super(ResNet, self).__init__()
    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                 bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
    self.layer1 = self._make_layer(block, 64, layers[0])
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2)   # different
    self.avgpool = nn.AvgPool2d(7)
    self.fc = nn.Linear(512 * block.expansion, num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
      elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

  def _make_layer(self, block, planes, blocks, stride=1):
    downsample = None
    if stride != 1 or self.inplanes != planes * block.expansion:
      downsample = nn.Sequential(
        nn.Conv2d(self.inplanes, planes * block.expansion,
              kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(planes * block.expansion),
      )

    layers = []
    layers.append(block(self.inplanes, planes, stride, downsample))
    self.inplanes = planes * block.expansion
    for i in range(1, blocks):
      layers.append(block(self.inplanes, planes))

    return nn.Sequential(*layers)

  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)

    return x

def resnet50(pretrained=False):
  """Constructs a ResNet-50 model.
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = ResNet(Bottleneck, [3, 4, 6, 3])
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
  return model

def resnet101(pretrained=False):
  """Constructs a ResNet-101 model.
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = ResNet(Bottleneck, [3, 4, 23, 3])
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
  return model        
        
        
class resnet(_FPN):
  def __init__(self, classes, num_layers=101, pretrained=False, class_agnostic=False):
    self.dout_base_model = 256
    self.pretrained = pretrained
    self.class_agnostic = class_agnostic
    self.num_layers = num_layers
    
    if num_layers == 101:
        self.model_path = 'data/pretrained_model/resnet101.pth'   
    elif num_layers == 50:
        self.model_path = 'data/pretrained_model/resnet50.pth'   

    _FPN.__init__(self, classes, class_agnostic)

  def _init_modules(self):
    if self.num_layers == 101:
        resnet = resnet101()
    elif self.num_layers == 50:
        resnet = resnet50()
    if self.pretrained == True:
      print("Loading pretrained weights from %s" %(self.model_path))
      state_dict = torch.load(self.model_path)
      resnet.load_state_dict({k:v for k,v in state_dict.items() if k in resnet.state_dict()})

    self.RCNN_layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
    self.RCNN_layer1 = nn.Sequential(resnet.layer1)
    self.RCNN_layer2 = nn.Sequential(resnet.layer2)
    self.RCNN_layer3 = nn.Sequential(resnet.layer3)
    self.RCNN_layer4 = nn.Sequential(resnet.layer4)

    # Top layer
    self.RCNN_toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)  # reduce channel

    # Smooth layers
    self.RCNN_smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
    self.RCNN_smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
    self.RCNN_smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

    # Lateral layers
    self.RCNN_latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
    self.RCNN_latlayer2 = nn.Conv2d( 512, 256, kernel_size=1, stride=1, padding=0)
    self.RCNN_latlayer3 = nn.Conv2d( 256, 256, kernel_size=1, stride=1, padding=0)

    # ROI Pool feature downsampling
    self.RCNN_roi_feat_ds = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)

    self.RCNN_top = nn.Sequential(
        nn.Conv2d(256, 1024, kernel_size=cfg['POOLING_SIZE'], stride=cfg['POOLING_SIZE'], padding=0),
        nn.ReLU(True),
        nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
        nn.ReLU(True)
    )

    self.RCNN_top_2nd = nn.Sequential(
        nn.Conv2d(256, 1024, kernel_size=cfg['POOLING_SIZE'], stride=cfg['POOLING_SIZE'], padding=0),
        nn.ReLU(True),
        nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
        nn.ReLU(True)
    )

    self.RCNN_top_3rd = nn.Sequential(
        nn.Conv2d(256, 1024, kernel_size=cfg['POOLING_SIZE'], stride=cfg['POOLING_SIZE'], padding=0),
        nn.ReLU(True),
        nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
        nn.ReLU(True)
    )

    self.RCNN_cls_score = nn.Linear(1024, self.n_classes)
    if self.class_agnostic:
        self.RCNN_bbox_pred = nn.Linear(1024, 4)
    else:
        self.RCNN_bbox_pred = nn.Linear(1024, 4 * self.n_classes)

    self.RCNN_cls_score_2nd = nn.Linear(1024, self.n_classes)
    if self.class_agnostic:
        self.RCNN_bbox_pred_2nd = nn.Linear(1024, 4)
    else:
        self.RCNN_bbox_pred_2nd = nn.Linear(1024, 4 * self.n_classes)

    self.RCNN_cls_score_3rd = nn.Linear(1024, self.n_classes)
    if self.class_agnostic:
        self.RCNN_bbox_pred_3rd = nn.Linear(1024, 4)
    else:
        self.RCNN_bbox_pred_3rd = nn.Linear(1024, 4 * self.n_classes)

    # Fix blocks
    for p in self.RCNN_layer0[0].parameters(): p.requires_grad=False
    for p in self.RCNN_layer0[1].parameters(): p.requires_grad=False

    if cfg['RESNET_FIXED_BLOCKS'] >= 3:
      for p in self.RCNN_layer3.parameters(): p.requires_grad=False
    if cfg['RESNET_FIXED_BLOCKS'] >= 2:
      for p in self.RCNN_layer2.parameters(): p.requires_grad=False
    if cfg['RESNET_FIXED_BLOCKS'] >= 1:
      for p in self.RCNN_layer1.parameters(): p.requires_grad=False

    def set_bn_fix(m):
      classname = m.__class__.__name__
      if classname.find('BatchNorm') != -1:
        for p in m.parameters(): p.requires_grad=False

    self.RCNN_layer0.apply(set_bn_fix)
    self.RCNN_layer1.apply(set_bn_fix)
    self.RCNN_layer2.apply(set_bn_fix)
    self.RCNN_layer3.apply(set_bn_fix)
    self.RCNN_layer4.apply(set_bn_fix)

  def train(self, mode=True):
    # Override train so that the training mode is set as we want
    nn.Module.train(self, mode)
    if mode:
      # Set fixed blocks to be in eval mode
      self.RCNN_layer0.eval()
      self.RCNN_layer1.eval()
      self.RCNN_layer2.train()
      self.RCNN_layer3.train()
      self.RCNN_layer4.train()

      self.RCNN_smooth1.train()
      self.RCNN_smooth2.train()
      self.RCNN_smooth3.train()

      self.RCNN_latlayer1.train()
      self.RCNN_latlayer2.train()
      self.RCNN_latlayer3.train()

      self.RCNN_toplayer.train()

      def set_bn_eval(m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
          m.eval()

      self.RCNN_layer0.apply(set_bn_eval)
      self.RCNN_layer1.apply(set_bn_eval)
      self.RCNN_layer2.apply(set_bn_eval)
      self.RCNN_layer3.apply(set_bn_eval)
      self.RCNN_layer4.apply(set_bn_eval)

  def _head_to_tail(self, pool5):
      block5 = self.RCNN_top(pool5)
      fc7 = block5.mean(3).mean(2)
      return fc7

  def _head_to_tail_2nd(self, pool5):
      block5 = self.RCNN_top_2nd(pool5)
      fc7 = block5.mean(3).mean(2)
      return fc7

  def _head_to_tail_3rd(self, pool5):
      block5 = self.RCNN_top_3rd(pool5)
      fc7 = block5.mean(3).mean(2)
      return fc7

#-----------------------------------------------------------------------#

P_img_ext = ["jpg", "png"]
def file_list(path, allfile):
    filelist = os.listdir(path)

    for filename in filelist:
        filepath = os.path.join(path, filename)
        if os.path.isdir(filepath):
            file_list(filepath, allfile)
        else:
            if filepath.split(".")[-1] in P_img_ext:
                # if filepath.endswith('xml'):
                #     allfile.append(filepath[0:-4].strip())
                allfile.append(filepath.strip())
    return allfile
    

def im_list_to_blob(ims):
    """Convert a list of images into a network input.

    Assumes images are already prepared (means subtracted, BGR order, ...).
    """
    max_shape = np.array([im.shape for im in ims]).max(axis=0)
    num_images = len(ims)
    blob = np.zeros((num_images, max_shape[0], max_shape[1], 3),
                    dtype=np.float32)
    for i in range(num_images):
        im = ims[i]
        blob[i, 0:im.shape[0], 0:im.shape[1], :] = im

    return blob
    
    
def _get_image_blob(im):
  """Converts an image into a network input.
  Arguments:
    im (ndarray): a color image in BGR order
  Returns:
    blob (ndarray): a data blob holding an image pyramid
    im_scale_factors (list): list of image scales (relative to im) used
      in the image pyramid
  """
  im_orig = im.astype(np.float32, copy=True)
  im_orig -= cfg['PIXEL_MEANS']

  im_shape = im_orig.shape
  im_size_min = np.min(im_shape[0:2])
  im_size_max = np.max(im_shape[0:2])

  processed_ims = []
  im_scale_factors = []

  for target_size in cfg['TEST_SCALES']:
    im_scale = float(target_size) / float(im_size_min)
    # Prevent the biggest axis from being more than MAX_SIZE
    if np.round(im_scale * im_size_max) > cfg['TEST_MAX_SIZE']:
      im_scale = float(cfg['TEST_MAX_SIZE']) / float(im_size_max)
    im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
            interpolation=cv2.INTER_LINEAR)
    im_scale_factors.append(im_scale)
    processed_ims.append(im)

  # Create a blob to hold the input images
  blob = im_list_to_blob(processed_ims)

  return blob, np.array(im_scale_factors)

def soft_nms(dets, box_scores, sigma=0.5, thresh=0.001, cuda=0):
    """
    Build a pytorch implement of Soft NMS algorithm.
    # Augments
        dets:        boxes coordinate tensor (format:[y1, x1, y2, x2])
        box_scores:  box score tensors
        sigma:       variance of Gaussian function
        thresh:      score thresh
        cuda:        CUDA flag
    # Return
        the index of the selected boxes
    """

    # Indexes concatenate boxes with the last column
    N = dets.shape[0]
    if cuda:
        indexes = torch.arange(0, N, dtype=torch.float).cuda().view(N, 1)
    else:
        indexes = torch.arange(0, N, dtype=torch.float).view(N, 1)
    dets = torch.cat((dets, indexes), dim=1)

    # The order of boxes coordinate is [y1,x1,y2,x2]
    y1 = dets[:, 0]
    x1 = dets[:, 1]
    y2 = dets[:, 2]
    x2 = dets[:, 3]
    scores = box_scores
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)

    for i in range(N):
        # intermediate parameters for later parameters exchange
        tscore = scores[i].clone()
        pos = i + 1

        if i != N - 1:
            maxscore, maxpos = torch.max(scores[pos:], dim=0)
            if tscore < maxscore:
                dets[i], dets[maxpos.item() + i + 1] = dets[maxpos.item() + i + 1].clone(), dets[i].clone()
                scores[i], scores[maxpos.item() + i + 1] = scores[maxpos.item() + i + 1].clone(), scores[i].clone()
                areas[i], areas[maxpos + i + 1] = areas[maxpos + i + 1].clone(), areas[i].clone()

        # IoU calculate
        yy1 = np.maximum(dets[i, 0].to("cpu").numpy(), dets[pos:, 0].to("cpu").numpy())
        xx1 = np.maximum(dets[i, 1].to("cpu").numpy(), dets[pos:, 1].to("cpu").numpy())
        yy2 = np.minimum(dets[i, 2].to("cpu").numpy(), dets[pos:, 2].to("cpu").numpy())
        xx2 = np.minimum(dets[i, 3].to("cpu").numpy(), dets[pos:, 3].to("cpu").numpy())
        
        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = torch.tensor(w * h).cuda() if cuda else torch.tensor(w * h)
        ovr = torch.div(inter, (areas[i] + areas[pos:] - inter))

        # Gaussian decay
        weight = torch.exp(-(ovr * ovr) / sigma)
        scores[pos:] = weight * scores[pos:]

    # select the boxes and keep the corresponding indexes
    keep = dets[:, 4][scores > thresh].int()

    return keep  
  
def soft_nms_old( det_proposal, detr_scores, method, thresh, Nt, sigma=0.5 ):
    
    '''
    the soft nms implement using python
    :param dets: the pred_bboxes
    :param method: the policy of decay pred_bbox score in soft nms
    :param thresh: the threshold
    :param Nt: Nt
    :return: the index of pred_bbox after soft nms
    '''
    # print(det_proposal)
    
    x1 = det_proposal[:, 0]
    y1 = det_proposal[:, 1]
    x2 = det_proposal[:, 2]
    y2 = det_proposal[:, 3]
    scores = detr_scores.squeeze(1)
    
    areas = ( y2 - y1 + 1. ) * ( x2 - x1 + 1. ) 
    orders = scores.argsort()[::-1]
    keep = []
    
    while orders.size > 0:
        i = orders[0]
        
        keep.append(i)
        for j in orders[1:]:
            xx1 = np.maximum( x1[i], x1[j] )
            yy1 = np.maximum( y1[i], y1[j] )
            xx2 = np.minimum( x2[i], x2[j] )
            yy2 = np.minimum( y2[i], y2[j] )
            w = np.maximum( xx2 - xx1 + 1., 0. )
            h = np.maximum( yy2 - yy1 + 1., 0. )
             
            inter = w * h
            overlap = inter / ( areas[i] + areas[j] - inter )

            if method == 1:  # linear
                if overlap > Nt:
                    weight = 1 - overlap
                else:  
                    weight = 1 
            elif method == 2:  # gaussian
                weight = np.exp( -(overlap * overlap) / sigma )
            else:  # original NMS
                if overlap > Nt:
                    weight = 0
                else:
                    weight = 1
            # print('weight:', weight)
            scores[j] = weight * scores[j]
            # print('scores[j]:', scores[j])
            # print('thresh:', thresh)
            if scores[j] < thresh:  
                orders = np.delete( orders, np.where( orders == j ) )
        
        orders = np.delete( orders, 0 ) 
    return keep
    
    
def vis_detections(im, class_name, dets, thresh=0.0):
    """Visual debugging of detections."""
    for i in range( dets.shape[0]):
        bbox = tuple(int(np.round(x)) for x in dets[i, :4])
        score = dets[i, -1]
        if score > thresh:
            cv2.rectangle(im, bbox[0:2], bbox[2:4], (0, 204, 0), 4)
            cv2.putText(im, '%s: %.3f' % (class_name, score), (bbox[0], bbox[1] + 15), cv2.FONT_HERSHEY_PLAIN,
                        2.0, (0, 0, 255), thickness=2)
    return im

    
    
if __name__ == '__main__':
    # params 
    num_layers = 50
    num_session= 1
    num_epoch  = 72
    checkpoint = 1686
    thresh_score_final = 0.05
    thresh_score_final_soft_nms = 0.5
    model_dir =os.path.join("models","res"+str(num_layers),"pascal_voc")
    data_dir = '../input/global-wheat-detection/test'

    Flag_vis = True
    
    # model name
    load_name = os.path.join("../input/mymodels", 'cascade_fpn_{}_{}_{}.pth'.format(num_session, num_epoch, checkpoint))
    # load_name = 'cascade_fpn_{}_{}_{}.pth'.format(num_session, num_epoch, checkpoint)
    
    # Network
    FPN = resnet(classes, num_layers, pretrained=False, class_agnostic=False)
    FPN.create_architecture()

    # load model
    checkpoint = torch.load(load_name)
    FPN.load_state_dict(checkpoint['model'])
    
    # set mode
    FPN.cuda()
    FPN.eval()
    
    print("load checkpoin---->", load_name)

    # get test images
    imgs_list= []
    file_list(data_dir, imgs_list)
   
    # commit submission
    submission = []
    for idx, img in enumerate(imgs_list):
        str_print = "total:{}--currnet:{}--img:{}".format(len(imgs_list),idx,os.path.basename(img))
        print(img)

        # load an image
        im = cv2.imread(img)

        
        if Flag_vis:     
            img_show = im.copy()
        
        prediction_string = []
        
        # prepare im for forward
        im = im[:,:,::-1] # BGR--->RGB
        blobs, im_scales = _get_image_blob(im)
        im_blob = blobs
        im_info_np = np.array([[im_blob.shape[1], im_blob.shape[2], im_scales[0]]], dtype=np.float32)
    
        im_data_pt = torch.from_numpy(im_blob)       # numpy to tensor
        im_data_pt = im_data_pt.permute(0, 3, 1, 2)  # NHWC --> NCHW
        im_info_pt = torch.from_numpy(im_info_np)    # numpy to tensor
        
        with torch.no_grad():
                im_data = im_data_pt.cuda()  
                im_info = im_info_pt.cuda()
                num_boxes = torch.zeros((1), dtype = torch.int64).cuda()
                gt_boxes = torch.zeros((1,1,5), dtype = torch.float32).cuda()
                
        
        # forward
        rois, cls_prob, bbox_pred,rpn_loss_cls, rpn_loss_box, RCNN_loss_cls, RCNN_loss_bbox, RCNN_loss_cls_2nd, RCNN_loss_bbox_2nd, RCNN_loss_cls_3rd, RCNN_loss_bbox_3rd, roi_labels = FPN(im_data, im_info, gt_boxes, num_boxes)

        # parse result
        scores = cls_prob.data
        boxes = rois.data[:, :, 1:5]
        
        # box transform
        class_agnostic = False
        args_cuda = True
        if cfg['TEST_BBOX_REG']:
            # Apply bounding-box regression deltas
            box_deltas = bbox_pred.data
            if cfg['TRAIN_BBOX_NORMALIZE_TARGETS_PRECOMPUTED']:
            # Optionally normalize targets by a precomputed mean and stdev
                if class_agnostic:
                    if args_cuda > 0:
                        box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_STDS']).cuda() + torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_MEANS']).cuda()
                    else:
                        box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_STDS']) + torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_MEANS'])
                    box_deltas = box_deltas.view(1, -1, 4)
                else:
                    if args_cuda > 0:
                        box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_STDS']).cuda() + torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_MEANS']).cuda()
                    else:
                        box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_STDS']) + torch.FloatTensor(cfg['TRAIN_BBOX_NORMALIZE_MEANS'])
                    box_deltas = box_deltas.view(1, -1, 4 * len(pascal_classes))
        
            pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
            pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)
        else:
            # Simply repeat the boxes, once for each class
            pred_boxes = np.tile(boxes, (1, scores.shape[1]))
        
        pred_boxes /= im_scales[0]   

        scores = scores.squeeze()
        pred_boxes = pred_boxes.squeeze()

        
        # filter boxes
        for j in range(1, len(pascal_classes)):
            inds = torch.nonzero(scores[:,j]>thresh_score_final).view(-1)
            # if there is det
            if inds.numel() == 0:
                prediction_string.append("")
            else:
                cls_scores = scores[:,j][inds]
                _, order = torch.sort(cls_scores, 0, True)
                if class_agnostic:
                  cls_boxes = pred_boxes[inds, :]
                else:
                  cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]
                
                cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
                cls_dets = cls_dets[order]
                keep = soft_nms(cls_boxes[order, :], cls_scores[order], thresh=thresh_score_final_soft_nms, cuda=1)
                               
                cls_dets = cls_dets[keep.view(-1).long()]            
                det_final = cls_dets.cpu().numpy()


                if det_final.shape[0]==0:
                    prediction_string.append("")
                else:
                    for i in range( det_final.shape[0]):
                        bbox = tuple(int(np.round(x)) for x in det_final[i, :4])
                        score = det_final[i, -1]
                        x = int(bbox[0])
                        y = int(bbox[1])
                        w = int(bbox[2]-bbox[0])
                        h = int(bbox[3]-bbox[1])
                        s = float(score)
                        prediction_string.append("{} {} {} {} {}".format(s,x,y,w,h))
                                    
                if Flag_vis:
                    im2show = vis_detections(img_show, pascal_classes[j], det_final)
        
        img_name_save = os.path.basename(img)[:-4]
        prediction_string = " ".join(prediction_string)
        submission.append([img_name_save,prediction_string])
        
        
        
        #if Flag_vis:
            #result_path = os.path.join("results_vis", os.path.basename(img)) # chagne here
            #cv2.imwrite(result_path, im2show)


    sample_submission = pd.DataFrame(submission, columns=["image_id","PredictionString"])
    sample_submission.to_csv('submission.csv', index=False)

        
    print("\n----------------------END----------------------")
   
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    