In [None]:
def conv_size_out(size_in, kern, stride):
    pad = 0
    dilation = 1
    return (size_in + 2*pad - dilation*(kern - 1) - 1) // stride + 1

def avg_size_out(size_in, kern, stride):
    pad = 0
    return (size_in + 2*pad - kern) // stride + 1

def max_size_out(size_in, kern, stride):
    pad = 0
    dilation = 1
    return (size_in + 2*pad - dilation*(kern - 1) - 1) // stride + 1


In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        
        self.model = nn.Sequential(            
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.Flatten(),
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=10)
        )

    def forward(self, x):
        logits = self.model(x)
        probs = F.softmax(logits, dim=1)
        return logits, probs

In [None]:
# the_channels = 1
# the_numf = 32
# the_pool = 4
# the_stride = 2

In [None]:
# model = nn.Sequential(
#     nn.AvgPool1d(kernel_size = the_pool, stride = the_stride),
#     nn.Flatten(),
#     nn.Linear(
#         in_features = the_channels * avg_size_out(the_numf, the_pool, the_stride),
#         out_features = 10
#     )
# )

In [None]:
from collections import OrderedDict

def register_hook(module):
    def hook(module, input, output):
        class_name = str(module.__class__).split(".")[-1].split("'")[0]
        module_idx = len(summary)

        m_key = "%s-%i" % ('Layer', module_idx + 1)
        summary[m_key] = {'name': class_name, 'input_shape':list(input[0].size())}

    if (
        not isinstance(module, nn.Sequential)
        and not isinstance(module, nn.ModuleList)
    ):
        hooks.append(module.register_forward_hook(hook))

# create properties
summary = OrderedDict()
hooks = []

# register hook
model.apply(register_hook)

# make a forward pass
# print(x.shape)
model(x)

summary

In [None]:
for layer in Model.model:
#     print(layer.__class__.__name__)
#     print(layer)
    print(layer.__dict__.keys())
    print()

In [3]:
Model = LeNet5()

model = Model.model

In [14]:
def extract(model, the_dim, 
                      the_batch, the_numf, the_channels):
    
    def MyAttr(attr, class_name):
        _MyAttr = {
            'stride':'stride',
            'p':'drop',
            'in_channels':'channels',
            'out_channels':'filters',
            'out_features':'units'
        }
        if ((class_name == 'AvgPool1d') |
            (class_name == 'AvgPool2d') |
            (class_name == 'MaxPool1d') |
            (class_name == 'MaxPool2d')):
                _MyAttr['kernel_size'] = 'pool'
        
        else:
            _MyAttr['kernel_size'] = 'kernel'
        
        return _MyAttr[attr]

    Search = {
        'Conv1d':['out_channels', 'kernel_size'],
        'Conv2d':['out_channels', 'kernel_size'],
        'AvgPool1d':['kernel_size', 'stride'],
        'AvgPool2d':['kernel_size', 'stride'],
        'MaxPool1d':['kernel_size', 'stride'],
        'MaxPool2d':['kernel_size', 'stride'],
        'Dropout':['p'],
        'Dropout2d':['p'],
        'Linear':['out_features']
    }
    
    def give_size_summary_dct(model, dim, batch, channels, numf):
        if dim==1:
            x = torch.rand(batch, channels, numf)

        else:
            x = torch.rand(batch, channels, numf, numf)
            
        from collections import OrderedDict

        def register_hook(module):
            def hook(module, input, output):
                class_name = str(module.__class__).split(".")[-1].split("'")[0]
                module_idx = len(summary)

                m_key = "%s-%i" % ('Layer', module_idx + 1)
                summary[m_key] = {'name': class_name, 'input_shape':list(input[0].size())}

            if (
                not isinstance(module, nn.Sequential)
                and not isinstance(module, nn.ModuleList)
            ):
                hooks.append(module.register_forward_hook(hook))

        # create properties
        summary = OrderedDict()
        hooks = []

        # register hook
        model.apply(register_hook)

        # make a forward pass
        # print(x.shape)
        model(x)

        return summary
    
    summ_dct = give_size_summary_dct(model, the_dim, the_batch, the_channels, the_numf)

    def give_info(layer, inp_info):
        Info = {}
        name = layer.__class__.__name__
        Info['name'] = name
        
        assert(name == inp_info['name'])
        
        Info['input_shape'] = inp_info['input_shape']
        Info['batch'] = inp_info['input_shape'][0]
        
        inp_size = len(inp_info['input_shape'])
        if inp_size == 2:
            Info['dim'] = 0
            Info['numf'] = inp_info['input_shape'][1]
        elif inp_size == 3:
            Info['dim'] = 1
            Info['channels'] = inp_info['input_shape'][1]
            Info['numf'] = inp_info['input_shape'][2]
        else:
            Info['dim'] = 2
            Info['channels'] = inp_info['input_shape'][1]
            Info['numf'] = inp_info['input_shape'][2]
        
        if name in Search.keys():
            search = Search[name]
            for attr in search:
                Info[MyAttr(attr, name)] = layer.__dict__[attr]
        
        return Info

    return [give_info(layer, inp_info) for layer, inp_info in zip(model, summ_dct.values())]

In [15]:
extract(model, 2, 32, 32, 1)

[{'name': 'Conv2d',
  'input_shape': [32, 1, 32, 32],
  'batch': 32,
  'dim': 2,
  'channels': 1,
  'numf': 32,
  'filters': 6,
  'kernel': (5, 5)},
 {'name': 'Tanh',
  'input_shape': [32, 6, 28, 28],
  'batch': 32,
  'dim': 2,
  'channels': 6,
  'numf': 28},
 {'name': 'AvgPool2d',
  'input_shape': [32, 6, 28, 28],
  'batch': 32,
  'dim': 2,
  'channels': 6,
  'numf': 28,
  'pool': 2,
  'stride': 2},
 {'name': 'Conv2d',
  'input_shape': [32, 6, 14, 14],
  'batch': 32,
  'dim': 2,
  'channels': 6,
  'numf': 14,
  'filters': 16,
  'kernel': (5, 5)},
 {'name': 'Tanh',
  'input_shape': [32, 16, 10, 10],
  'batch': 32,
  'dim': 2,
  'channels': 16,
  'numf': 10},
 {'name': 'AvgPool2d',
  'input_shape': [32, 16, 10, 10],
  'batch': 32,
  'dim': 2,
  'channels': 16,
  'numf': 10,
  'pool': 2,
  'stride': 2},
 {'name': 'Conv2d',
  'input_shape': [32, 16, 5, 5],
  'batch': 32,
  'dim': 2,
  'channels': 16,
  'numf': 5,
  'filters': 120,
  'kernel': (5, 5)},
 {'name': 'Tanh',
  'input_shape': [3