## Weight decay (L2正則化) ありなしの比較

### モデルの定義

In [2]:
import random

import torch
import torch.utils.data
from torch import nn
from torch import optim

class MyModel(nn.Module):
    def __init__(self, init_value1, init_value2):
        super(MyModel, self).__init__()
        self.w_1 = nn.Parameter(torch.tensor([init_value1]))
        self.w_2 = nn.Parameter(torch.tensor([init_value2]))
    
    def forward_loss(self):
        loss = (1/20) * self.w_1 ** 2 + self.w_2 ** 2
        return loss

In [5]:
# 初期重み (Weight decayありなしを比較するために初期重みを定めておく)
init_val1 = random.random()
init_val2 = random.random()

# モデルのインスタンス (比較のために両方init_val1とinit_val2を初期重みにする)
model = MyModel(init_val1, init_val2)
model_wd = MyModel(init_val1, init_val2)

### 最適化手法の定義
weight decayあり・なし

In [7]:
op = optim.SGD(model.parameters(), lr=0.01) # lr:learning rate (学習率)
op_wd = optim.SGD(model_wd.parameters(), lr=0.01, weight_decay=0.01) 

### 学習

In [23]:
for epoch in range(30):
    loss = model.forward_loss()
    loss_wd = model_wd.forward_loss()
    print("epoch: {}, loss:{}, loss_wd:{}".format(epoch, float(loss.data), float(loss_wd.data)))
    print("sgd w1:{}, w2:{}".format(float(model.w_1), float(model.w_2)))
    print("sgd_wd w1:{}, w2:{}".format(float(model_wd.w_1), float(model_wd.w_2)))
    print("-"*10)

    # 逆伝搬 (勾配の設定)
    loss.backward()
    loss_wd.backward()

    # パラメータの更新
    op.step()
    op_wd.step()

    # 勾配の消去
    model.zero_grad()
    model_wd.zero_grad()


epoch: 0, loss:0.010829604230821133, loss_wd:0.009758861735463142
sgd w1:0.465394526720047, w2:2.5314715458080173e-05
sgd_wd w1:0.4417886435985565, w2:2.4006429157452658e-05
----------
epoch: 1, loss:0.01080795656889677, loss_wd:0.009737404063344002
sgd w1:0.46492913365364075, w2:2.4808421585476026e-05
sgd_wd w1:0.44130268692970276, w2:2.3523900381405838e-05
----------
epoch: 2, loss:0.01078635174781084, loss_wd:0.009715993888676167
sgd w1:0.4644642174243927, w2:2.431225402688142e-05
sgd_wd w1:0.440817266702652, w2:2.3051070456858724e-05
----------
epoch: 3, loss:0.010764788836240768, loss_wd:0.009694630280137062
sgd w1:0.46399974822998047, w2:2.3826009055483155e-05
sgd_wd w1:0.4403323531150818, w2:2.2587744751945138e-05
----------
epoch: 4, loss:0.010743271559476852, loss_wd:0.009673313237726688
sgd w1:0.46353575587272644, w2:2.3349488401436247e-05
sgd_wd w1:0.4398479759693146, w2:2.2133730453788303e-05
----------
epoch: 5, loss:0.01072179526090622, loss_wd:0.00965204369276762
sgd w1: