In [1]:
import numpy as np
import torch
import torch.nn as nn

In [2]:
# build two models

widenet = nn.Sequential(
    nn.Linear(2,4),  # hidden layer
    nn.Linear(4,3),  # output layer
    )


deepnet = nn.Sequential(
    nn.Linear(2,2),  # hidden layer
    nn.Linear(2,2),  # hidden layer
    nn.Linear(2,3),  # output layer
    )

# print them out to have a look
print(widenet)
print(' ')
print(deepnet)

Sequential(
  (0): Linear(in_features=2, out_features=4, bias=True)
  (1): Linear(in_features=4, out_features=3, bias=True)
)
 
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
  (2): Linear(in_features=2, out_features=3, bias=True)
)


In [4]:
for p in deepnet.named_parameters():
    print(p, end='\n\n')

('0.weight', Parameter containing:
tensor([[-0.6642,  0.4010],
        [ 0.3723,  0.4078]], requires_grad=True))

('0.bias', Parameter containing:
tensor([-0.6598, -0.5642], requires_grad=True))

('1.weight', Parameter containing:
tensor([[-0.4614, -0.3372],
        [ 0.1609, -0.4838]], requires_grad=True))

('1.bias', Parameter containing:
tensor([ 0.6884, -0.4641], requires_grad=True))

('2.weight', Parameter containing:
tensor([[ 0.0590, -0.0448],
        [ 0.0494, -0.4010],
        [ 0.4525, -0.2821]], requires_grad=True))

('2.bias', Parameter containing:
tensor([-0.4849, -0.3421, -0.5981], requires_grad=True))



In [5]:
# count the number of nodes ( = the number of biases)

# named_parameters() is an iterable that returns the tuple (name,numbers)
numNodesInWide = 0
for p in widenet.named_parameters():
    if 'bias' in p[0]:
        numNodesInWide += len(p[1])

numNodesInDeep = 0
for paramName,paramVect in deepnet.named_parameters():
    if 'bias' in paramName:
        numNodesInDeep += len(paramVect)


print('There are %s nodes in the wide network.' %numNodesInWide)
print('There are %s nodes in the deep network.' %numNodesInDeep)

There are 7 nodes in the wide network.
There are 7 nodes in the deep network.


In [6]:
# just the parameters
for p in widenet.parameters():
    print(p, end='\n\n')

Parameter containing:
tensor([[ 0.1152, -0.2921],
        [-0.0285,  0.2294],
        [-0.4228,  0.3815],
        [ 0.5880, -0.6070]], requires_grad=True)

Parameter containing:
tensor([ 0.4153,  0.2286, -0.2910,  0.4066], requires_grad=True)

Parameter containing:
tensor([[ 1.4704e-04, -4.9102e-01,  4.4487e-01,  2.7680e-01],
        [-4.3859e-01, -3.6658e-01, -2.2687e-01, -4.2947e-02],
        [-2.7399e-01, -4.7468e-01,  2.5115e-01,  4.8581e-01]],
       requires_grad=True)

Parameter containing:
tensor([-0.4942, -0.4227,  0.0617], requires_grad=True)



In [7]:
# now count the total number of trainable parameters
nparams = 0
for p in widenet.parameters():
    if p.requires_grad:
        print('This piece has %s parameters' %p.numel())
        nparams += p.numel()

print('\n\nTotal of %s parameters'%nparams)

This piece has 8 parameters
This piece has 4 parameters
This piece has 12 parameters
This piece has 3 parameters


Total of 27 parameters


In [8]:
# btw, can also use list comprehension

nparams = np.sum([ p.numel() for p in widenet.parameters() if p.requires_grad ])
print('Widenet has %s parameters'%nparams)

nparams = np.sum([ p.numel() for p in deepnet.parameters() if p.requires_grad ])
print('Deepnet has %s parameters'%nparams)

Widenet has 27 parameters
Deepnet has 21 parameters


In [9]:
# A nice simple way to print out the model info.
from torchsummary import summary
summary(widenet,(1,2))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 1, 4]              12
            Linear-2                 [-1, 1, 3]              15
Total params: 27
Trainable params: 27
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
