### 1. P_Net 网络架构图

![P_NET](p_net.jpg)

### 2. 搭建模型

In [1]:
import torch
from torch import nn

In [2]:
class PNet(nn.Module):
    """
        PNet 网络结构
            - 全卷积网络
    """
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, 
                               out_channels=10, 
                               kernel_size=3,
                               stride=1,
                               padding=0)
        self.mp1 = nn.MaxPool2d(kernel_size=3,
                               stride=2,
                               padding=1)
        self.conv2 = nn.Conv2d(in_channels=10, 
                              out_channels=16,
                              kernel_size=3,
                              stride=1,
                              padding=0)
        self.conv3 = nn.Conv2d(in_channels=16,
                              out_channels=32,
                              kernel_size=3,
                              stride=1,
                              padding=0)
        self.classifier = nn.Conv2d(in_channels=32,
                                   out_channels=2,
                                   kernel_size=1,
                                   stride=1,
                                   padding=0)
        self.regressor = nn.Conv2d(in_channels=32,
                                  out_channels=4,
                                  kernel_size=1,
                                  stride=1,
                                  padding=0)
    def forward(self, x):
        x = self.conv1(x)
        x = self.mp1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        cls = self.classifier(x)
        reg = self.regressor(x)
        return cls, reg

In [3]:
# [N, C, H, W]
# [2, 3, 12, 12]

model = PNet()

In [4]:
# [N, C, H, W]
X = torch.randn(2, 3, 1920, 1080)
cls, reg = model(X)
print(cls.shape)
print(reg.shape)

torch.Size([2, 2, 955, 535])
torch.Size([2, 4, 955, 535])


In [5]:
import math

In [6]:
math.ceil((1920 - 12) / 2 + 1)

955

In [7]:
955 * 535

510925

### 3. R_Net网络结构

![R_NET](r_net.jpg)

In [43]:
class RNet(nn.Module):
    """
        RNet网络
    """
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3,
                              out_channels=28,
                              kernel_size=3,
                              stride=1,
                              padding=0)
        self.mp1 = nn.MaxPool2d(kernel_size=3,
                               stride=2,
                               padding=1)
        self.conv2 = nn.Conv2d(in_channels=28,
                              out_channels=48,
                              kernel_size=3,
                              stride=1,
                              padding=0)
        self.mp2 = nn.MaxPool2d(kernel_size=3,
                               stride=2,
                               padding=0)
        self.conv3 = nn.Conv2d(in_channels=48,
                              out_channels=64,
                              kernel_size=2,
                              stride=1,
                              padding=0)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(in_features=576, out_features=128)
        self.classifier = nn.Linear(in_features=128, out_features=2)
        self.regressor = nn.Linear(in_features=128, out_features=4)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.mp1(x)
        x = self.conv2(x)
        x = self.mp2(x)
        x = self.conv3(x)
        x = self.flatten(x)
        x = self.linear(x)
        cls = self.classifier(x)
        reg = self.regressor(x)
        return cls, reg

In [45]:
# [N, C, H, W]
X = torch.randn(2, 3, 24, 24)
model = RNet()
cls, reg = model(X)

In [46]:
cls.shape

torch.Size([2, 2])

In [47]:
reg.shape

torch.Size([2, 4])