-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable scalar/broadcast operation for LazyPropagation #167
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportBase: 81.88% // Head: 81.59% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #167 +/- ##
==========================================
- Coverage 81.88% 81.59% -0.30%
==========================================
Files 28 28
Lines 2186 2200 +14
==========================================
+ Hits 1790 1795 +5
- Misses 396 405 +9
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
src/rrules.jl
Outdated
@@ -32,16 +32,12 @@ LazyPropagation(F::judiPropagator, q) = LazyPropagation(identity, F, q) | |||
|
|||
for op in [:+, :-, :*, :/] | |||
@eval begin | |||
$(op)(F::LazyPropagation, y::AbstractArray{T}) where T = $(op)(eval_prop(F), y) | |||
$(op)(y::AbstractArray{T}, F::LazyPropagation) where T = $(op)(y, eval_prop(F)) | |||
$(op)(F::LazyPropagation, y::Union{AbstractArray{T}, T}) where T = $(op)(eval_prop(F), y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are generalizing in a bad way. Unions aren't really nice for different abstract type. So Union of like Vector and Matrix is fine but for Scalar and AbstractArray it's not great.
You are also completely ignoring the setup here that make type separate through eval, that Union completely defeat the purpose.
Finally, you are again not considering what this type is about nor what you are trying to do and put here first way "that works" you could make up. The main point is to avoid evaluating PDEs when not necessary and you are not forcing potentially un-neccessary PDEs. a LazyPropagation is Linear, if you do a .* Lazy
then it's the same as Lazy(L.F, a .* L.q)
which does not evaluate anything.
Your change lead to ambiguities... please run these basic tests locally |
src/rrules.jl
Outdated
$(op)(y::Union{AbstractArray{T}, T}, F::LazyPropagation) where T = $(op)(y, eval_prop(F)) | ||
$(op)(F::LazyPropagation, y::AbstractArray{T}) where T = $(op)(eval_prop(F), y) | ||
$(op)(y::AbstractArray{T}, F::LazyPropagation) where T = $(op)(y, eval_prop(F)) | ||
$(op)(F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, $(op)(F.q, y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No this is only true for *
and /
src/rrules.jl
Outdated
end | ||
end | ||
@eval begin | ||
broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), eval_prop(F.q), y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is eval_prop(F.q)
???
Again, only true for *
and /
On a side note: does it make sense to move the scalar operations (all of +-*/) into LazyPropagation.post? |
No because then it's not a linear operation anymore |
Hmm appreciate your @mloubout comment on this one: I am now on JUDI master and julia> gs_inv = gradient(x -> norm(F(x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.50 s
Operator `gradient` ran in 0.34 s
(Float32[-0.081900775 0.07301128 … 6.170804f-6 7.20752f-6; 0.0637427 0.027981473 … 9.756089f-7 5.4272978f-6; … ; 0.06374304 0.027981216 … 9.755976f-7 5.4272914f-6; -0.08189945 0.07301152 … 6.170794f-6 7.2075245f-6],)
julia> gs_inv1 = gradient(x -> norm(F(1f0*x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.55 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.34 s
ERROR: MethodError: no method matching *(::Float32, ::JUDI.LazyPropagation)
Closest candidates are:
*(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
*(::T, ::T) where T<:Union{Float16, Float32, Float64} at float.jl:385
*(::Union{Float16, Float32, Float64}, ::BigFloat) at mpfr.jl:414
...
Stacktrace:
[1] (::ChainRules.var"#1490#1494"{JUDI.LazyPropagation, Float32, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}})()
@ ChainRules ~/.julia/packages/ChainRules/ajkp7/src/rulesets/Base/arraymath.jl:111
[2] unthunk
@ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
[3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1490#1494"{JUDI.LazyPropagation, Float32, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, ChainRules.var"#1489#1493"{JUDI.LazyPropagation, Float32}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:237
[4] wrap_chainrules_output
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:105 [inlined]
[5] map
@ ./tuple.jl:223 [inlined]
[6] wrap_chainrules_output
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:106 [inlined]
[7] ZBack
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:206 [inlined]
[8] Pullback
@ ./REPL[26]:1 [inlined]
[9] (::typeof(∂(#10)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[10] (::Zygote.var"#60#61"{typeof(∂(#10))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
[11] gradient(f::Function, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
[12] top-level scope
@ REPL[26]:1
[13] top-level scope
@ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
julia> import Base.*;
julia> *(y::Float32, F::JUDI.LazyPropagation) = JUDI.LazyPropagation(F.post, F.F, *(y, F.q));
julia> gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.56 s
Operator `forward` ran in 0.53 s
Operator `gradient` ran in 0.34 s
Operator `forward` ran in 0.43 s
Operator `gradient` ran in 0.35 s
(Float32[-0.081900775 0.07301128 … 6.170804f-6 7.20752f-6; 0.0637427 0.027981473 … 9.756089f-7 5.4272978f-6; … ; 0.06374304 0.027981216 … 9.755976f-7 5.4272914f-6; -0.08189945 0.07301152 … 6.170794f-6 7.2075245f-6],)
Full script below using JUDI
using Flux
using ArgParse, Test, Printf, Aqua
using SegyIO, LinearAlgebra, Distributed, JOLI
using TimerOutputs: TimerOutputs, @timeit
Flux.Random.seed!(2022)
### Model
tti = false
viscoacoustic = false
nsrc = 1
dt = 1f0
include(joinpath(JUDIPATH, "../test/seismic_utils.jl"))
model, model0, dm = setup_model(tti, viscoacoustic, 4)
m, m0 = model.m.data, model0.m.data
q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt)
# Common op
Pr = judiProjection(recGeometry)
Ps = judiProjection(srcGeometry)
ra = false
stype = "Point"
Pq = Ps
opt = Options(return_array=ra, sum_padding=true, f0=f0)
A_inv = judiModeling(model; options=opt)
A_inv0 = judiModeling(model0; options=opt)
# Operators
F = Pr*A_inv*adjoint(Pq)
F0 = Pr*A_inv0*adjoint(Pq)
gs_inv = gradient(x -> norm(F(x)*q), m0)
gs_inv1 = gradient(x -> norm(F(1f0*x)*q), m0)
import Base.*;
*(y::Float32, F::JUDI.LazyPropagation) = JUDI.LazyPropagation(F.post, F.F, *(y, F.q));
gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0) |
That's quite curious indeed i'll see if can figure out what's going on |
Well that's is baaaaaaaad, this is why people don't wanna use Julia for serious stuff. When you do So there is not trivial way out of it except maybe having |
src/TimeModeling/Types/abstract.jl
Outdated
@@ -73,7 +73,15 @@ vec(x::judiMultiSourceVector) = vcat(vec.(x.data)...) | |||
|
|||
time_sampling(ms::judiMultiSourceVector) = [1 for i=1:ms.nsrc] | |||
|
|||
reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N = reshape(vec(ms), dims) | |||
function reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for
failing example
using JUDI
using Flux
using ArgParse, Test, Printf, Aqua
using SegyIO, LinearAlgebra, Distributed, JOLI
using TimerOutputs: TimerOutputs, @timeit
Flux.Random.seed!(2022)
### Model
tti = false
viscoacoustic = false
nsrc = 1
dt = 1f0
include(joinpath(JUDIPATH, "../test/seismic_utils.jl"))
model, model0, dm = setup_model(tti, viscoacoustic, 4)
m, m0 = model.m.data, model0.m.data
q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt)
# Common op
Pr = judiProjection(recGeometry)
Ps = judiProjection(srcGeometry)
opt = Options(sum_padding=true, f0=f0)
A_inv = judiModeling(model; options=opt)
A_inv0 = judiModeling(model0; options=opt)
# Operators
F = Pr*A_inv*adjoint(Ps)
J = judiJacobian(F,q)
dm = vec(m-m0)
gs_inv = gradient(q -> norm(J(q)*dm), q)
ERROR: LoadError: DimensionMismatch: new dimensions (1,) must be consistent with array size 1501
Stacktrace:
[1] (::Base.var"#throw_dmrsa#289")(dims::Tuple{Int64}, len::Int64)
@ Base ./reshapedarray.jl:41
[2] reshape
@ ./reshapedarray.jl:45 [inlined]
[3] reshape
@ ~/.julia/dev/JUDI/src/TimeModeling/Types/abstract.jl:76 [inlined]
[4] reshape(parent::judiVector{Float32, Matrix{Float32}}, shp::Tuple{Base.OneTo{Int64}})
@ Base ./reshapedarray.jl:111
[5] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::JUDI.LazyPropagation)
@ JUDI ~/.julia/dev/JUDI/src/rrules.jl:142
[6] _project
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:184 [inlined]
[7] map(f::typeof(Zygote._project), t::Tuple{judiVector{Float32, Matrix{Float32}}}, s::Tuple{JUDI.LazyPropagation})
@ Base ./tuple.jl:246
[8] gradient(f::Function, args::judiVector{Float32, Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:98
[9] top-level scope
@ ~/.julia/dev/JUDI/test/MFE.jl:33
[10] include(fname::String)
@ Base.MainInclude ./client.jl:476
[11] top-level scope
@ REPL[1]:1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It wasn't failing before what changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example above fails on master branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be more specific:
julia> gs_inv = gradient(() -> norm(J(q)*dm), Flux.params(q))
Operator `born` ran in 0.75 s
Grads(...)
this doesn't fail but
julia> gs_inv = gradient(q -> norm(J(q)*dm), q)
Operator `born` ran in 0.72 s
Operator `born` ran in 0.73 s
ERROR: DimensionMismatch: new dimensions (1,) must be consistent with array size 1501
Stacktrace:
[1] (::Base.var"#throw_dmrsa#289")(dims::Tuple{Int64}, len::Int64)
@ Base ./reshapedarray.jl:41
[2] reshape
@ ./reshapedarray.jl:45 [inlined]
[3] reshape
@ ~/.julia/dev/JUDI/src/TimeModeling/Types/abstract.jl:76 [inlined]
[4] reshape(parent::judiVector{Float32, Matrix{Float32}}, shp::Tuple{Base.OneTo{Int64}})
@ Base ./reshapedarray.jl:111
[5] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::JUDI.LazyPropagation)
@ JUDI ~/.julia/dev/JUDI/src/rrules.jl:142
[6] _project
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:184 [inlined]
[7] map(f::typeof(Zygote._project), t::Tuple{judiVector{Float32, Matrix{Float32}}}, s::Tuple{JUDI.LazyPropagation})
@ Base ./tuple.jl:246
[8] gradient(f::Function, args::judiVector{Float32, Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:98
[9] top-level scope
@ REPL[25]:1
[10] top-level scope
@ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
this fail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum ok, then split into Dims{N}
and Dims{1}
and just an if/else these try/catch are really a bad idea anywhere near Zygote
4c9a276
to
f045df7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few test would be appreciated considering the amount of changes
@@ -130,6 +130,7 @@ end | |||
size(jA::jAdjoint) = (jA.op.n, jA.op.m) | |||
display(P::jAdjoint) = println("Adjoint($(P.op))") | |||
display(P::judiProjection{D}) where D = println("JUDI projection operator $(repr(P.n)) -> $(repr(P.m))") | |||
display(P::judiWavelet{T}) where T = println("JUDI wavelet injected at every grid point") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not quite true, it's just a container for a single time trace there is nothing about "everywhere in space" in it.
src/TimeModeling/Types/abstract.jl
Outdated
@@ -73,7 +73,15 @@ vec(x::judiMultiSourceVector) = vcat(vec.(x.data)...) | |||
|
|||
time_sampling(ms::judiMultiSourceVector) = [1 for i=1:ms.nsrc] | |||
|
|||
reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N = reshape(vec(ms), dims) | |||
function reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It wasn't failing before what changed?
src/rrules.jl
Outdated
Base.collect(F::LazyPropagation) = eval_prop(F) | ||
LazyPropagation(post, F::judiPropagator, q) = LazyPropagation(post, F, q, nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
::Function
src/rrules.jl
Outdated
Base.collect(F::LazyPropagation) = eval_prop(F) | ||
LazyPropagation(post, F::judiPropagator, q) = LazyPropagation(post, F, q, nothing) | ||
LazyPropagation(F::judiPropagator, q) = LazyPropagation(identity, F, q) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
COnstructor LazyPropagation(identity, F, q, nothing)
extra function call not needed
16b3a5d
to
171170f
Compare
Could you enlighten me how (by code or something) you reach the conclusion here #167 (comment) ? I am experiencing issue below and would like to check what went wrong ... julia> gs_inv = gradient(() -> norm(F(1f0*m)*q), Flux.params(m))
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.48 s
Operator `gradient` ran in 0.34 s
Grads(...)
julia> gs_inv = gradient(() -> norm(F(m*1f0)*q), Flux.params(m))
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.53 s
Operator `gradient` ran in 0.34 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.34 s
Grads(...) |
Debug every eval_prop to see which where it's called and what the inputs are. In that other case it was evaluated in dot then you can infer why and check that's undeed the gradient it computes by requesting it as a param |
Not sure where you are in the debug, but I can tell you that's it's not super trivial and the fix will require some proper design to extend it cleanly to this type of case. But i'll leave it to you to at least find what the issue is as an exercise. |
Thanks! Yes I agree this is not simple. I will pick it up some time later this week |
1c8c9af
to
37d2e0c
Compare
0148d00
to
5a2cf8f
Compare
LazyPropagation
; add associated test (which won't pass with the current master)LazyPropagation
now has an attributeval
, which storesF * q
if previously computedreshape
issue for multi source vector -- which can be in size ofnsrc
and also in size ofnsrc * nt * nrec