In [3]:
import torch
from torch import nn
from torch.nn import functional as F

In [None]:
# VNet 有可用的预训练模型吗？使用预训练模型的增益到底大不大？

In [25]:
# Convolutions are all applied with appropriate padding.
# 卷积层都不改变 Volume 尺寸，只是改变通道数量
# Down Conv 用来改变 Volume 尺寸

class Residual(nn.Module):
    def __init__(self, input_channels:int, output_channels:int, num_conv=3, use_1x1conv=False, stride=1, res_channels=None):
        """之所以有 `res_channels`，是因为残差可以由自定义，不定义的话，默认就是输入 X
        """
        super().__init__()
        self.convs = nn.ModuleList([])
        self.bns = nn.ModuleList([])
        self.input_channels = input_channels
        self.num_conv = num_conv
        for i in range(num_conv):
            self.convs.append(nn.Conv3d(input_channels, output_channels, kernel_size=5, stride=stride, padding=2))
            self.bns.append(nn.BatchNorm3d(output_channels))
            input_channels = output_channels
        if use_1x1conv:
            if res_channels is None:
                res_channels = self.input_channels
            self.conv1x1 = nn.Conv3d(res_channels, 
            output_channels, kernel_size=1, stride=stride)
        else:
            self.conv1x1 = None
    def forward(self, X, res=None):
        Y = X
        # 和李沐讲的不同的是，VNet 直接用的 ReLU Last，也就是最后一个就是在 Y+=X 之前做 ReLU 而不是 Y+=X 之后再做 ReLU
        for i in range(self.num_conv):
            Y = F.relu(self.bns[i](self.convs[i](Y)))
        if self.conv1x1:
            if res is None:
                X = self.conv1x1(X)
            else:
                X = self.conv1x1(res)
        Y += X
        return Y

In [5]:
blk = Residual(1, 16, 1, True)
# Batch size, channels, Height, Width, Depth
X = torch.rand((1, 1, 128, 128, 64))
Y = blk(X)
Y.shape

torch.Size([1, 16, 128, 128, 64])

In [6]:
downConv = nn.Conv3d(16, 32, kernel_size=2, stride=2)
Y1 = downConv(Y)
print(Y1.shape)

torch.Size([1, 32, 64, 64, 32])


In [7]:
blk1 = Residual(32, 32, 2, False)
Y2 = blk1(Y1)
print(Y2.shape)

torch.Size([1, 32, 64, 64, 32])


In [8]:
print(blk)
print(blk1)

Residual(
  (convs): ModuleList(
    (0): Conv3d(1, 16, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
  )
  (bns): ModuleList(
    (0): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv1x1): Conv3d(1, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
)
Residual(
  (convs): ModuleList(
    (0): Conv3d(32, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
    (1): Conv3d(32, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
  )
  (bns): ModuleList(
    (0): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


In [9]:
class Encoder(nn.Module):
    def __init__(self, channels=(16, 32, 64, 128, 256), num_conv=(1, 2, 3, 3, 3)):
        super().__init__()
        self.num_stages = len(num_conv)
        self.enc_blks = nn.ModuleList([])
        self.down_convs = nn.ModuleList([])
        for i in range(self.num_stages):
            # 输入输出通道不同时，才需要使用 1*1 卷积
            use_1x1conv = i == 0
            # ResNet 模块
            if i == 0:
                self.enc_blks.append(Residual(1, channels[i], num_conv[i], use_1x1conv=True))
            else:
                self.enc_blks.append(Residual(channels[i], channels[i], num_conv[i], use_1x1conv=False))
            # 最后一个 stage 不需要 Down Conv 模块 
            if (i != self.num_stages - 1):
                # Down Conv 模块
                self.down_convs.append(nn.Conv3d(channels[i], channels[i+1], kernel_size=2, stride=2))
    def forward(self, X):
        ftrs = []
        for i in range(self.num_stages):
            X = self.enc_blks[i](X)
            ftrs.append(X)
            if i != self.num_stages - 1:
                X = F.relu(self.down_convs[i](X))
        return ftrs

encoder = Encoder()
ftrs = encoder(X)
for ftr in ftrs:
    print(ftr.shape)

torch.Size([1, 16, 128, 128, 64])
torch.Size([1, 32, 64, 64, 32])
torch.Size([1, 64, 32, 32, 16])
torch.Size([1, 128, 16, 16, 8])
torch.Size([1, 256, 8, 8, 4])


In [10]:
print(encoder)

Encoder(
  (enc_blks): ModuleList(
    (0): Residual(
      (convs): ModuleList(
        (0): Conv3d(1, 16, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
      )
      (bns): ModuleList(
        (0): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv1x1): Conv3d(1, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    )
    (1): Residual(
      (convs): ModuleList(
        (0): Conv3d(32, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
        (1): Conv3d(32, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
      )
      (bns): ModuleList(
        (0): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): Residual(
      (convs): ModuleList(
        (0): Conv3d(64, 64, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
        (1): Conv3d(64, 64, ker

In [32]:
class Decoder(nn.Module):
    def __init__(self, channels=(256, 128, 64, 32), num_conv=(1, 2, 3, 3, 3)):
        super().__init__()
        self.num_stages = len(channels)
        self.up_convs = nn.ModuleList([])
        self.dec_blks = nn.ModuleList([])
        for i in range(self.num_stages):
            if i == 0:
                in_chs, out_chs = 256, channels[i]//2
            else:
                in_chs, out_chs = channels[i-1], channels[i]//2
            self.up_convs.append(nn.ConvTranspose3d(in_chs, out_chs, kernel_size=2, stride=2))
            self.dec_blks.append(Residual(channels[i], channels[i], num_conv[i], use_1x1conv=True, stride=1, res_channels=out_chs))
        
    def forward(self, X, ftrs):
        for i in range(self.num_stages):
            X = res = F.relu(self.up_convs[i](X))
            X = torch.cat([X, ftrs[i]], dim=1)
            X = self.dec_blks[i](X, res)
        return X

In [36]:
decoder = Decoder()
X = torch.rand(1, 256, 8, 8, 4)
Y = decoder(X, ftrs[::-1][1:])
print(Y.shape)

torch.Size([1, 32, 128, 128, 64])


In [37]:
head = nn.Conv3d(32, 2, 1)
Y1 = head(Y)
print(Y1.shape)

torch.Size([1, 2, 128, 128, 64])


In [39]:
smx = F.softmax(Y1, dim=1)
# 在前景/背景上作和为概率 1
print(smx.sum(dim=1))

tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         ...,

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 