In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class RPN(nn.Module):
    def __init__(self, in_channels, num_anchors):
        super(RPN, self).__init__()

        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.cls_head = nn.Conv2d(in_channels, num_anchors * 2, kernel_size=1, stride=1)
        self.reg_head = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)

    def forward(self, x):
        x = F.relu(self.conv(x))
        logits = self.cls_head(x)
        offsets = self.reg_head(x)

        # Reshape logits and offsets
        batch_size = x.size(0)
        logits = logits.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2)
        offsets = offsets.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 4)

        return logits, offsets

In [None]:
net = RPN(512, 10)
batch_x = torch.randn(256, 512, 4, 4)
print(batch_x.size(), end = " -> ")
y = net(batch_x)
print(y.size())