## 最適化手法の比較
SGDとAdamを比較

### モデルの定義

In [66]:
import random

import torch
import torch.utils.data
from torch import nn
from torch import optim

class MyModel(nn.Module):
    def __init__(self, init_value1, init_value2):
        super(MyModel, self).__init__()
        self.w_1 = nn.Parameter(torch.tensor([init_value1]))
        self.w_2 = nn.Parameter(torch.tensor([init_value2]))
    
    def forward_loss(self):
        loss = (1/20) * self.w_1 ** 2 + self.w_2 ** 2 # (w_1, w_2) = (0, 0) が最適解
        return loss


In [67]:
# 初期重み (SGDとAdamを比較するために初期重みを定めておく)
init_val1 = random.random()
init_val2 = random.random()

# モデルのインスタンス (比較のために両方init_val1とinit_val2を初期重みにする)
model_sgd = MyModel(init_val1, init_val2) # SGD用
model_adam = MyModel(init_val1, init_val2) # Adam用

### SGD or Adamによる最適化

In [68]:
op_sgd = optim.SGD(model_sgd.parameters(), lr=0.01) # lr:learning rate (学習率)
op_adam = optim.Adam(model_adam.parameters(), lr=0.01) # lr:learning rate (学習率)

### 学習

In [69]:
for epoch in range(30):
    loss_sgd = model_sgd.forward_loss()
    loss_adam = model_adam.forward_loss()
    print("epoch: {}, loss_sgd:{}, loss_adam:{}".format(epoch, float(loss_sgd.data), float(loss_adam.data)))
    print("sgd w1:{}, w2:{}".format(float(model_sgd.w_1), float(model_sgd.w_2)))
    print("adam w1:{}, w2:{}".format(float(model_adam.w_1), float(model_adam.w_2)))
    print("-"*10)

    # 逆伝搬 (勾配の設定)
    loss_sgd.backward()
    loss_adam.backward()

    # パラメータの更新
    op_sgd.step()
    op_adam.step()

    # 勾配の消去
    model_sgd.zero_grad()
    model_adam.zero_grad()

epoch: 0, loss_sgd:0.0003822590515483171, loss_adam:0.0003822590515483171
sgd w1:0.08444901555776596, w2:0.005067271180450916
adam w1:0.08444901555776596, w2:0.005067271180450916
----------
epoch: 1, loss_sgd:0.00038052938180044293, loss_adam:0.00030146457720547915
sgd w1:0.08436456322669983, w2:0.0049659255892038345
adam w1:0.07444902509450912, w2:-0.004932718351483345
----------
epoch: 2, loss_sgd:0.0003788414760492742, loss_adam:0.00022864290804136544
sgd w1:0.08428020030260086, w2:0.004866607021540403
adam w1:0.06450152397155762, w2:-0.004540988709777594
----------
epoch: 3, loss_sgd:0.00037719361716881394, loss_adam:0.0001500424405094236
sgd w1:0.08419591933488846, w2:0.004769274964928627
adam w1:0.05465327575802803, w2:-0.0008327120449393988
----------
epoch: 4, loss_sgd:0.0003755843499675393, loss_adam:0.0001087170749087818
sgd w1:0.08411172032356262, w2:0.004673889372497797
adam w1:0.044962070882320404, w2:0.0027636343147605658
----------
epoch: 5, loss_sgd:0.000374012161046266