Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement similar PyTorch function as model.summary() in keras? #2001

Open
iabhi7 opened this issue Jul 7, 2017 · 22 comments
Open

Implement similar PyTorch function as model.summary() in keras? #2001

iabhi7 opened this issue Jul 7, 2017 · 22 comments
Labels
feature A request for a proper, new feature. function request A request for a new function or the addition of new arguments/modes to an existing function. module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@iabhi7
Copy link

iabhi7 commented Jul 7, 2017

model.summary in keras gives a very fine visualization of your model and it's very convenient when it comes to debugging the network. Can we try to implement something like it in PyTorch?

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @mruberry

@apaszke
Copy link
Contributor

apaszke commented Jul 7, 2017

What kind of information that's not in str(model) would you like to see? Output shapes?

@ncullen93
Copy link

ncullen93 commented Jul 7, 2017

Sure you can! Here's something to get you started. (Adapted from other code, so it's not tested in the wild). Note that you HAVE to know the input size, and you HAVE to make a forward pass through the network. Those are the only reqs I think.

    def summary(input_size, model):
        def register_hook(module):
            def hook(module, input, output):
                class_name = str(module.__class__).split('.')[-1].split("'")[0]
                module_idx = len(summary)

                m_key = '%s-%i' % (class_name, module_idx+1)
                summary[m_key] = OrderedDict()
                summary[m_key]['input_shape'] = list(input[0].size())
                summary[m_key]['input_shape'][0] = -1
                summary[m_key]['output_shape'] = list(output.size())
                summary[m_key]['output_shape'][0] = -1

                params = 0
                if hasattr(module, 'weight'):
                    params += th.prod(th.LongTensor(list(module.weight.size())))
                    if module.weight.requires_grad:
                        summary[m_key]['trainable'] = True
                    else:
                        summary[m_key]['trainable'] = False
                if hasattr(module, 'bias'):
                    params +=  th.prod(th.LongTensor(list(module.bias.size())))
                summary[m_key]['nb_params'] = params
                
            if not isinstance(module, nn.Sequential) and \
               not isinstance(module, nn.ModuleList) and \
               not (module == model):
                hooks.append(module.register_forward_hook(hook))
        
        # check if there are multiple inputs to the network
        if isinstance(input_size[0], (list, tuple)):
            x = [Variable(th.rand(1,*in_size)) for in_size in input_size]
        else:
            x = Variable(th.rand(1,*input_size))

        # create properties
        summary = OrderedDict()
        hooks = []
        # register hook
        model.apply(register_hook)
        # make a forward pass
        model(x)
        # remove these hooks
        for h in hooks:
            h.remove()

        return summary

Here's an example of Keras summary, FYI:

screen shot 2017-07-07 at 12 55 01 pm

@iabhi7
Copy link
Author

iabhi7 commented Jul 11, 2017

@apaszke str(output) works well for most cases but missed out on some output_shape is one of them
@ncullen93 The code looks/works well. Can we think about a PR with some code refactoring to this?

@soumith soumith added enhancement todo Not as important as medium or high priority tasks, but we will work on these. labels Jul 13, 2017
@soumith soumith added this to Uncategorized in Issue Status Aug 23, 2017
@soumith soumith moved this from Uncategorized to Low Priority in Issue Status Aug 23, 2017
@soumith soumith added this to usability / simple-fixes in Issue Categories Sep 11, 2017
@aditya1702
Copy link

Hello, I am new to pytorch and contributing to it. Can I try this one out?

@iabhi7
Copy link
Author

iabhi7 commented Sep 19, 2017

@aditya1702 yeah sure, as my information goes till now we don't have any functions to see the output_size and something similiar to keras model.summary

@EtienneDesticourt
Copy link

I think it would be useful to add hooks to the output of str(model).
I use str(model) in my logging system and currently I can't know whether my models used weight normalizations or not since it's a hook.

@isaykatsman
Copy link

isaykatsman commented Oct 9, 2017

After @ncullen93 posted his code, I added model summary to my local build of pytorch. Recently decided to clean the code up a little and make a PR. Can hopefully merge soon.

@iabhi7
Copy link
Author

iabhi7 commented Oct 10, 2017

Hi @isaykatsman the implementation is nice. Thanks.

@HTLife
Copy link

HTLife commented Mar 15, 2018

@ncullen93
It that 'Variable' at the line x = Variable(th.rand(1,*input_size)) a package?
How can I deal with this error?

global name 'Variable' is not defined

@apaszke
Copy link
Contributor

apaszke commented Mar 15, 2018

from torch.autograd import Variable

@HTLife
Copy link

HTLife commented Mar 16, 2018

I did some modification on the print out style to fit the style of Keras. Hope it helps.

_031618_024018_pm
_031618_023944_pm

(2018/05/18 update)
sksq96 re-organized the code into python package sksq96/pytorch-summary

@backpropper
Copy link

backpropper commented Apr 26, 2018

This doesn't work for LSTMs where the output is a tuple.

@ShuvenduRoy
Copy link

ShuvenduRoy commented May 24, 2018

@ncullen93 It does not work with sequential() function. Do you have a solution for this?

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

Traceback (most recent call last):
  File ".\test.py", line 118, in <module>
    summary(model, (3, 28, 28))
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torchsummary\torchsummary.py", line 57, in summary
    model(x)
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File ".\test.py", line 93, in forward
    return self.model(img)
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torch\nn\modules\container.py", line 91, in forward
    input = module(input)
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    hook_result = hook(self, input, result)
  File "C:\Users\bikas\Anaconda3\lib\site-packages\torchsummary\torchsummary.py", line 26, in hook
    params += torch.prod(torch.LongTensor(list(module.weight.size())))
AttributeError: 'NoneType' object has no attribute 'size'

@jongwook
Copy link
Contributor

jongwook commented Jul 17, 2018

I like the tree format of the builtin __repr__() but wanted to show the number of parameters, like:

image

import sys
from functools import reduce

from torch.nn.modules.module import _addindent

def summary(model, file=sys.stderr):
    def repr(model):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = model.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        total_params = 0
        for key, module in model._modules.items():
            mod_str, num_params = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
            total_params += num_params
        lines = extra_lines + child_lines

        for name, p in model._parameters.items():
            total_params += reduce(lambda x, y: x * y, p.shape)

        main_str = model._get_name() + '('
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        if file is sys.stderr:
            main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
        else:
            main_str += ', {:,} params'.format(total_params)
        return main_str, total_params

    string, count = repr(model)
    if file is not None:
        print(string, file=file)
    return count

@PistonY
Copy link

PistonY commented Mar 27, 2019

What kind of information that's not in str(model) would you like to see? Output shapes?

Total parameters and total FLOPS.

@ezyang ezyang added feature A request for a proper, new feature. and removed enhancement labels Apr 1, 2019
houseroad added a commit to houseroad/pytorch that referenced this issue May 13, 2019
…5889b7

Summary:
Previous import was 5bde6371620b76302864bce90f521d72eda95d0e

Included changes:
- **[e08efaa3](onnx/onnx@e08efaa3)**: Fix shape inference logic for TopK operator (pytorch#2005) <Hariharan Seshadri>
- **[d80ea947](onnx/onnx@d80ea947)**: Nullary variadic (pytorch#1889) <G. Ramalingam>
- **[50dc186b](onnx/onnx@50dc186b)**: Removed setting MD/MDd flags manually through cmake. The MTd/MT part is still necessary. Looks like CI fails without it. (pytorch#1995) <Alexander Yermolovich>
- **[e7f81c5e](onnx/onnx@e7f81c5e)**: Move NonMaxSupression to object_detection folder (pytorch#2001) <Hector Li>
- **[86ab4517](onnx/onnx@86ab4517)**: Prevent using invalid iterator, fix arithmetics. (pytorch#2004) <Dmitri Smirnov>

Differential Revision: D15302141

fbshipit-source-id: 07e8d34a170f77c112042faad519d07bcf1e61f8
facebook-github-bot pushed a commit that referenced this issue May 13, 2019
…5889b7 (#20443)

Summary:
Pull Request resolved: #20443

Previous import was 5bde6371620b76302864bce90f521d72eda95d0e

Included changes:
- **[e08efaa3](onnx/onnx@e08efaa3)**: Fix shape inference logic for TopK operator (#2005) <Hariharan Seshadri>
- **[d80ea947](onnx/onnx@d80ea947)**: Nullary variadic (#1889) <G. Ramalingam>
- **[50dc186b](onnx/onnx@50dc186b)**: Removed setting MD/MDd flags manually through cmake. The MTd/MT part is still necessary. Looks like CI fails without it. (#1995) <Alexander Yermolovich>
- **[e7f81c5e](onnx/onnx@e7f81c5e)**: Move NonMaxSupression to object_detection folder (#2001) <Hector Li>
- **[86ab4517](onnx/onnx@86ab4517)**: Prevent using invalid iterator, fix arithmetics. (#2004) <Dmitri Smirnov>

Reviewed By: zrphercule

Differential Revision: D15302141

fbshipit-source-id: 146c346c188934e5125371b261ecfde93b4aa166
@ray-lee-94
Copy link

Can i draw the forward pass graph similar to torchviz ?
If i summary the model and get the visulization, that would be good.

@jeffreyksmithjr jeffreyksmithjr added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 19, 2019
@zamnius
Copy link

zamnius commented May 10, 2020

NameError: name 'th' is not defined
For example in the line: x = [Variable(th.rand(1,*in_size)).type(dtype) for in_size in input_size]
What is missing?

@hmrishavbandy
Copy link

@zamnius , here are the complete list of imports:

from torch.autograd import Variable
import torch as th
from torch import nn as nn
from collections import OrderedDict
from model import UNet3D

@ShuvenduBikash , this is probably what you figured out:
The lines 15-22 should be like this as all layers dont have bias.

                if hasattr(module,'weight') and module.weight is not None:
                    params += th.prod(th.LongTensor(list(module.weight.size())))
                    if module.weight.requires_grad:
                        summary[m_key]['trainable'] = True
                    else:
                        summary[m_key]['trainable'] = False
                if hasattr(module,'bias') and module.bias is not None:
                    params +=  th.prod(th.LongTensor(list(module.bias.size())))

@zou3519
Copy link
Contributor

zou3519 commented Jan 22, 2021

I don't remember why this was originally in the low-priority bin, but I'm tentatively marking this as triage review for discussion on whether or not this should be high-pri

@zou3519 zou3519 added triage review and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 22, 2021
@mruberry mruberry added function request A request for a new function or the addition of new arguments/modes to an existing function. high priority module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed todo Not as important as medium or high priority tasks, but we will work on these. triage review labels Jan 25, 2021
@jbschlosser
Copy link
Contributor

Removing high-pri for now since this is provided by pytorch-summary, which appears to work for most cases described here (including nn.Sequential models).

As of now, I don't see huge benefits to maintaining this within PyTorch core and would prefer to keep it in the separate repo, but I'm open to being convinced otherwise if it doesn't support your use case, etc.

@tuiiitendinh
Copy link

Hi, i'm new to pytorch.
Is there anyway to plot the graph of the model as well as export the sumary of the model without initialize an instance of the model? Like Keras, I would just load the model and then visualize it.

@mert-kurttutan
Copy link

Hi, i'm new to pytorch. Is there anyway to plot the graph of the model as well as export the sumary of the model without initialize an instance of the model? Like Keras, I would just load the model and then visualize it.

Hi @tuiiitendinh,
I created a package to visualize pytorch models. You can see it, here
There are a few notebooks to get familiar with in addition to documentation.
Any feedback/issue/pr is much appreciated

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. function request A request for a new function or the addition of new arguments/modes to an existing function. module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Issue Categories
usability / simple-fixes
Issue Status
Low Priority
Status: To pick up
Development

No branches or pull requests