Skip to content

Commit

Permalink
Merge pull request #198 from slimgroup/bcast-fix
Browse files Browse the repository at this point in the history
types: fix vector broadcasting
  • Loading branch information
mloubout committed Aug 4, 2023
2 parents 6004a63 + 9ad678f commit 77ca35e
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "JUDI"
uuid = "f3b833dc-6b2e-5b9c-b940-873ed6319979"
authors = ["Philipp Witte, Mathias Louboutin"]
version = "3.3.6"
version = "3.3.7"

This comment has been minimized.

Copy link
@mloubout

mloubout Aug 4, 2023

Author Member

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
6 changes: 3 additions & 3 deletions src/TimeModeling/LinearOperators/callable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ end
function (F::judiPropagator)(;kwargs...)
Fl = deepcopy(F)
for (k, v) in kwargs
k in _mparams(Fl.model) && getfield(Fl.model, k) .= v
k in _mparams(Fl.model) && getfield(Fl.model, k) .= reshape(v, size(Fl.model))
end
Fl
end
Expand All @@ -29,7 +29,7 @@ end

function (F::judiPropagator)(m::AbstractArray)
@info "Assuming m to be squared slowness for $(typeof(F))"
return F(;m=m)
return F(;m=reshape(m, size(F.model)))
end

(F::judiPropagator)(m::AbstractModel, q) = F(m)*as_src(q)
Expand All @@ -42,7 +42,7 @@ end

function (J::judiJacobian{D, O, FT})(x::Array{D, N}) where {D, O, FT, N}
if length(x) == prod(size(J.model))
return J(;m=m)
return J(;m=reshape(x, size(F.model.n)))
end
new_q = _as_src(J.qInjection.op, J.model, x)
newJ = judiJacobian{D, O, FT}(J.m, J.n, J.F, new_q)
Expand Down
8 changes: 4 additions & 4 deletions src/TimeModeling/Types/ModelStructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ function materialize!(A::PhysicalParameter{T, N}, ev::PhysicalParameter{T, N}) w
end

materialize!(A::PhysicalParameter{T, N}, B::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{PhysicalParameter}}) where {T<:Real, N} = materialize!(A, B.f(B.args...))
materialize!(A::PhysicalParameter{T, N}, B::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{N}}) where {T<:Real, N} = materialize!(A.data, reshape(materialize(B), A.n))
materialize!(A::PhysicalParameter{T, N}, B::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{Na}}) where {T<:Real, N, Na} = materialize!(A.data, reshape(materialize(B), A.n))
materialize!(A::AbstractArray{T, N}, B::Broadcast.Broadcasted{Broadcast.ArrayStyle{PhysicalParameter}}) where {T<:Real, N} = materialize!(A, reshape(materialize(B).data, size(A)))

for op in [:+, :-, :*, :/, :\]
Expand Down Expand Up @@ -400,9 +400,9 @@ function Model(d, o, m::Array{mT, N}; epsilon=nothing, delta=nothing, theta=noth
return IsoModel{T, N}(G, m, rho)
end

Model(n, d, o, m::Array, rho::Array; nb=40) = Model(d, o, m; rho=rho, nb=nb)
Model(n, d, o, m::Array, rho::Array, qp::Array; nb=40) = Model(d, o, m; rho=rho, qp=qp, nb=nb)
Model(n, d, o, m::Array; kw...) = Model(d, o, m; kw...)
Model(n, d, o, m::Array, rho::Array; nb=40) = Model(d, o, reshape(m, n...); rho=reshape(rho, n...), nb=nb)
Model(n, d, o, m::Array, rho::Array, qp::Array; nb=40) = Model(d, o, reshape(m, n...); rho=reshape(rho, n...), qp=reshape(qp, n...), nb=nb)
Model(n, d, o, m::Array; kw...) = Model(d, o, reshape(m, n...); kw...)

size(m::MT) where {MT<:AbstractModel} = size(m.G)
origin(m::MT) where {MT<:AbstractModel} = origin(m.G)
Expand Down
56 changes: 32 additions & 24 deletions test/test_linear_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ model = example_model()
test_getindex(F_adjoint, nsrc)

if VERSION>v"1.2"
a = randn(Float32, model.n...)
F2 = F_forward(;m=a)
@test isapprox(F2.model.m, a)
F2 = F_forward(Model(model.n, model.d, model.o, a))
@test isapprox(F2.model.m, a)
@test F2.model.n == model.n
a0 = a = randn(Float32, model.n...)
for a in [a0, a0[:]]
F2 = F_forward(;m=a)
@test isapprox(F2.model.m, a0)
F2 = F_forward(Model(model.n, model.d, model.o, a))
@test isapprox(F2.model.m, a0)
@test F2.model.n == model.n
end
end
end
end
Expand Down Expand Up @@ -128,12 +130,14 @@ end
test_getindex(F_adjoint, nsrc)

if VERSION>v"1.2"
a = randn(Float32, model.n...)
F2 = F_forward(;m=a)
@test isapprox(F2.model.m, a)
F2 = F_forward(Model(model.n, model.d, model.o, a))
@test isapprox(F2.model.m, a)
@test F2.model.n == model.n
a0 = a = randn(Float32, model.n...)
for a in [a0, a0[:]]
F2 = F_forward(;m=a)
@test isapprox(F2.model.m, a0)
F2 = F_forward(Model(model.n, model.d, model.o, a))
@test isapprox(F2.model.m, a0)
@test F2.model.n == model.n
end
end

# SimSources
Expand Down Expand Up @@ -180,12 +184,14 @@ end
test_getindex(F_adjoint, nsrc)

if VERSION>v"1.2"
a = randn(Float32, model.n...)
F2 = F_forward(;m=a)
@test isapprox(F2.model.m, a)
F2 = F_forward(Model(model.n, model.d, model.o, a))
@test isapprox(F2.model.m, a)
@test F2.model.n == model.n
a0 = a = randn(Float32, model.n...)
for a in [a0, a0[:]]
F2 = F_forward(;m=a)
@test isapprox(F2.model.m, a0)
F2 = F_forward(Model(model.n, model.d, model.o, a))
@test isapprox(F2.model.m, a0)
@test F2.model.n == model.n
end
end

# SimSources
Expand Down Expand Up @@ -236,12 +242,14 @@ end
test_getindex(F_adjoint, nsrc)

if VERSION>v"1.2"
a = randn(Float32, model.n...)
F2 = F_forward(;m=a)
@test isapprox(F2.model.m, a)
F2 = F_forward(Model(model.n, model.d, model.o, a))
@test isapprox(F2.model.m, a)
@test F2.model.n == model.n
a0 = a = randn(Float32, model.n...)
for a in [a0, a0[:]]
F2 = F_forward(;m=a)
@test isapprox(F2.model.m, a0)
F2 = F_forward(Model(model.n, model.d, model.o, a))
@test isapprox(F2.model.m, a0)
@test F2.model.n == model.n
end
end

# SimSources
Expand Down
6 changes: 6 additions & 0 deletions test/test_physicalparam.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ ftol = 1f-5
@test u.data[1:10] == 1:10
@test u[11] == 0f0

u .= u.data[:]
@test norm(u, 1) == 55
@test u[1:10] == 1:10
@test u.data[1:10] == 1:10
@test u[11] == 0f0

if nd == 2
tmp = randn(Float32, u[1:10, :].n)
u[1:10, :] .= tmp
Expand Down

1 comment on commit 77ca35e

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/89031

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.3.7 -m "<description of version>" 77ca35e9246fc6bc60a035236b0d85ba3f4d2f60
git push origin v3.3.7

Please sign in to comment.