Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New functionality: use torchsummary to build pytorch model with scalable input shape #80

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Keras style `model.summary()` in PyTorch
[![PyPI version](https://badge.fury.io/py/torchsummary.svg)](https://badge.fury.io/py/torchsummary)

Keras has a neat API to view the visualization of the model which is very helpful while debugging your network. Here is a barebone code to try and mimic the same in PyTorch. The aim is to provide information complementary to, what is not provided by `print(your_model)` in PyTorch.
Keras has a neat API to view the visualization of the model which is very helpful while debugging your network. Here is a barebone code to try and mimic the same in PyTorch. The aim is to provide information complementary to, what is not provided by `print(your_model)` in PyTorch. (**New functionality**) The main function `summary` (`from torchsummary import summary`) can also be used to infer the output shape of a pytorch model. Thus, it provides a way to build pytorch model that supports any input shape like in Keras (see an [example](#scalable) below).

### Usage

Expand Down Expand Up @@ -191,6 +191,100 @@ Estimated Total Size (MB): 0.78
----------------------------------------------------------------
```

### Build pytorch model with scalable input shape (like Keras)<a name="scalable"></a>

```python

import torch
import torch.nn as nn
from torchsummary import summary

class AutoEncoder(nn.Module):
"""
ResNet autoencoder network that support any input shape as model in Keras
:param img_shape (tuple, channel last): support any image input shape
:param state_dim: (int) latent state dimension
"""

def __init__(self, img_shape=(3, 224, 224), state_dim=3):
super().__init__()
self.state_dim = state_dim
self.img_shape = img_shape

self.encoder_conv = nn.Sequential(
nn.Conv2d(self.img_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),

nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2)
)
## Without torchsummary here, it's impossible to build model with scalable input shape as Keras.
outshape = summary(self.encoder_conv, img_shape, show=False) # [-1, channels, high, width]
self.img_height, self.img_width = outshape[-2:]
self.encoder_fc = nn.Sequential(
nn.Linear(self.img_height * self.img_width * 64, state_dim)
)

self.decoder_fc = nn.Sequential(
nn.Linear(state_dim, self.img_height * self.img_width * 64)
)

self.decoder_conv = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),

nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),

nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),

nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),

nn.ConvTranspose2d(64, self.img_shape[0], kernel_size=4, stride=2),
nn.Tanh()
)

def encode(self, x):
"""
Encode image to latent state
"""
encoded = self.encoder_conv(x)
encoded = encoded.view(encoded.size(0), -1)
return self.encoder_fc(encoded)

def decode(self, x):
"""
Decode latent state to image
"""
decoded = self.decoder_fc(x)
decoded = decoded.view(x.size(0), 64, self.img_height, self.img_width)
return self.decoder_conv(decoded)

def forward(self, x):
reconstruct = self.decode(self.encode(x))
return reconstruct


img_shape = (3,128,128)
model = AutoEncoder(img_shape=img_shape, state_dim=100)
summary(model, img_shape)

```


### References
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="torchsummary",
version="1.5.1",
version="1.5.2",
description="Model summary in PyTorch similar to `model.summary()` in Keras",
url="https://github.com/sksq96/pytorch-summary",
author="Shubham Chandel @sksq96",
Expand Down
71 changes: 40 additions & 31 deletions torchsummary/torchsummary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np


def summary(model, input_size, batch_size=-1, device="cuda"):
def summary(model, input_size, batch_size=-1, device="cpu", show=True):

def register_hook(module):

Expand Down Expand Up @@ -40,24 +40,32 @@ def hook(module, input, output):
and not (module == model)
):
hooks.append(module.register_forward_hook(hook))

device = device.lower()
assert device in [
"cuda",
"cpu",
], "Input device is not valid, please specify 'cuda' or 'cpu'"

if device == "cuda" and torch.cuda.is_available():
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
def cuda_device_valid(device_str):
valid = device_str.startswith("cuda")
try:
device_index = int(device_str.split(":")[-1])
total_gpu_num = torch.cuda.device_count()
if (device_index < total_gpu_num):
return valid
else:
print("Cuda device '{}' dosen't exist. Find {} GPU(s)".format(device_str, total_gpu_num))
return False
except:
print("CUDA device should have form like 'cuda:0', 'cuda:n', etc. (n is an integer)")
return False
if isinstance(device, str):
device = device.lower()
assert device in [
"cuda",
"cpu",
] or cuda_device_valid(device), "Input device is not valid, please specify 'cpu' or 'cuda' or 'cuda:n'"

# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]

# batch_size of 2 for batchnorm
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
x = [torch.rand(2, *in_size).to(device) for in_size in input_size]
# print(type(x[0]))

# create properties
Expand All @@ -74,11 +82,11 @@ def hook(module, input, output):
# remove these hooks
for h in hooks:
h.remove()

print("----------------------------------------------------------------")
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
print(line_new)
print("================================================================")
if show:
print("----------------------------------------------------------------")
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
print(line_new)
print("================================================================")
total_params = 0
total_output = 0
trainable_params = 0
Expand All @@ -94,22 +102,23 @@ def hook(module, input, output):
if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"]
print(line_new)
if show:
print(line_new)

# assume 4 bytes/number (float on cuda).
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size

print("================================================================")
print("Total params: {0:,}".format(total_params))
print("Trainable params: {0:,}".format(trainable_params))
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
print("----------------------------------------------------------------")
print("Input size (MB): %0.2f" % total_input_size)
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
print("Params size (MB): %0.2f" % total_params_size)
print("Estimated Total Size (MB): %0.2f" % total_size)
print("----------------------------------------------------------------")
# return summary
if show:
print("================================================================")
print("Total params: {0:,}".format(total_params))
print("Trainable params: {0:,}".format(trainable_params))
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
print("----------------------------------------------------------------")
print("Input size (MB): %0.2f" % total_input_size)
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
print("Params size (MB): %0.2f" % total_params_size)
print("Estimated Total Size (MB): %0.2f" % total_size)
print("----------------------------------------------------------------")
return summary[layer]["output_shape"]