-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement similar PyTorch function as model.summary() in keras? #2001
Comments
What kind of information that's not in |
Sure you can! Here's something to get you started. (Adapted from other code, so it's not tested in the wild). Note that you HAVE to know the input size, and you HAVE to make a forward pass through the network. Those are the only reqs I think. def summary(input_size, model):
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split('.')[-1].split("'")[0]
module_idx = len(summary)
m_key = '%s-%i' % (class_name, module_idx+1)
summary[m_key] = OrderedDict()
summary[m_key]['input_shape'] = list(input[0].size())
summary[m_key]['input_shape'][0] = -1
summary[m_key]['output_shape'] = list(output.size())
summary[m_key]['output_shape'][0] = -1
params = 0
if hasattr(module, 'weight'):
params += th.prod(th.LongTensor(list(module.weight.size())))
if module.weight.requires_grad:
summary[m_key]['trainable'] = True
else:
summary[m_key]['trainable'] = False
if hasattr(module, 'bias'):
params += th.prod(th.LongTensor(list(module.bias.size())))
summary[m_key]['nb_params'] = params
if not isinstance(module, nn.Sequential) and \
not isinstance(module, nn.ModuleList) and \
not (module == model):
hooks.append(module.register_forward_hook(hook))
# check if there are multiple inputs to the network
if isinstance(input_size[0], (list, tuple)):
x = [Variable(th.rand(1,*in_size)) for in_size in input_size]
else:
x = Variable(th.rand(1,*input_size))
# create properties
summary = OrderedDict()
hooks = []
# register hook
model.apply(register_hook)
# make a forward pass
model(x)
# remove these hooks
for h in hooks:
h.remove()
return summary Here's an example of Keras summary, FYI: |
@apaszke |
Hello, I am new to pytorch and contributing to it. Can I try this one out? |
@aditya1702 yeah sure, as my information goes till now we don't have any functions to see the output_size and something similiar to keras model.summary |
I think it would be useful to add hooks to the output of str(model). |
After @ncullen93 posted his code, I added model summary to my local build of pytorch. Recently decided to clean the code up a little and make a PR. Can hopefully merge soon. |
Hi @isaykatsman the implementation is nice. Thanks. |
@ncullen93
|
|
I did some modification on the print out style to fit the style of Keras. Hope it helps. (2018/05/18 update) |
This doesn't work for LSTMs where the output is a tuple. |
@ncullen93 It does not work with sequential() function. Do you have a solution for this?
|
I like the tree format of the builtin import sys
from functools import reduce
from torch.nn.modules.module import _addindent
def summary(model, file=sys.stderr):
def repr(model):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = model.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
total_params = 0
for key, module in model._modules.items():
mod_str, num_params = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
total_params += num_params
lines = extra_lines + child_lines
for name, p in model._parameters.items():
total_params += reduce(lambda x, y: x * y, p.shape)
main_str = model._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
if file is sys.stderr:
main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
else:
main_str += ', {:,} params'.format(total_params)
return main_str, total_params
string, count = repr(model)
if file is not None:
print(string, file=file)
return count |
Total parameters and total FLOPS. |
…5889b7 Summary: Previous import was 5bde6371620b76302864bce90f521d72eda95d0e Included changes: - **[e08efaa3](onnx/onnx@e08efaa3)**: Fix shape inference logic for TopK operator (pytorch#2005) <Hariharan Seshadri> - **[d80ea947](onnx/onnx@d80ea947)**: Nullary variadic (pytorch#1889) <G. Ramalingam> - **[50dc186b](onnx/onnx@50dc186b)**: Removed setting MD/MDd flags manually through cmake. The MTd/MT part is still necessary. Looks like CI fails without it. (pytorch#1995) <Alexander Yermolovich> - **[e7f81c5e](onnx/onnx@e7f81c5e)**: Move NonMaxSupression to object_detection folder (pytorch#2001) <Hector Li> - **[86ab4517](onnx/onnx@86ab4517)**: Prevent using invalid iterator, fix arithmetics. (pytorch#2004) <Dmitri Smirnov> Differential Revision: D15302141 fbshipit-source-id: 07e8d34a170f77c112042faad519d07bcf1e61f8
…5889b7 (#20443) Summary: Pull Request resolved: #20443 Previous import was 5bde6371620b76302864bce90f521d72eda95d0e Included changes: - **[e08efaa3](onnx/onnx@e08efaa3)**: Fix shape inference logic for TopK operator (#2005) <Hariharan Seshadri> - **[d80ea947](onnx/onnx@d80ea947)**: Nullary variadic (#1889) <G. Ramalingam> - **[50dc186b](onnx/onnx@50dc186b)**: Removed setting MD/MDd flags manually through cmake. The MTd/MT part is still necessary. Looks like CI fails without it. (#1995) <Alexander Yermolovich> - **[e7f81c5e](onnx/onnx@e7f81c5e)**: Move NonMaxSupression to object_detection folder (#2001) <Hector Li> - **[86ab4517](onnx/onnx@86ab4517)**: Prevent using invalid iterator, fix arithmetics. (#2004) <Dmitri Smirnov> Reviewed By: zrphercule Differential Revision: D15302141 fbshipit-source-id: 146c346c188934e5125371b261ecfde93b4aa166
Can i draw the forward pass graph similar to torchviz ? |
NameError: name 'th' is not defined |
@zamnius , here are the complete list of imports:
@ShuvenduBikash , this is probably what you figured out:
|
I don't remember why this was originally in the low-priority bin, but I'm tentatively marking this as triage review for discussion on whether or not this should be high-pri |
Removing high-pri for now since this is provided by pytorch-summary, which appears to work for most cases described here (including As of now, I don't see huge benefits to maintaining this within PyTorch core and would prefer to keep it in the separate repo, but I'm open to being convinced otherwise if it doesn't support your use case, etc. |
Hi, i'm new to pytorch. |
Hi @tuiiitendinh, |
model.summary
in keras gives a very fine visualization of your model and it's very convenient when it comes to debugging the network. Can we try to implement something like it in PyTorch?cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @mruberry
The text was updated successfully, but these errors were encountered: