Skip to content

Commit

Permalink
Update to ZygoteRules 0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Sep 21, 2019
1 parent d66f3aa commit 49c239c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 14 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ZygoteRules = "0.2"
julia = "1"

[extras]
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ julia> B = [
5 3 5
];

julia> C, back = Zygote.forward(A, B) do A, B
julia> C, back = Zygote.pullback(A, B) do A, B
cut(A) * B
end;

Expand Down Expand Up @@ -55,7 +55,7 @@ julia> using ChainCutters: uncut

julia> using Setfield

julia> C, back = Zygote.forward((A = A, B = B, alpha = 2)) do p
julia> C, back = Zygote.pullback((A = A, B = B, alpha = 2)) do p
q = cut(@set p.B = uncut(p.B)) # only treat `B` as varying
q.A * q.B * q.alpha
end;
Expand Down
6 changes: 3 additions & 3 deletions test/test_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ using Zygote
B = rand(n, n)
Δ = rand(n, n)

y_plain, back_plain = Zygote.forward(op, A, B)
y_cut1, back_cut1 = Zygote.forward((a, b) -> op(cut(a), b), A, B)
y_cut2, back_cut2 = Zygote.forward((a, b) -> op(a, cut(b)), A, B)
y_plain, back_plain = Zygote.pullback(op, A, B)
y_cut1, back_cut1 = Zygote.pullback((a, b) -> op(cut(a), b), A, B)
y_cut2, back_cut2 = Zygote.pullback((a, b) -> op(a, cut(b)), A, B)
diff_plain = back_plain(Δ)
diff_cut1 = back_cut1(Δ)
diff_cut2 = back_cut2(Δ)
Expand Down
18 changes: 9 additions & 9 deletions test/test_broadcastablecallable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,23 @@ end
v = rand(5)

#=
y_actual, back_actual = Zygote.forward(v -> sum(f.(cut(u), v)), v)
y_desired, back_desired = Zygote.forward(v -> sum(f.(u, v)), v)
y_actual, back_actual = Zygote.pullback(v -> sum(f.(cut(u), v)), v)
y_desired, back_desired = Zygote.pullback(v -> sum(f.(u, v)), v)
@test y_actual == y_desired
@test back_actual(1) == back_desired(1)
y_actual, back_actual = Zygote.forward(u -> sum(f.(u, cut(v))), u)
y_desired, back_desired = Zygote.forward(u -> sum(f.(u, v)), u)
y_actual, back_actual = Zygote.pullback(u -> sum(f.(u, cut(v))), u)
y_desired, back_desired = Zygote.pullback(u -> sum(f.(u, v)), u)
@test y_actual == y_desired
@test back_actual(1) == back_desired(1)
=#

y_actual, back_actual = Zygote.forward(f.a) do a
y_actual, back_actual = Zygote.pullback(f.a) do a
g = cut(@set f.a = uncut(a))
# g = Zygote.@showgrad g
sum(g.(u, v))
end
y_desired, back_desired = Zygote.forward(f.a) do a
y_desired, back_desired = Zygote.pullback(f.a) do a
g = @set f.a = a
# g = Zygote.@showgrad g
sum(g.(u, v))
Expand All @@ -66,7 +66,7 @@ end
g = cut(@set f.a = a)
sum(g.(cut(u), cut(v)))
end
y_actual, back_actual = Zygote.forward(h, f.a)
y_actual, back_actual = Zygote.pullback(h, f.a)
@test y_actual == h(f.a)
@test back_actual(1) == (nothing,)
end
Expand All @@ -76,12 +76,12 @@ end
f = AddCall(Poly3(rand(4)...), Poly3(rand(4)...))
x = rand(5)

y_actual, back_actual = Zygote.forward(f.f.c2) do c
y_actual, back_actual = Zygote.pullback(f.f.c2) do c
g = cut(@set f.f.c2 = uncut(c))
# g = Zygote.@showgrad g
sum(g.(x))
end
y_desired, back_desired = Zygote.forward(f.f.c2) do c
y_desired, back_desired = Zygote.pullback(f.f.c2) do c
g = @set f.f.c2 = c
# g = Zygote.@showgrad g
sum(g.(x))
Expand Down

0 comments on commit 49c239c

Please sign in to comment.