In [2]:
import cv2
import torch
import torchvision
import torch.nn as nn
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

In [3]:
# 图像数据
img = cv2.imread('1.jpg')
# label数据，好像有几个类
label = {
    "boxes":[[100, 100, 200, 300]],
    "label": 1,
    "image_id": 2
    }

In [4]:
img = []
img.append(cv2.imread('1.jpg'))
img = torch.LongTensor(img).permute(0, 3, 1, 2).float()
img.shape

torch.Size([1, 3, 333, 500])

In [5]:
backbone = resnet_fpn_backbone('resnet50', False)
output = backbone(img)
print(len(output))

5


In [6]:
class RPNhead(nn.Module):

    def __init__(self, in_channels, num_anchors):
        super(RPNhead, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, in_channels, kernel_size=3, stride=1, padding=1
        )
        self.cls_logits = nn.Conv2d(
            in_channels, num_anchors, kernel_size=1, stride=1
        )
        self.bbox_pred = nn.Conv2d(
            in_channels, num_anchors * 4, kernel_size=1, stride=1
        )

        # init parameters
        for l in self.children():
            torch.nn.init.normal_(l.weight, std=0.01)
            torch.nn.init.constant_(l.bias, 0)

    def forward(self, x):
        # 输入的是feature_map, feature_map有可能是多层(resnet)
        # x: C * [batch_size, out_channel, H_out, W_out]
        features = list(x.values())
        logits = []
        bbox_reg = []
        for feature in features:
            t = nn.functional.relu(self.conv(feature))
            logits.append(self.cls_logits(t))
            bbox_reg.append(self.bbox_pred(t))
        return bbox_reg, logits

In [7]:
input_channel = backbone.out_channels
Anchor_sizes = (16, 64, 128)
aspect_ratios = (0.5, 1, 2)

In [8]:
rpn_head = RPNhead(input_channel, len(Anchor_sizes) * len(aspect_ratios))
bbox_reg, logits = rpn_head(output)

In [123]:
# 生成anchor的各个中心点
# 已知feature_map有5层
[i.shape for i in output.values()]

[torch.Size([1, 256, 84, 125]),
 torch.Size([1, 256, 42, 63]),
 torch.Size([1, 256, 21, 32]),
 torch.Size([1, 256, 11, 16]),
 torch.Size([1, 256, 6, 8])]

In [125]:
# 获取每层feature_map的尺寸
feature_size = [i.shape[-2:] for i in output.values()]

In [126]:
# 获取输入图片的尺寸
img_size = img.shape[-2:]

In [127]:
print(feature_size)
print(img_size)

[torch.Size([84, 125]), torch.Size([42, 63]), torch.Size([21, 32]), torch.Size([11, 16]), torch.Size([6, 8])]
torch.Size([333, 500])


In [131]:
#计算每层的stride
stride = torch.LongTensor([[img_size[0]//f[0], img_size[1]//f[1]] for f in feature_size])

In [132]:
print(stride)

tensor([[ 3,  4],
        [ 7,  7],
        [15, 15],
        [30, 31],
        [55, 62]])


In [142]:
# 这样的话我们就可以生成,每个anchor的中心点
centers = []
for f, s in zip(feature_size, stride):
    y = torch.arange(0, f[0]) * s[0]
    x = torch.arange(0, f[1]) * s[1]
    y, x = torch.meshgrid(y,x)
    y = y.reshape(-1)
    x = x.reshape(-1)
    center = torch.stack((y,x,y,x), dim = 1)
    centers.append(center)

In [143]:
print(len(centers))
[i.shape for i in centers]

5


[torch.Size([10500, 4]),
 torch.Size([2646, 4]),
 torch.Size([672, 4]),
 torch.Size([176, 4]),
 torch.Size([48, 4])]

In [148]:
# 生成每个anchor的长宽
# Anchor_sizes = (16, 64, 128)
# aspect_ratios = (0.5, 1, 2)
Anchor_scales = torch.as_tensor(Anchor_sizes, dtype = torch.float32)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype = torch.float32)

In [152]:
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1/h_ratios

ws = h_ratios[:, None] * Anchor_scales[None, :].view(-1)
hs = w_ratios[:, None] * Anchor_scales[None, :].view(-1)

In [158]:
base_anchor = torch.stack([-ws, -hs, ws, hs], dim = 1)/2

In [159]:
base_anchor

tensor([[[ -5.6569, -22.6274, -45.2548],
         [-11.3137, -45.2548, -90.5097],
         [  5.6569,  22.6274,  45.2548],
         [ 11.3137,  45.2548,  90.5097]],

        [[ -8.0000, -32.0000, -64.0000],
         [ -8.0000, -32.0000, -64.0000],
         [  8.0000,  32.0000,  64.0000],
         [  8.0000,  32.0000,  64.0000]],

        [[-11.3137, -45.2548, -90.5097],
         [ -5.6569, -22.6274, -45.2548],
         [ 11.3137,  45.2548,  90.5097],
         [  5.6569,  22.6274,  45.2548]]])