Skip to content

Commit

Permalink
Use ZygoteRules.literal_getproperty
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Sep 20, 2019
1 parent d3eed62 commit 2c3ddb6
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions src/ChainCutters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,13 @@ Base.getproperty(x::Variable, name::Symbol) = _uncut(getproperty(unwrap(x), name
nothingsfor(obj) =
NamedTuple{__fieldnames(obj)}(ntuple(_ -> nothing, nfields(obj)))

# Let's use this ugly formatting until `literal_getproperty` is moved
# to ZygoteRules.jl: https://github.com/FluxML/ZygoteRules.jl/issues/3
function __init__()
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin

using .Zygote: Zygote, unbroadcast

@adjoint function Zygote.literal_getproperty(obj::Wrapper, ::Val{name}) where name
Zygote.literal_getproperty(obj, Val(name)), function(Δ)
@adjoint function ZygoteRules.literal_getproperty(obj::Wrapper, ::Val{name}) where name
ZygoteRules.literal_getproperty(obj, Val(name)), function(Δ)
nt = nothingsfor(unwrap(obj))
(setproperties(nt, NamedTuple{(name,)}((Δ,))), nothing)
end
end

end # @require begin
end # function __init__


Setfield.setproperties(obj::Const, patch) =
Const(setproperties(unwrap(obj), patch))
Expand Down Expand Up @@ -202,4 +192,10 @@ using BroadcastableStructs: BroadcastableCallable, calling, splitargsfor
return y, broadcastablecallable_pullback
end

function __init__()
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
using .Zygote: unbroadcast
end
end

end # module

0 comments on commit 2c3ddb6

Please sign in to comment.