In [1]:
import torch
from torch import nn

In [2]:
nn.Identity()

Identity()

In [3]:
class IdentityShortcut(nn.Module):
    """
        构建一个 输入/输出维度相同的ResBlock，也可以叫 基本形式（Identity Shortcut）
        输入输出的channel不变
    """
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
    def forward(self, x):
        """
            identity 
        """
        x = self.conv(x) + x
        return x

In [4]:
X = torch.randn(2, 3, 32, 32)

In [5]:
identity_sc = IdentityShortcut()

In [6]:
identity_sc(X).shape

torch.Size([2, 3, 32, 32])

In [7]:
class ProjectionShortcut(nn.Module):
    """
        构建一个 输入/输出维度不同的ResBlock，也可以叫 维度变化形式（Projection Shortcut）
        输入输出的channel有变化
    """
    def __init__(self):
        super().__init__()
        # 抽特征
        self.conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, stride=2, padding=1)
        # 调整通道
        self.short_cut = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=1, stride=2, padding=0)
    def forward(self, x):
         x = self.conv(x) + self.short_cut(x)
         return x

In [8]:
bottleneck_sc = ProjectionShortcut()

In [9]:
X.shape

torch.Size([2, 3, 32, 32])

In [10]:
bottleneck_sc(X).shape

torch.Size([2, 8, 16, 16])