Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] PyTorch Model Summary #3043

Closed
wants to merge 7 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,92 @@ def __repr__(self):
tmpstr = tmpstr + ')'
return tmpstr

def summary(self, input_size):
def register_hook(module):
def hook(module, input, output):
if module._modules: # only want base layers
return
class_name = str(module.__class__).split('.')[-1].split("'")[0]
module_idx = len(summary)
m_key = '%s-%i' % (class_name, module_idx + 1)
summary[m_key] = OrderedDict()
summary[m_key]['input_shape'] = list(input[0].size())

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

summary[m_key]['input_shape'][0] = None
if output.__class__.__name__ == 'tuple':
summary[m_key]['output_shape'] = list(output[0].size())
else:
summary[m_key]['output_shape'] = list(output.size())
summary[m_key]['output_shape'][0] = None

params = 0
# iterate through parameters and count num params
for name, p in module._parameters.items():
params += torch.numel(p.data)
summary[m_key]['trainable'] = p.requires_grad

summary[m_key]['nb_params'] = params

if not isinstance(module, torch.nn.Sequential) and \
not isinstance(module, torch.nn.ModuleList) and \
not (module == self):
hooks.append(module.register_forward_hook(hook))

# check if there are multiple inputs to the network
if isinstance(input_size[0], (list, tuple)):
x = [Variable(torch.rand(1, *in_size)) for in_size in input_size]
else:
x = Variable(torch.randn(1, *input_size))

# create properties
summary = OrderedDict()
hooks = []
# register hook
self.apply(register_hook)
# make a forward pass
self(x)
# remove these hooks
for h in hooks:
h.remove()

# print out neatly
def get_names(module, name, acc):
if not module._modules:
acc.append(name)
else:
for key in module._modules.keys():
p_name = key if name == "" else name + "." + key
get_names(module._modules[key], p_name, acc)
names = []
get_names(self, "", names)

col_width = 25 # should be >= 12
summary_width = 61

def crop(s):
return s[:col_width] if len(s) > col_width else s

print('_' * summary_width)
print('{0: <{3}} {1: <{3}} {2: <{3}}'.format(
'Layer (type)', 'Output Shape', 'Param #', col_width))
print('=' * summary_width)
total_params = 0
trainable_params = 0
for (i, l_type), l_name in zip(enumerate(summary), names):

This comment was marked as off-topic.

This comment was marked as off-topic.

d = summary[l_type]
total_params += d['nb_params']
if 'trainable' in d and d['trainable']:
trainable_params += d['nb_params']
print('{0: <{3}} {1: <{3}} {2: <{3}}'.format(
crop(l_name + ' (' + l_type[:-2] + ')'), crop(str(d['output_shape'])),

This comment was marked as off-topic.

crop(str(d['nb_params'])), col_width))
if i < len(summary) - 1:
print('_' * summary_width)
print('=' * summary_width)
print('Total params: ' + str(total_params))
print('Trainable params: ' + str(trainable_params))
print('Non-trainable params: ' + str((total_params - trainable_params)))
print('_' * summary_width)

def __dir__(self):
module_attrs = dir(self.__class__)
attrs = list(self.__dict__.keys())
Expand Down