Skip to content

Commit

Permalink
Add size information
Browse files Browse the repository at this point in the history
Added information for estimating the total size of the model. Estimates taken from here: http://jacobkimmel.github.io/pytorch_estimating_model_size/
-calculates size of input, parameters, and forward/backward pass intermediate variables
-prints out these estimates and total
-batch_size optional input
  • Loading branch information
rmchurch committed Aug 4, 2018
1 parent 6d9f77c commit 81df232
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions torchsummary/torchsummary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from torch.autograd import Variable

from collections import OrderedDict
import numpy as np


def summary(model, input_size, device="cuda"):
def summary(model, input_size, batch_size=-1,device="cuda"):
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split('.')[-1].split("'")[0]
Expand All @@ -14,12 +14,12 @@ def hook(module, input, output):
m_key = '%s-%i' % (class_name, module_idx+1)
summary[m_key] = OrderedDict()
summary[m_key]['input_shape'] = list(input[0].size())
summary[m_key]['input_shape'][0] = -1
summary[m_key]['input_shape'][0] = batch_size
if isinstance(output, (list,tuple)):
summary[m_key]['output_shape'] = [[-1] + list(o.size())[1:] for o in output]
else:
summary[m_key]['output_shape'] = list(output.size())
summary[m_key]['output_shape'][0] = -1
summary[m_key]['output_shape'][0] = batch_size

params = 0
if hasattr(module, 'weight') and hasattr(module.weight, 'size'):
Expand Down Expand Up @@ -67,18 +67,31 @@ def hook(module, input, output):
print(line_new)
print('================================================================')
total_params = 0
total_output = 0
trainable_params = 0
for layer in summary:
# input_shape, output_shape, trainable, nb_params
line_new = '{:>20} {:>25} {:>15}'.format(layer, str(summary[layer]['output_shape']), '{0:,}'.format(summary[layer]['nb_params']))
total_params += summary[layer]['nb_params']
total_output += np.prod(summary[layer]['output_shape'])
if 'trainable' in summary[layer]:
if summary[layer]['trainable'] == True:
trainable_params += summary[layer]['nb_params']
print(line_new)
#assume 4 bytes/number (float on cuda).
total_input_size = abs(np.prod(input_size)*batch_size*4./(1024**2.))
total_output_size = abs(2.*total_output*4./(1024**2.)) #x2 for gradients
total_params_size = abs(total_params.numpy()*4./(1024**2.))
total_size = total_params_size + total_output_size + total_input_size

print('================================================================')
print('Total params: {0:,}'.format(total_params))
print('Trainable params: {0:,}'.format(trainable_params))
print('Non-trainable params: {0:,}'.format(total_params - trainable_params))
print('----------------------------------------------------------------')
# return summary
print('Input size (MB): %0.2f' % total_input_size)
print('Forward/backward pass size (MB): %0.2f' % total_output_size)
print('Params size (MB): %0.2f' % total_params_size)
print('Estimated Total Size (MB): %0.2f' % total_size)
print('----------------------------------------------------------------')
#return summary

0 comments on commit 81df232

Please sign in to comment.