# Flux/Zygote のデモ（２）

Zygoteの自動微分を利用した勾配法

Copyright (c) 2022 Tadashi Wadayama  
Released under the MIT license  
https://opensource.org/licenses/mit-license.php

In [51]:
using LinearAlgebra
using Plots
gr()
using Random
Random.seed!(1)
using Flux

In [52]:
f(x) = 4*x[1]^2 + x[2]^2

f (generic function with 1 method)

In [53]:
x0 = randn(2)

2-element Vector{Float64}:
 0.2972879845354616
 0.3823959677906078

### 勾配法の実装(1)

明示的に勾配ベクトルを記述

In [54]:
x = x0
for i in 1:10
    x = x - 0.1*[8*x[1], 2*x[2]]
    println(x)
end

[0.05945759690709232, 0.30591677423248625]
[0.011891519381418462, 0.244733419385989]
[0.0023783038762836915, 0.1957867355087912]
[0.0004756607752567383, 0.15662938840703294]
[9.513215505134763e-5, 0.12530351072562634]
[1.9026431010269525e-5, 0.10024280858050108]
[3.8052862020539047e-6, 0.08019424686440087]
[7.610572404107806e-7, 0.06415539749152069]
[1.5221144808215608e-7, 0.05132431799321655]
[3.044228961643121e-8, 0.041059454394573244]


### 勾配法の実装(2)

Zygote.jlによる自動微分(gradient)を利用

In [55]:
x = x0
for i in 1:10
    ps = params(x)
    gs = gradient(ps) do
        f(x)
    end
    x = x - 0.1*gs[x]
    println(x)
end

[0.05945759690709232, 0.30591677423248625]
[0.011891519381418462, 0.244733419385989]
[0.0023783038762836915, 0.1957867355087912]
[0.0004756607752567383, 0.15662938840703294]
[9.513215505134763e-5, 0.12530351072562634]
[1.9026431010269525e-5, 0.10024280858050108]
[3.8052862020539047e-6, 0.08019424686440087]
[7.610572404107806e-7, 0.06415539749152069]
[1.5221144808215608e-7, 0.05132431799321655]
[3.044228961643121e-8, 0.041059454394573244]


### 勾配法の実装(3)

自動微分に加えてFluxの最適化関数(勾配法)の利用

In [56]:
opt = Descent(0.1)
x = x0
for i in 1:10
    ps = params(x)
    gs = gradient(ps) do
        f(x)
    end
    Flux.Optimise.update!(opt, ps, gs)
    println(x)
end

[0.05945759690709232, 0.30591677423248625]
[0.011891519381418462, 0.244733419385989]
[0.0023783038762836915, 0.1957867355087912]
[0.0004756607752567383, 0.15662938840703294]
[9.513215505134763e-5, 0.12530351072562634]
[1.9026431010269525e-5, 0.10024280858050108]
[3.8052862020539047e-6, 0.08019424686440087]
[7.610572404107806e-7, 0.06415539749152069]
[1.5221144808215608e-7, 0.05132431799321655]
[3.044228961643121e-8, 0.041059454394573244]
