In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import numpy as np
import torch.optim as optim


### 搭建模型

In [3]:
class LinearBNAC(nn.Module):
    def __init__(self, in_channels, out_channels, bias=True, dropout=0.3, is_output=False):
        super(LinearBNAC, self).__init__()
        if is_output and out_channels==1:
            self.linear = nn.Sequential(
                nn.Linear(in_channels, out_channels, bias=bias),
                nn.Sigmoid()
            )
        elif is_output:
            self.linear = nn.Sequential(
                nn.Linear(in_channels, out_channels, bias=bias),
                nn.Softmax(dim=1)
            )   
        else:
            self.linear = nn.Sequential(
                nn.Linear(in_channels, out_channels, bias=bias),
                nn.Dropout(dropout),
                nn.BatchNorm1d(out_channels),
                nn.LeakyReLU(inplace=True)
            )
            
    def forward(self, x):
        out=self.linear(x)
        return out

In [14]:
class Model(nn.Module):
    def __init__(self, input_dimention, output_classes=1):
        super(Model, self).__init__()
        self.layer1 = LinearBNAC(input_dimention, 128)
        self.layer2 = LinearBNAC(128,256)
        self.layer3 = LinearBNAC(256,64)
        self.layer4 = LinearBNAC(64, 32)
        self.output = LinearBNAC(32, output_classes, is_output=True)
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.output(x)
        return x 
        

### 準備輸入資料、優化器、標籤資料、模型輸出

In [15]:
model = Model(input_dimention=256,output_classes=10)
optimizer = optim.Adam(params=model.parameters(), lr=1e-3, weight_decay=1e-3)

In [16]:
batch_size = 4
input_features = 256
dummy_input = torch.randn(batch_size, input_features,)

#target = torch.empty(4, dtype=torch.float).random_(10)
target = torch.tensor([9., 5., 4., 4.], dtype=torch.long)

In [17]:
output = model(dummy_input)
print(output)

tensor([[0.2604, 0.0803, 0.0698, 0.0768, 0.0577, 0.0467, 0.1006, 0.1674, 0.0841,
         0.0561],
        [0.0940, 0.1246, 0.0746, 0.1233, 0.0955, 0.0661, 0.0601, 0.0963, 0.1273,
         0.1381],
        [0.1248, 0.0903, 0.1302, 0.0776, 0.1519, 0.0396, 0.0795, 0.1231, 0.0926,
         0.0904],
        [0.0752, 0.0761, 0.1104, 0.1441, 0.0783, 0.1187, 0.0733, 0.1862, 0.0806,
         0.0571]], grad_fn=<SoftmaxBackward>)


### 計算 CrossEntropy Loss
* 請注意哪一個 Loss最適合：我們已經使用 softmax
* 因為我們有使用dropout，並隨機產生dummy_input，所以各為學員得到的值會與解答不同，然而步驟原理需要相同

In [18]:
from torch.nn import NLLLoss, LogSoftmax, CrossEntropyLoss

In [27]:
criterion = CrossEntropyLoss()

In [31]:
loss = criterion(torch.log(output), target)

In [32]:
loss

tensor(2.5074, grad_fn=<NllLossBackward>)

### 完成back propagation並更新梯度

In [33]:
loss.backward()

In [34]:
print('weight : {}'.format(model.layer1.linear[0].weight))
print('\n')
print('grad : {}'.format(model.layer1.linear[0].weight.grad))

weight : Parameter containing:
tensor([[ 0.0483, -0.0293, -0.0423,  ...,  0.0426, -0.0160, -0.0508],
        [ 0.0396, -0.0108, -0.0328,  ...,  0.0317,  0.0289,  0.0034],
        [ 0.0489, -0.0236,  0.0154,  ...,  0.0439,  0.0460, -0.0065],
        ...,
        [ 0.0448,  0.0290,  0.0222,  ..., -0.0246,  0.0211,  0.0377],
        [ 0.0096, -0.0469,  0.0582,  ..., -0.0443, -0.0579, -0.0447],
        [ 0.0368,  0.0218, -0.0507,  ..., -0.0300,  0.0015,  0.0166]],
       requires_grad=True)


grad : tensor([[ 1.5044e-01, -2.6275e-01,  3.9103e-02,  ..., -4.6727e-02,
         -5.8239e-02, -1.3566e-01],
        [-1.5740e-07, -1.3140e-07, -7.4710e-08,  ..., -1.9736e-08,
          2.8783e-09,  6.4984e-08],
        [-1.3266e-02,  4.4235e-03, -9.8266e-03,  ...,  1.3352e-04,
         -7.0837e-03,  1.3571e-02],
        ...,
        [-3.9177e-02,  1.8807e-02,  3.6561e-02,  ..., -1.6867e-02,
         -4.9197e-03, -2.3362e-02],
        [ 1.4734e-03,  2.0869e-03, -1.0921e-03,  ...,  2.5712e-03,
       

In [35]:
optimizer.step()

In [36]:
print('weight : {}'.format(model.layer1.linear[0].weight))
print('\n')
print('grad : {}'.format(model.layer1.linear[0].weight.grad))

weight : Parameter containing:
tensor([[ 0.0473, -0.0283, -0.0433,  ...,  0.0436, -0.0150, -0.0498],
        [ 0.0386, -0.0098, -0.0318,  ...,  0.0307,  0.0279,  0.0024],
        [ 0.0499, -0.0246,  0.0164,  ...,  0.0429,  0.0470, -0.0075],
        ...,
        [ 0.0458,  0.0280,  0.0212,  ..., -0.0236,  0.0221,  0.0387],
        [ 0.0086, -0.0479,  0.0592,  ..., -0.0453, -0.0589, -0.0457],
        [ 0.0378,  0.0228, -0.0497,  ..., -0.0290,  0.0005,  0.0156]],
       requires_grad=True)


grad : tensor([[ 1.5044e-01, -2.6275e-01,  3.9103e-02,  ..., -4.6727e-02,
         -5.8239e-02, -1.3566e-01],
        [-1.5740e-07, -1.3140e-07, -7.4710e-08,  ..., -1.9736e-08,
          2.8783e-09,  6.4984e-08],
        [-1.3266e-02,  4.4235e-03, -9.8266e-03,  ...,  1.3352e-04,
         -7.0837e-03,  1.3571e-02],
        ...,
        [-3.9177e-02,  1.8807e-02,  3.6561e-02,  ..., -1.6867e-02,
         -4.9197e-03, -2.3362e-02],
        [ 1.4734e-03,  2.0869e-03, -1.0921e-03,  ...,  2.5712e-03,
       

### 清空 gradient

In [37]:
optimizer.zero_grad()

In [38]:
print('weight : {}'.format(model.layer1.linear[0].weight))
print('\n')
print('grad : {}'.format(model.layer1.linear[0].weight.grad))

weight : Parameter containing:
tensor([[ 0.0473, -0.0283, -0.0433,  ...,  0.0436, -0.0150, -0.0498],
        [ 0.0386, -0.0098, -0.0318,  ...,  0.0307,  0.0279,  0.0024],
        [ 0.0499, -0.0246,  0.0164,  ...,  0.0429,  0.0470, -0.0075],
        ...,
        [ 0.0458,  0.0280,  0.0212,  ..., -0.0236,  0.0221,  0.0387],
        [ 0.0086, -0.0479,  0.0592,  ..., -0.0453, -0.0589, -0.0457],
        [ 0.0378,  0.0228, -0.0497,  ..., -0.0290,  0.0005,  0.0156]],
       requires_grad=True)


grad : tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
