In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import init
from core.config import cfg
from modeling.generate_anchors import generate_anchors


class autoctx_outputs(nn.Module):
    def __init__(self,dim_in,spatial_scale):
        self.spatial_scale = spatial_scale
        self.anchors = generate_anchors(
                            stride = 1./ spatial_scale,
                            sizes = cfg.RPN.SIZES,
                            aspect_ratios = cfg.RPN.ASPECT_RATIOS
                            )
        self.anchors_per_location = self.anchors.shape(0)
        # Getting anchor scores
        self.cand_anchors_scores = nn.Conv2d(self.dim_in,self.anchors_per_location,1,1,0)
        self._init_weights()
        
    def _init_weights(self):
        pass
    
    def forward(self,feature_map,object_roi):
        cand_anchors_scores = self.cand_anchors_scores(feature_map)
        b,c,h,w = cand_anchors_scores(feature_map).size()
        if cfg.CTX_ACTIVATION == "softmax":
            cand_anchors_prob = F.softmax(cand_anchors_scores.contiguous().view(b,c,h,w),dim = 1)
        else:
            cand_anchors_prob = F.sigmoid(cand_anchors_scores)
            
        cand_anchors_prob = cand_anchors_prob.contiguous().view(-1,1)
        prob_idx = torch.from_numpy(np.arange(cand_anchors_prob.size(0))[:,np.newaxis])
        all_prob_with_idx = torch.cat((prob_idx,cand_anchors_prob),dim = 1)
        shift_x = np.arange(0,w) / self.spatial_scale
        shift_y = np.arange(0,h) / self.spatial_scale
        shift_x,shift_y = np.meshgrid(shift_x,shift_y,copy = False)
        shifts = np.vstack((shift_x.ravel(),shift_y.ravel(),shift_x.ravel(),shift_y.ravel())).transpose()
        all_anchors = anchors[np.newaxis, :, :] + shifts[:, np.newaxis, :]
        all_anchors = all_anchors.transpose((1,0,2)).reshape(c,h,w,4)
        for i in range(b):
            all_image_anchors[i] = all_anchors
        # (b*c*h*w,4)
        all_image_anchors = all_image_anchors.reshape(-1,4)
        # make labels for each anchor idx and stack together with the coordinate
        anchors_idx = np.arange(all_image_anchors.shape[0])[:,np.newaxis]
        all_image_anchors_with_idx = np.hstack((anchors_idx,all_image_anchors))
# ----------------------------------------------------------------------------- #
        # The object roi's coordinate should be in the feature map
        # So remember to transfer the scale in the main func
        ctx_cells = getting_ctx_cell(object_roi)
        ctx_roi = []
        for ctx_cell in ctx_cells:
            # Select those anchors in current cell from all anchors
            begin_idx_1,end_idx_1,begin_idx_2,end_idx_2 = get_candidate_anchors(object_roi,ctx_cell,h,w)
            ctx_h = ctx_cell[4] - ctx_cell[2]
            ctx_w = ctx_cell[3] - ctx_cell[1]
            ctx_c = c
            interval = ctx_h * ctx_c
            anchors_in_cell = []
            for i in range(ctx_w):
                begin_idx = begin_idx_1 + interval * i
                end_idx = end_idx_1 + interval * i
                anchors_in_cell = anchors_in_cell.append(all_image_anchors_with_idx[begin_idx:end_idx])
            # All anchors in this cell
            anchors_in_cell = np.array(anchors_in_cell)
            filtered_size_anchors = filter_larger_anchors(object_roi,anchors_in_cell)
            # filter those anchors which have small iou with the anchor roi
            filtered_iou_anchors = filter_small_iou_anchors(object_roi,filtered_size_anchors,ctx_cell)
            # getting the prob scores w.r.t the remaining anchors in this cell
            filtered_anchors_prob = all_prob_with_idx[filtered_iou_anchors[:,0]]
            # select the best one
            max_anchor_idx = np.argmax(filtered_anchors_prob[:,0])
            max_anchor = (all_image_anchors[filtered_anchors_prob[max_anchor_idx][0]])
            ctx_roi.append(max_anchor)
        ctx_roi.append(object_roi)
        ctx_roi = np.vstack(ctx_roi)
        
        # the ctx_roi should be ndarray with shape of (9*5)
        return ctx_roi
    
    def compute_iou(box_1,box_2,box1_area,box2_area):
        # Calculate intersection areas
        x1 = np.maximum(box_1[0],box_2[0])
        x2 = np.minimum(box_1[2],box_2[2])
        y1 = np.maximum(box_1[1],box_2[1])
        y2 = np.minimum(box_1[3],box_2[3])
        intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
        union = box1_area + box2_area - intersection
        iou = intersection / union
        
        return iou
    
    # filter those iou_with_anchorroi < threshold anchors
    def filter_small_iou_anchors(object_roi,anchors,ctx_cell):
        width = object_roi[3] - object_roi[1]
        height = object_roi[4] - object_roi[2]
        anchor_roi = np.empty(4)
        # The four coordinate of anchor_roi(x1,y1,x2,y2)
        anchor_roi[0] = np.maximum((ctx_cell[1] + ctx_cell[3] - width)/2, 0)
        anchor_roi[1] = np.maximum((ctx_cell[2] + ctx_cell[4] - height)/2, 0)
        anchor_roi[2] = np.maximum((ctx_cell[1] + ctx_cell[3] + width)/2, 0)
        anchor_roi[3] = np.maximum((ctx_cell[2] + ctx_cell[4] + height)/2, 0)
        iou_list = []
        for anchor in anchors:
            iou = compute_iou(anchor_roi,anchors[1:])
            iou_list.append(iou)
        iou_array = np.array(iou_list)
        # filter those anchors whose iou < threshold,e.g,0.3
        keep_anchors_idx = np.where(iou_array > cfg.IOU_THRESHOLD)
        
        return keep_anchors_idx
    
    # filter those small ctx anchors
    def filter_larger_anchors(object_roi,anchors_in_cell):
        width = object_roi[3] - object_roi[1]
        height = object_roi[4] - object_roi[2]
        filter_idx = np.where(((anchors_in_cell[:,3] - anchors_in_cell[:,1])<= width/2) 
                           and (anchors_in_cell[:,4] - anchors_in_cell[:,2] <= height/2))[0]
        filter_size_list = anchors_in_cell[filter_idx]
        
        return filter_size_list
    
    # Getting the candidate anchors for each object_roi in each ctx cell
    # The cand anchors should be limited to the same image,so each idx
    # should add the image dim for the same origin.
    def get_candidate_anchors(object_roi,ctx_cell,h,w):
        begin_idx_1 = ctx_cell[1] * ctx_cell[2]* self.anchors_per_location + object_roi[0]*h*w*self.anchors_per_location  #x1×y1×num_per_location
        end_idx_1 = ctx_cell[3] * ctx_cell[2]* self.anchors_per_location + object_roi[0]*h*w*self.anchors_per_location #x2×y1×num_per_location
        begin_idx_2 = ctx_cell[1] * ctx_cell[4]* self.anchors_per_location + object_roi[0]*h*w*self.anchors_per_location #x1×y2×num_per_location
        end_idx_2 = ctx_cell[3] * ctx_cell[4]* self.anchors_per_location + object_roi[0]*h*w*self.anchors_per_location #x2×y2×num_per_location
        
        return begin_idx_1,end_idx_1,begin_idx_2,end_idx_2
    
    def getting_ctx_cell(object_roi,num_ctx = 8):
        # The shape of ctx_cells:(num_ctx+1,5)
        # The shape of object roi:(5,)
        width = object_roi[3] - object_roi[1]
        height = object_roi[4] - object_roi[2]
        ctx_cells = np.empty((num_ctx,5),dtype = np.float32)
        # ctx_cells also keep the roi index with it
        ctx_cells[:,0] = object_roi[0]
        ctx_cells[0][1:] = np.array([object_roi[1] + width,object_roi[2],\
                                     object_roi[3] + width,object_roi[4]]) #ctx_right_cell
        ctx_cells[1][1:] = np.array([object_roi[1] - width,object_roi[2],\
                                     object_roi[3] - width,object_roi[4]]) #ctx_left_cell
        ctx_cells[2][1:] = np.array([object_roi[1],object_roi[2] + height,\
                                     object_roi[3],object_roi[4] + height]) #ctx_down_cell
        ctx_cells[3][1:] = np.array([object_roi[1],object_roi[2] - height,\
                                     object_roi[3],object_roi[4] - height]) #ctx_top_cell
        ctx_cells[4][1:] = np.array([object_roi[1] + width,object_roi[2] + height,\
                                     object_roi[3] + width,object_roi[4] + height])#downright
        ctx_cells[5][1:] = np.array([object_roi[1] - width,object_roi[2] + height,\
                                     object_roi[3] - width,object_roi[4] + height]) #downleft
        ctx_cells[6][1:] = np.array([object_roi[1] + width,object_roi[2] - height,\
                                     object_roi[3] + width,object_roi[4] - height]) #topright
        ctx_cells[7][1:] = np.array([object_roi[1] - width,object_roi[2] - height,\
                                     object_roi[3] - width,object_roi[4] - height]) #topleft
        
        return ctx_cells