# 勾配クリッピングの実装例

In [1]:
import numpy as np
from common.layers import TwoLayerNet

In [2]:
def clip_grads(grads, max_norm):
    """
    grads : 勾配をまとめたdict
    max_norm : 閾値
    """
    
    # 勾配の2ノルムを求める
    total_norm = 0
    for key, grad in grads.items():
        total_norm += np.sum(grad ** 2, axis=None)
    total_norm = np.sqrt(total_norm)

    # 勾配をクリッピング
    rate = max_norm / (total_norm + 1e-6)
    if rate < 1:
        """
        勾配の2ノルムが閾値を超えた場合
        """
        for key, grad in grads.items():
            grads[key] *= rate
            

### 実行例

In [3]:
np.random.seed(1324)
g1 = np.random.rand(2,5) * 10
g2 = np.random.rand(3,4) * 5
grads = {}
grads["W1"] = g1
grads["W2"] = g2

print("クリッピング前")
print("grads=")
print(grads["W1"].round(2))
print(grads["W2"].round(2))
print("max of grads=", max(g1.max(),g2.max()))
print()

# クリッピング
clip_grads(grads, max_norm=5)

print("クリッピング後")
print("grads=")
print(grads["W1"].round(2))
print(grads["W2"].round(2))
print("max of grads=", max(g1.max(),g2.max()))
print()

クリッピング前
grads=
[[9.91 7.49 4.71 3.08 9.2 ]
 [6.65 0.75 9.93 0.61 1.48]]
[[0.43 2.04 3.38 0.49]
 [3.65 0.34 0.81 4.78]
 [2.7  0.39 2.7  0.71]]
max of grads= 9.932045233468367

クリッピング後
grads=
[[2.25 1.7  1.07 0.7  2.09]
 [1.51 0.17 2.26 0.14 0.34]]
[[0.1  0.46 0.77 0.11]
 [0.83 0.08 0.18 1.09]
 [0.61 0.09 0.61 0.16]]
max of grads= 2.2551488241293045



### TwoLayerNetを用いた実行例

In [4]:
tnet = TwoLayerNet(input_size=5, hidden_size=4, output_size=3)

x = np.array([[1,2,3,4,5]])
t = np.array([[0, 0, 1]])

# 勾配を計算する
grads = tnet.gradient(x, t)
print("クリッピング前")
for key, value in grads.items():
    print(key,"=")
    print(value)
    print()

# クリッピング
clip_grads(grads, max_norm=0.5)

print("クリッピング後")
for key, value in grads.items():
    print(key,"=")
    print(value)
    print()

クリッピング前
W1 =
[[-0.00414807  0.         -0.00257031  0.        ]
 [-0.00829613  0.         -0.00514062  0.        ]
 [-0.0124442   0.         -0.00771094  0.        ]
 [-0.01659226  0.         -0.01028125  0.        ]
 [-0.02074033  0.         -0.01285156  0.        ]]

b1 =
[-0.00414807  0.         -0.00257031  0.        ]

W2 =
[[ 0.01641493  0.01636278 -0.03277771]
 [ 0.          0.         -0.        ]
 [ 0.03962442  0.03949854 -0.07912296]
 [ 0.          0.         -0.        ]]

b2 =
[ 0.3337785   0.33271811 -0.66649661]

クリッピング後
W1 =
[[-0.00251761  0.         -0.00156001  0.        ]
 [-0.00503522  0.         -0.00312003  0.        ]
 [-0.00755282  0.         -0.00468004  0.        ]
 [-0.01007043  0.         -0.00624005  0.        ]
 [-0.01258804  0.         -0.00780007  0.        ]]

b1 =
[-0.00251761  0.         -0.00156001  0.        ]

W2 =
[[ 0.0099628   0.00993115 -0.01989395]
 [ 0.          0.         -0.        ]
 [ 0.02404946  0.02397306 -0.04802252]
 [ 0.          0.  