-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathgradient.jl
58 lines (41 loc) · 1.91 KB
/
gradient.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
using ReverseDiff: GradientTape, GradientConfig, gradient, gradient!, compile, DiffResults
#########
# setup #
#########
# some objective function to work with
f(a, b) = sum(a' * b + a * b')
# pre-record a GradientTape for `f` using inputs of shape 100x100 with Float64 elements
const f_tape = GradientTape(f, (rand(100, 100), rand(100, 100)))
# compile `f_tape` into a more optimized representation
const compiled_f_tape = compile(f_tape)
# some inputs and work buffers to play around with
a, b = rand(100, 100), rand(100, 100)
inputs = (a, b)
results = (similar(a), similar(b))
all_results = map(DiffResults.GradientResult, results)
cfg = GradientConfig(inputs)
####################
# taking gradients #
####################
# with pre-recorded/compiled tapes (generated in the setup above) #
#-----------------------------------------------------------------#
# this should be the fastest method, and non-allocating
gradient!(results, compiled_f_tape, inputs)
# the same as the above, but in addition to calculating the gradients, the value `f(a, b)`
# is loaded into the the provided `DiffResult` instances (see DiffResults.jl documentation).
gradient!(all_results, compiled_f_tape, inputs)
# this should be the second fastest method, and also non-allocating
gradient!(results, f_tape, inputs)
# you can also make your own function if you want to abstract away the tape
∇f!(results, inputs) = gradient!(results, compiled_f_tape, inputs)
# with a pre-allocated GradientConfig #
#-------------------------------------#
# these methods are more flexible than a pre-recorded tape, but can be
# wasteful since the tape will be re-recorded for every call.
gradient!(results, f, inputs, cfg)
gradient(f, inputs, cfg)
# without a pre-allocated GradientConfig #
#----------------------------------------#
# convenient, but pretty wasteful since it has to allocate the GradientConfig itself
gradient!(results, f, inputs)
gradient(f, inputs)