In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1):
        super(Conv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups)

    def forward(self, x):
        return self.conv(x)


class BatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(BatchNorm2d, self).__init__()
        self.bn = nn.BatchNorm2d(num_features, eps=eps, momentum=momentum)

    def forward(self, x):
        return self.bn(x)


class ReLU(nn.Module):
    def forward(self, x):
        return F.relu(x)


class ChannelShffle(nn.Module):
    def __init__(self, groups):
        super(ChannelShffle, self).__init__()
        self.groups = groups

    def forward(self, x):
        batch_size, num_channels, height, width = x.size()
        channels_per_group = num_channels // self.groups
        x = x.view(batch_size, self.groups, channels_per_group, height, width)
        x = x.permute(0, 2, 1, 3, 4).contiguous()
        x = x.view(batch_size, -1, height, width)
        return x


class MaxPool2d(nn.Module):
    def __init__(self, kernel_size, stride=None, padding=0):
        super(MaxPool2d, self).__init__()
        self.pool = nn.MaxPool2d(kernel_size, stride, padding)

    def forward(self, x):
        return self.pool(x)


class AdaptiveAvgPool2d(nn.Module):
    def __init__(self, output_size):
        super(AdaptiveAvgPool2d, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(output_size)

    def forward(self, x):
        return self.pool(x)


class Linear(nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x):
        return self.linear(x)


class ShffleUnit(nn.Module):
    def __init__(self, in_channels, out_channels, stride, groups):
        super(ShffleUnit, self).__init__()
        self.stride = stride
        self.groups = groups
        mid_channels = out_channels // 4

        if self.stride == 2:
            self.residual = nn.Sequential(
                Conv2d(in_channels, in_channels, 1, 1, 0, groups=groups),
                BatchNorm2d(in_channels),
                ReLU(),
                Conv2d(in_channels, in_channels, 3, stride, 1, groups=in_channels),
                BatchNorm2d(in_channels),
                Conv2d(in_channels, out_channels, 1, 1, 0, groups=1),
                BatchNorm2d(out_channels)
            )
            self.shortcut = nn.Sequential(
                Conv2d(in_channels, out_channels, 1, stride, 0),
                BatchNorm2d(out_channels)
            )
        else:
            self.residual = nn.Sequential(
                Conv2d(in_channels, mid_channels, 1, 1, 0, groups=1),
                BatchNorm2d(mid_channels),
                ReLU(),
                Conv2d(mid_channels, mid_channels, 3, stride, 1, groups=mid_channels),
                BatchNorm2d(mid_channels),
                Conv2d(mid_channels, out_channels, 1, 1, 0, groups=1),
                BatchNorm2d(out_channels)
            )
            self.shortcut = nn.Sequential(
                Conv2d(in_channels, out_channels, 1, 1, 0),
                BatchNorm2d(out_channels)
            )

        self.shuffle = ChannelShffle(groups)
        self.relu = ReLU()

    def forward(self, x):
        residual = self.residual(x)
        shortcut = self.shortcut(x)
        output = self.relu(residual + shortcut)
        return self.shuffle(output)


class ShffleNet(nn.Module):
    def __init__(self, num_classes=1000, groups=3):
        super(ShffleNet, self).__init__()
        self.groups = groups
        self.stage_repeats = [4, 8, 4]
        self.stage_out_channels = [-1, 24, 240, 480, 960]

        self.conv1 = Conv2d(3, self.stage_out_channels[1], 3, 2, 1)
        self.bn1 = BatchNorm2d(self.stage_out_channels[1])
        self.relu = ReLU()
        self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.features = nn.Sequential()
        input_channels = self.stage_out_channels[1]
        for i in range(len(self.stage_repeats)):
            output_channels = self.stage_out_channels[i + 2]
            for j in range(self.stage_repeats[i]):
                stride = 2 if j == 0 else 1
                self.features.add_module('ShffleUnit_{}_{}'.format(i, j), ShffleUnit(input_channels, output_channels, stride, groups=self.groups))
                input_channels = output_channels

        self.conv5 = Conv2d(input_channels, 1024, 1, 1, 0)
        self.bn5 = BatchNorm2d(1024)
        self.avgpool = AdaptiveAvgPool2d((1, 1))
        self.fc = Linear(1024, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.conv5(x)
        x = self.bn5(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


# 创建 ShffleNet 模型实例
model = ShffleNet(num_classes=1000, groups=3)

# 创建一个随机输入张量
input_tensor = torch.randn(1, 3, 224, 224)

# 前向传播
output = model(input_tensor)

print(output.shape)  # 输出: torch.Size([1, 1000])

torch.Size([1, 1000])
