In [2]:
import torch
import torch.nn as nn

In [3]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
         nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

In [4]:
net = double_conv(128, 64)
print(net)

Sequential(
  (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
)


In [10]:
class Unet(nn.Module):

    def __init__(self):
        super().__init__()

        self.adnet = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(2, 1), padding=0, stride=(2, 1)),
            nn.BatchNorm2d(1),  # 加入Bn层提高网络泛化能力（防止过拟合），加收敛速度
            nn.ReLU(inplace=True),
            nn.Conv2d(1, 1, kernel_size=3, padding=5, stride=1),
            nn.BatchNorm2d(1),  # 加入Bn层提高网络泛化能力（防止过拟合），加收敛速度
            nn.ReLU(inplace=True)
        )

        self.dconv_down0 = double_conv(1, 32)
        self.dconv_down1 = double_conv(32, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)
        # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.upsample4 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1)
        self.upsample3 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)
        self.upsample2 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
        self.upsample1 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
        self.upsample0 = nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1)

        self.dconv_up3 = double_conv(256 + 256, 256)
        self.dconv_up2 = double_conv(128 + 128, 128)
        self.dconv_up1 = double_conv(64 + 64, 64)
        self.dconv_up0 = double_conv(64, 32)

        self.conv_last = nn.Conv2d(16, 1, 1)

    def forward(self, x):
        # reshape
        x = self.adnet(x)  # 1x128x128

        # encode
        conv0 = self.dconv_down0(x)  # 32x128x128
        x = self.maxpool(conv0)  # 32x64x64

        conv1 = self.dconv_down1(x)  # 64x64x64
        x = self.maxpool(conv1)  # 64x32x32

        conv2 = self.dconv_down2(x)  # 128x32x32
        x = self.maxpool(conv2)  # 128x16x16

        conv3 = self.dconv_down3(x)  # 256x16x16
        x = self.maxpool(conv3)  # 256x8x8

        x = self.dconv_down4(x)  # 512x8x8

        # decode
        x = self.upsample4(x)  # 256x16x16
        # 因为使用了3*3卷积核和 padding=1 的组合，所以卷积过程图像尺寸不发生改变，所以省去了crop操作！
        x = torch.cat([x, conv3], dim=1)  # 512x16x16

        x = self.dconv_up3(x)  # 256x16x16
        x = self.upsample3(x)  # 128x32x32
        x = torch.cat([x, conv2], dim=1)  # 256x32x32

        x = self.dconv_up2(x)  # 128x32x32
        x = self.upsample2(x)  # 64x64x64
        x = torch.cat([x, conv1], dim=1)  # 128x64x64

        x = self.dconv_up1(x)  # 64x64x64
        x = self.upsample1(x)  # 32x128x128
        x = torch.cat([x, conv0], dim=1)  # 64x128x128

        x = self.dconv_up0(x)  # 32x128x128
        x = self.upsample0(x)   # 16x256x256

        out = self.conv_last(x)  # 1x256x256

        return out

In [11]:
net = Unet()
print(net)

Unet(
  (adnet): Sequential(
    (0): Conv2d(1, 1, kernel_size=(2, 1), stride=(2, 1))
    (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5))
    (4): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (dconv_down0): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (dconv_down1): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tr

In [16]:
X = torch.randn((255, 1, 240, 120), dtype=torch.float)
net(X).shape

torch.Size([255, 1, 256, 256])