Skip to content
This repository has been archived by the owner on May 18, 2022. It is now read-only.

merge into NNlib and CUDA? #32

Closed
CarloLucibello opened this issue Dec 21, 2020 · 8 comments
Closed

merge into NNlib and CUDA? #32

CarloLucibello opened this issue Dec 21, 2020 · 8 comments

Comments

@CarloLucibello
Copy link

CarloLucibello commented Dec 21, 2020

Hi,
in FluxML/Flux.jl#1431 there was some talk about having the primitives defined here more widely available in the ecosystem. In order to do this, the Zygote and CUDA dependencies should be dropped, because they could be an unnecessary and huge payload for other packages. Therefore, we should have the following steps:

  1. Replace Zygote and ZygoteRules adjoint's definitions with ChainRules ones
  2. Move the cpu implementations to NNlib.jl
  3. Move the gpu kernels to CUDA.jl (which already depends on NNlib.jl) if @maleadt is willing to accept them

@yuehhua does this plan make sense?

cc @dfdx @jeremiedb @chengchingwen

@dfdx
Copy link

dfdx commented Dec 22, 2020

Currently, NNlib doesn't depend on ChainRules, so it may be better to move adjoints definitions to Flux directly.
I also suggest splitting pullbacks into forward and reverse functions so that they could be used separately by libraries with other gradient calculation rules (e.g. Yota). For example instead of this:

@adjoint function scatter_add!(ys::AbstractArray, us::AbstractArray, xs::AbstractArray)
    ys_ = copy(ys)
    scatter_add!(ys_, us, xs)
    ys_, Δ -> (Δ, gather(Δ, xs), nothing)
end

have this (assuming I understood semantics of @adjoint correctly):

∇scatter_add_ys!(Δ, xs) = Δ
∇scatter_add_us!(Δ, xs) = gather(Δ, xs)

function rrule(::typeof(scatter_add!), args...; kwargs...)
    ys_ = copy(ys)
    scatter_add!(ys_, us, xs)
    ys_, Δ -> (∇scatter_add_ys!(Δ, xs), ∇scatter_add_us!(Δ, xs), nothing)
end

If this looks good to everyone, I can try it out in NNlib / CUDA / Flux during this or next weekend.

@yuehhua
Copy link
Owner

yuehhua commented Dec 23, 2020

It makes sense to me. If there is anything I can help, just let me know.
Currently, as @dfdx suggested, put adjoints definitions to Flux directly and ChainRules's definitions to NNlib and CUDA separately?
Also, separate forward and reverse functions?
I will generalize scatter operations to every dimensions as well.

@chengchingwen
Copy link

@dfdx Some gradient function need the intermediate values from the forward pass. If we are going to split the definition, then some backward function would need extra argument to get the values instead of recalculate again.

@CarloLucibello
Copy link
Author

The only correction to the comments above is that NNlib's rules are currently being moved to NNlib itself FluxML/NNlib.jl#242, so scatter's rules should go there as well.

@yuehhua it would be nice if you could file the PR to NNlib yourself so that you preserve authorship.

@dfdx
Copy link

dfdx commented Dec 23, 2020

@dfdx Some gradient function need the intermediate values from the forward pass. If we are going to split the definition, then some backward function would need extra argument to get the values instead of recalculate again.

True, and in general some refactoring of forward pass functions may be needed. But in ScatterNNlib all adjoint definitions follow the same, simple to split pattern.

All in all, it looks much easier to have separate forward and reverse pass functions and combine them in a pullback than to have only pullback and try to extract forward & reverse passes from it. This is essentially the reason Yota.jl (and perhaps any non-pullback-based library) still doesn't use ChainRules.jl.

(I hope it doesn't sound like a selfish argument :))

@yuehhua
Copy link
Owner

yuehhua commented Jul 2, 2021

All migrations are complete! Thank you everyone.

@yuehhua yuehhua closed this as completed Jul 2, 2021
@CarloLucibello
Copy link
Author

Amazing and relentless work, thanks @yuehhua !

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants