/
autodiff.jl
54 lines (44 loc) · 1.88 KB
/
autodiff.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
using ChainRulesCore
@doc raw"
einsum_grad(ixs, xs, iy, size_dict, cdy, i)
return the gradient of the result of evaluating the `EinCode` w.r.t
the `i`th tensor in `xs`. `cdy` is the result of applying the `EinCode`
to the `xs`.
# example
```jldoctest; setup = :(using OMEinsum)
julia> using OMEinsum: einsum_grad, get_size_dict
julia> a, b = rand(2,2), rand(2,2);
julia> c = einsum(EinCode((('i','j'),('j','k')), ('i','k')), (a,b));
julia> sd = get_size_dict((('i','j'),('j','k')), (a,b));
julia> einsum_grad((('i','j'),('j','k')), (a,b), ('i','k'), sd, c, 1) ≈ c * transpose(b)
true
```
"
function einsum_grad(ixs, @nospecialize(xs), iy, size_dict, cdy, i)
nixs = _insertat(ixs, i, iy)
nxs = _insertat( xs, i, cdy)
niy = ixs[i]
y = einsum(DynamicEinCode(nixs, niy), nxs, size_dict)
return ChainRulesCore.ProjectTo(xs[i])(conj(y)) # do not use `conj!` because we want to support Hessians.
end
function ChainRulesCore.rrule(::typeof(einsum), code::EinCode, @nospecialize(xs), size_dict)
y = einsum(code, xs, size_dict)
function einsum_pullback(dy)
dy = convert(typeof(y), dy) # for filled array/cuarray et al.
dxs = ChainRulesCore.@thunk ntuple(i -> einsum_grad(getixs(code), xs, getiy(code), size_dict, conj(dy), i), length(xs))
return (NoTangent(), NoTangent(), dxs, NoTangent())
end
einsum_pullback(::NoTangent) = (NoTangent(), NoTangent(), NoTangent(), NoTangent())
return y, einsum_pullback
end
function ChainRulesCore.rrule(::typeof(_safe_set), lst, i, x)
y = _safe_set(lst, i, x)
function set_pullback(dy)
return (NoTangent(), dy, NoTangent(), dy[i])
end
set_pullback(::NoTangent) = (NoTangent(), NoTangent(), NoTangent(), NoTangent())
return y, set_pullback
end
@non_differentiable get_size_dict!(::Any, ::Any, ::Any)
@non_differentiable DynamicEinCode(::Any, ::Any)
@non_differentiable DynamicEinCode(::Any)