# 1. MTCNN的第一阶段： P-Net 网络架构图

In [1]:
import torch
from torch import nn

![p_net](p_net.jpg)

In [2]:
class Pnet(nn.Module):
    """
        P-net网络结构
            - 全卷积网络
    """
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3, stride=1, padding=0)
        self.mp1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=10, out_channels=16, kernel_size=3, stride=1, padding=0)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=0)

        self.classification = nn.Conv2d(in_channels=32, out_channels=2, kernel_size=1, stride=1, padding=0)
        self.regression = nn.Conv2d(in_channels=32, out_channels=4, kernel_size=1, stride=1, padding=0)
        self.localization = nn.Conv2d(in_channels=32, out_channels=10, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        print(x.shape)
        x = self.conv1(x)
        print(x.shape)
        x = self.mp1(x)
        print(x.shape)
        x = self.conv2(x)
        print(x.shape)
        x = self.conv3(x)
        print(x.shape)
        cls = self.classification(x)
        reg = self.regression(x)
        loc = self.localization(x)
        return cls, reg, loc

In [3]:
# 模拟一个数据输入:N, C, H, W
X = torch.randn(2, 3, 12, 12)
# 实例化模型
pnet = Pnet()
cls, reg, loc = pnet(X)

torch.Size([2, 3, 12, 12])
torch.Size([2, 10, 10, 10])
torch.Size([2, 10, 5, 5])
torch.Size([2, 16, 3, 3])
torch.Size([2, 32, 1, 1])


In [4]:
print(cls.shape, reg.shape, loc.shape)

torch.Size([2, 2, 1, 1]) torch.Size([2, 4, 1, 1]) torch.Size([2, 10, 1, 1])


# 2. MTCNN的第二阶段： R-Net 网络架构图

![r_net](r_net.jpg)

In [5]:
class Rnet(nn.Module):
    """
        Rnet网络
    """
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=28, kernel_size=3, stride=1, padding=0)
        self.mp1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=28, out_channels=48, kernel_size=3, stride=1, padding=0)
        self.mp2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        self.conv3 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=2, stride=1, padding=0)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features=576, out_features=128)

        self.classification = nn.Linear(in_features=128, out_features=2)
        self.regression = nn.Linear(in_features=128, out_features=4)
        self.localization = nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        print(x.shape)
        x = self.conv1(x)
        print(x.shape)
        x = self.mp1(x)
        print(x.shape)
        x = self.conv2(x)
        print(x.shape)
        x = self.mp2(x)
        print(x.shape)
        x = self.conv3(x)
        print(x.shape)
        x = self.flatten(x)
        x = self.fc1(x)
        print(x.shape)
        cls = self.classification(x)
        reg = self.regression(x)
        loc = self.localization(x)
        return cls, reg, loc
        

In [6]:
X = torch.randn(2, 3, 24, 24)
rnet = Rnet()
cls, reg, loc = rnet(X)

torch.Size([2, 3, 24, 24])
torch.Size([2, 28, 22, 22])
torch.Size([2, 28, 11, 11])
torch.Size([2, 48, 9, 9])
torch.Size([2, 48, 4, 4])
torch.Size([2, 64, 3, 3])
torch.Size([2, 128])


In [7]:
print(cls.shape, reg.shape, loc.shape)

torch.Size([2, 2]) torch.Size([2, 4]) torch.Size([2, 10])


# 3. MTCNN的第三阶段： O-Net 网络架构图

![o_net](o_net.jpg)

In [8]:
class Onet(nn.Module):
    """
        Onet网络
    """
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=0)
        self.mp1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=0)
        self.mp2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0)
        self.mp3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=2, stride=1, padding=0)        
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features=1152, out_features=256)

        self.classification = nn.Linear(in_features=256, out_features=2)
        self.regression = nn.Linear(in_features=256, out_features=4)
        self.localization = nn.Linear(in_features=256, out_features=10)

    def forward(self, x):
        print(x.shape)
        x = self.conv1(x)
        print(x.shape)
        x = self.mp1(x)
        print(x.shape)
        x = self.conv2(x)
        print(x.shape)
        x = self.mp2(x)
        print(x.shape)
        x = self.conv3(x)
        print(x.shape)
        x = self.mp3(x)
        print(x.shape)
        x = self.conv4(x)
        print(x.shape)
        x = self.flatten(x)
        print(x.shape)
        x = self.fc1(x)
        print(x.shape)
        cls = self.classification(x)
        reg = self.regression(x)
        loc = self.localization(x)
        return cls, reg, loc
        

In [9]:
X = torch.randn(2, 3, 48, 48)
onet = Onet()
cls, reg, loc = onet(X)

torch.Size([2, 3, 48, 48])
torch.Size([2, 32, 46, 46])
torch.Size([2, 32, 23, 23])
torch.Size([2, 64, 21, 21])
torch.Size([2, 64, 10, 10])
torch.Size([2, 64, 8, 8])
torch.Size([2, 64, 4, 4])
torch.Size([2, 128, 3, 3])
torch.Size([2, 1152])
torch.Size([2, 256])


In [10]:
print(cls.shape, reg.shape, loc.shape)

torch.Size([2, 2]) torch.Size([2, 4]) torch.Size([2, 10])
