* venv : torch2_0

In [1]:
import torch
import torch.nn.functional as F
import torchinfo

import os

In [2]:
class SampleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3)
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(kernel_size=2)
        self.fc1 = torch.nn.Linear(in_features=3600, out_features=256)
        self.fc2 = torch.nn.Linear(in_features=256, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size()[0], -1) # flatten
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = SampleModel()

In [3]:
# 学習可能なパラメーター(requires_grad=Trueのものの数)
def count_trainable_parameters(model):
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

count_trainable_parameters(model=model)

924874

In [4]:
# モデルサマリーの表示(torchinfo使用)
#  - 注意：deviceを指定しないとmodelがcuda:0に移動してしまう(指定すればよいかは未確認)
torchinfo.summary(model=model,input_size=(1,3,32,32))

Layer (type:depth-idx)                   Output Shape              Param #
SampleModel                              [1, 10]                   --
├─Conv2d: 1-1                            [1, 16, 30, 30]           448
├─ReLU: 1-2                              [1, 16, 30, 30]           --
├─MaxPool2d: 1-3                         [1, 16, 15, 15]           --
├─Linear: 1-4                            [1, 256]                  921,856
├─Linear: 1-5                            [1, 10]                   2,570
Total params: 924,874
Trainable params: 924,874
Non-trainable params: 0
Total mult-adds (M): 1.33
Input size (MB): 0.01
Forward/backward pass size (MB): 0.12
Params size (MB): 3.70
Estimated Total Size (MB): 3.83

In [5]:
# 上記をファイルに保存
output_path = './output/model_summary.txt'
os.makedirs(os.path.dirname(output_path),exist_ok=True)

with open(output_path,mode='w') as fw:
    fw.write(repr(torchinfo.summary(model,(1,3,32,32))))