In [1]:
import torch.nn as nn

In [2]:
def cal_params(model):
    params = list(model.parameters())
    params_num = sum(param.numel() for param in params if param.requires_grad)
    print(f'Number of parameters: {params_num}')

In [3]:
class Conv_Block(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Conv_Block, self).__init__()
        self.conv_block = nn.Sequential(nn.Conv2d(in_dim, out_dim, 3),
                                        nn.ReLU(), nn.BatchNorm2d(out_dim),
                                        nn.MaxPool2d(2, 2))

    def forward(self, x):
        return self.conv_block(x)


class Linear_Block(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Linear_Block, self).__init__()
        self.linear_block = nn.Sequential(nn.Linear(in_dim, out_dim),
                                          nn.ReLU(), nn.BatchNorm1d(out_dim))

    def forward(self, x):
        return self.linear_block(x)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_net = nn.Sequential(Conv_Block(3, 8), 
                                      Conv_Block(8, 16),
                                      Conv_Block(16, 32))

        self.mlp = nn.Sequential(Linear_Block(32 * 2 * 2, 128),
                                 Linear_Block(128, 128), 
                                 nn.Linear(128, 10))

    def forward(self, x):
        x = self.conv_net(x)
        x = nn.AdaptiveAvgPool2d(2)(x)
        x = x.view(-1, 32 * 2 * 2)
        x = self.mlp(x)
        return x

In [4]:
model = Net()

In [5]:
cal_params(model)

Number of parameters: 40970
