In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### define network

In [34]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # input 32*32
        self.conv1 = nn.Conv2d(1,6,3) #1 input image channel, 6 output channels, 3x3 square convolution
        self.conv2 = nn.Conv2d(6,16,3) #6 input image channel, 16 output channels, 3x3 square convolution
        # conv2D: kernel_size=(3, 3), stride=(1, 1)) is default setting

        self.fc1 = nn.Linear(16*(6*6), 120)  # 6*6 from image dimension to 120
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84,10)

    def forward(self, x):
        x= F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x= F.max_pool2d(F.relu(self.conv2(x)), (2,2))# These sentenses can be separated
        # max pool 2d :stride=kernel_size,padding=0 is default setting

        x = x.view(-1, self.num_flat_features(x)) # to vector
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        # size = #sample * #channel * #height * #width
        size = x.size()[1:]
        num_feature = 1
        for s in size:
            num_feature *= s
        return num_feature


In [35]:

net = Net()
print(net)

Net(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [43]:
params = list(net.parameters())
print(params[0].size())
print(len(params))


torch.Size([6, 1, 3, 3])
10


torch.Size([6, 1, 3, 3])
10


### Set input and back propagation 

In [44]:
input  = (torch.randn(1,1,32,32)) #sample * #channel * #height * #width
out = net(input)
print(out)

tensor([[-0.1205,  0.0523, -0.0669, -0.0886,  0.0369,  0.0070, -0.0821,  0.0056,
          0.0342, -0.0915]], grad_fn=<AddmmBackward>)


In [45]:
net.zero_grad() #set  0 to gradation. 
out.backward(torch.randn(1,10)) # (1,10) is output -> backward propagation

In [46]:
params[0].grad.data


tensor([[[[-0.0533,  0.0213,  0.0073],
          [ 0.0366, -0.0253,  0.0403],
          [ 0.0161, -0.0613, -0.0129]]],


        [[[-0.0546, -0.0175, -0.0264],
          [ 0.0176, -0.0646,  0.0156],
          [-0.0412, -0.0139, -0.0150]]],


        [[[ 0.0537,  0.0302, -0.0493],
          [-0.0677,  0.0264,  0.0044],
          [ 0.0283,  0.0382, -0.0424]]],


        [[[-0.0235,  0.0437, -0.0437],
          [-0.0277, -0.0136, -0.0046],
          [ 0.0255, -0.0172, -0.0194]]],


        [[[ 0.0453,  0.0002,  0.0350],
          [ 0.0022, -0.0007, -0.0425],
          [-0.0112, -0.0146, -0.0246]]],


        [[[ 0.0369,  0.0070, -0.0192],
          [ 0.0380,  0.0396, -0.0202],
          [-0.0049, -0.0072, -0.0105]]]])

###  Loss function  

In [70]:
output = net(input)
target = torch.randn(10) 
target = target.view(1,-1) #target value

In [71]:
criterion = nn.MSELoss()

In [72]:
loss=criterion(output, target)
print(loss)

tensor(0.8953, grad_fn=<MseLossBackward>)


In [73]:
print(loss.grad_fn)
print(loss.grad_fn.next_functions[0][0])
# input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d 
#       -> view -> linear -> relu -> linear -> relu -> linear
#       -> MSELoss
#       -> loss

<MseLossBackward object at 0x7f9105f60b90>
<AddmmBackward object at 0x7f9105f60f90>


In [74]:
net.zero_grad() #init 
print(net.conv1.bias.grad)
loss.backward()#cal gradient and reflect result
print(net.conv1.bias.grad)

tensor([0., 0., 0., 0., 0., 0.])
tensor([-0.0005, -0.0013,  0.0095,  0.0131,  0.0216, -0.0140])


### update weight

In [84]:
# simple way
lr = 0.01
for i, f in enumerate(net.parameters()):
    print(i)
    f.data.sub_(f.grad.data * lr) # f.data = f.data - f.grad.data * lr 

0
1
2
3
4
5
6
7
8
9


In [86]:
# use optimizer
import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr = lr)

## Does each iteration ##
optimizer.zero_grad()
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step() # Does update