In [1]:
import torch
import torchsummary
from torch import nn
from torch.nn import functional as F

In [2]:
class Inception(nn.Module):
    def __init__(self, in_channel, c1, c2, c3, c4) -> None:
        super().__init__()
        self.branch1 = nn.Conv2d(in_channel, c1, 1)
        self.branch2_1 = nn.Conv2d(in_channel, c2[0], 1)
        self.branch2_2 = nn.Conv2d(c2[0], c2[1], 3, padding=1)
        self.branch3_1 = nn.Conv2d(in_channel, c3[0], 1)
        self.branch3_2 = nn.Conv2d(c3[0], c3[1], 5, padding=2)
        self.branch4_1 = nn.MaxPool2d(3, stride=1, padding=1)
        self.branch4_2 = nn.Conv2d(in_channel, c4, 1)
    def forward(self, x):
        b1 = F.relu(self.branch1(x))
        b2 = F.relu(self.branch2_2(F.relu(self.branch2_1(x))))
        b3 = F.relu(self.branch3_2(F.relu(self.branch3_1(x))))
        b4 = F.relu(self.branch4_2(F.relu(self.branch4_1(x))))
        return torch.cat([b1, b2, b3, b4], dim=1)

In [6]:
net = Inception(3, 3, (2, 3), (3, 5), 4)
torchsummary.summary(net, (3, 224, 224), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 3, 224, 224]              12
            Conv2d-2          [-1, 2, 224, 224]               8
            Conv2d-3          [-1, 3, 224, 224]              57
            Conv2d-4          [-1, 3, 224, 224]              12
            Conv2d-5          [-1, 5, 224, 224]             380
         MaxPool2d-6          [-1, 3, 224, 224]               0
            Conv2d-7          [-1, 4, 224, 224]              16
Total params: 485
Trainable params: 485
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 8.80
Params size (MB): 0.00
Estimated Total Size (MB): 9.38
----------------------------------------------------------------
