Skip to content

tkf/ChainCutters.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

37 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ChainCutters

Build Status Codecov Coveralls

Treating arguments as constants

Use ChainCutters.cut(x) to treat x as a constant. Only *, + and - are supported.

julia> using ChainCutters: cut

julia> using LinearAlgebra, Zygote

julia> A = [
           1  9  1
           9  1  2
           5  3  5
       ];

julia> B = [
           7  9  1
           9  1  6
           5  3  5
       ];

julia> C, back = Zygote.pullback(A, B) do A, B
           cut(A) * B
       end;

julia> C == A * B
true

julia> ∂A, ∂B = back(I(3));

julia> ∂A === nothing  # `A` is treated as a constant
true

julia> ∂B
3×3 Array{Int64,2}:
 1  9  5
 9  1  3
 1  2  5

Treating specific fields of constant object as variables

Fields inside objects marked as constant by cut can be marked as a variable using uncut.

julia> using ChainCutters: uncut

julia> using Setfield

julia> C, back = Zygote.pullback((A = A, B = B, alpha = 2)) do p
           q = cut(@set p.B = uncut(p.B))  # only treat `B` as varying
           q.A * q.B * q.alpha
       end;

julia> C == A * B * 2
true

julia> ∂p, = back(I(3));

julia> ∂p
(A = nothing, B = [2 18 10; 18 2 6; 2 4 10], alpha = nothing)

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages