In [15]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)
# print (net.forward([1,1]))

Net (
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear (400 -> 120)
  (fc2): Linear (120 -> 84)
  (fc3): Linear (84 -> 10)
)


In [17]:
params = list(net.parameters())
print(len(params))
print(params[0].size()) 
print(params[1].size()) 
print(params[2].size()) 
print(params[3].size()) 
print(params[4].size()) 
print(params[5].size()) 

10
torch.Size([6, 1, 5, 5])
torch.Size([6])
torch.Size([16, 6, 5, 5])
torch.Size([16])
torch.Size([120, 400])
torch.Size([120])


In [14]:
list(net.parameters())

[Parameter containing:
 (0 ,0 ,.,.) = 
  -0.0636 -0.1017  0.0077 -0.1395 -0.0499
  -0.1712 -0.0065 -0.0785 -0.1342  0.1551
  -0.0449  0.1728  0.0711 -0.1573 -0.0655
  -0.0633 -0.0724 -0.1930  0.0825  0.1792
   0.1334 -0.0040 -0.1284  0.1560 -0.0304
 
 (1 ,0 ,.,.) = 
  -0.0988 -0.0327  0.0765 -0.1421 -0.1845
   0.1118  0.0993  0.0492  0.1954  0.0484
  -0.1850  0.0821 -0.1772 -0.0715  0.0456
   0.0831 -0.1694  0.1579  0.0285 -0.0280
   0.1337  0.1635  0.0806  0.0217  0.1132
 
 (2 ,0 ,.,.) = 
   0.0962 -0.0972  0.1386  0.1755  0.0119
   0.1209  0.0830 -0.0228 -0.1225  0.0494
  -0.0124  0.1878 -0.1487  0.1325 -0.0162
  -0.0508 -0.0511 -0.0315  0.1966 -0.0605
  -0.0025 -0.1639  0.1336  0.1297 -0.0662
 
 (3 ,0 ,.,.) = 
   0.1041 -0.0179 -0.0906  0.0124 -0.0355
   0.0596 -0.0457  0.0102 -0.1324 -0.0785
  -0.1454 -0.0310 -0.1913  0.1713 -0.0059
   0.0744 -0.0964 -0.1519  0.1004 -0.1245
   0.1229  0.1624  0.1045 -0.1566  0.0460
 
 (4 ,0 ,.,.) = 
  -0.0520  0.1165  0.1408  0.1597 -0.1905
   0.04