In [59]:
import torch
import torch.nn as nn
import torch.nn.functional as F

##### Torch Sample Net Class
Pytorch 에서는 Neural Network Module 클래스를 제공해준다. Module은 파라미터 분석, GPU, Exporting, Loading 등의 편리한 기능을 제공한다.

특히 forward 부분만 제공해도 back propagation 부분을 자동으로 제공해주는데, Autograd 기능을 기본적으로 사용하기 때문에 사용하기 쉽다.

In [125]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        
        self.fc1 = nn.Linear(16 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x)) # flatten
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        
        return num_features

net = Net()
print(net)

Net(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [126]:
params = list(net.parameters())

print(len(params))
print(params[0])

10
Parameter containing:
tensor([[[[ 0.1517,  0.2703, -0.0641],
          [-0.0864, -0.0806, -0.1732],
          [ 0.0268,  0.0227,  0.3258]]],


        [[[-0.1264, -0.2172, -0.1967],
          [ 0.2041,  0.2621,  0.1639],
          [ 0.1060, -0.2884,  0.1346]]],


        [[[ 0.0691,  0.0900,  0.1056],
          [-0.0724, -0.1461,  0.2538],
          [-0.1159, -0.2514, -0.2045]]],


        [[[-0.0328,  0.3147,  0.1305],
          [ 0.0503, -0.1428,  0.2043],
          [-0.1962,  0.0285,  0.2514]]],


        [[[-0.0648,  0.3001, -0.0829],
          [-0.0422,  0.2867, -0.2773],
          [-0.2023, -0.1014, -0.2773]]],


        [[[-0.1163,  0.2336, -0.0223],
          [-0.1994,  0.1215,  0.0291],
          [ 0.3244, -0.0400, -0.2256]]]], requires_grad=True)


In [127]:
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

tensor([[-0.0932,  0.0126, -0.0428,  0.1333, -0.0101,  0.0387, -0.0174,  0.0377,
         -0.1365,  0.0199]], grad_fn=<AddmmBackward>)


In [128]:
# backward opartion.
net.zero_grad()
out.backward(torch.randn(1, 10))

In [155]:
# Loss 
output = net(input)
target = torch.randn(10)
target = target.view(1, -1)
criterion = nn.MSELoss()

loss = criterion(output, target)

print(loss)
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU

tensor(1.4171, grad_fn=<MseLossBackward>)
<MseLossBackward object at 0x11e1fd810>
<AddmmBackward object at 0x11dd5aed0>
<AccumulateGrad object at 0x11e1fd810>


In [156]:
# Backpropagation
net.zero_grad()
print('conv1.bias.grad before backward')
print(net.conv1.weight.grad)

loss.backward()

print('conv1.bias.grad after backward')
print(net.conv1.weight.grad)

conv1.bias.grad before backward
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])
conv1.bias.grad after backward
tensor([[[[-0.0072,  0.0018, -0.0066],
          [ 0.0045,  0.0120, -0.0136],
          [-0.0083,  0.0027,  0.0231]]],


        [[[ 0.0267,  0.0181, -0.0174],
          [-0.0102, -0.0016, -0.0152],
          [-0.0170,  0.0095, -0.0092]]],


        [[[ 0.0130,  0.0023, -0.0120],
          [-0.0120,  0.0069,  0.0037],
          [-0.0149, -0.0054, -0.0069]]],


        [[[ 0.0009,  0.0162, -0.0086],
          [ 0.0080,  0.0076,  0.0048],
          [-0.0159,  0.0240,  0.0225]]],
