Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements #263

Merged
merged 9 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci-judi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:

- name: Set julia python
run: |
echo "PYTHON=$(which python3)" >> $GITHUB_ENV
PYTHON=$(which python3) julia -e 'using Pkg;Pkg.add("PyCall");Pkg.build("PyCall")'

- name: Build JUDI
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci-op.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ jobs:

- name: Set julia python
run: |
echo "PYTHON=$(which python3)" >> $GITHUB_ENV
PYTHON=$(which python3) julia -e 'using Pkg;Pkg.add("PyCall");Pkg.build("PyCall")'

- name: Build JUDI
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "JUDI"
uuid = "f3b833dc-6b2e-5b9c-b940-873ed6319979"
authors = ["Philipp Witte, Mathias Louboutin"]
version = "3.4.4"
version = "3.4.5"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down Expand Up @@ -51,4 +51,4 @@ test = ["Aqua", "JLD2", "Printf", "Test", "TimerOutputs", "Flux"]
[weakdeps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
25 changes: 10 additions & 15 deletions deps/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,32 @@ struct DevitoException <: Exception
msg::String
end

python = PyCall.pyprogramname

try
pk = pyimport("pkg_resources")
pk = try
pyimport("pkg_resources")
catch e
Cmd([python, "-m", "pip", "install", "--user", "setuptools"])
run(cmd)
pk = pyimport("pkg_resources")
run(PyCall.python_cmd(`-m pip install --user setuptools`))
pyimport("pkg_resources")
end

################## Devito ##################
# pip command
cmd = Cmd([python, "-m", "pip", "install", "-U", "--user", "devito[extras,tests]>=4.4"])
dvver = "4.8.10"
cmd = PyCall.python_cmd(`-m pip install --user devito\[extras,tests\]\>\=$(dvver)`)

try
dv_ver = split(pk.get_distribution("devito").version, "+")[1]
if cmp(dv_ver, "4.8.7") < 0
@info "Devito version too low, updating to >=4.8.7"
dv_ver = VersionNumber(split(pk.get_distribution("devito").version, "+")[1])
if dv_ver < VersionNumber(dvver)
@info "Devito version too low, updating to >=$(dvver)"
run(cmd)
end
catch e
@info "Devito not installed, installing with PyCall python"
run(cmd)
end


################## Matplotlib ##################
# pip command
cmd = Cmd([python, "-m", "pip", "install", "--user", "matplotlib"])
try
mpl = pyimport("matplotlib")
catch e
run(cmd)
run(PyCall.python_cmd(`-m pip install --user matplotlib`))
end
2 changes: 0 additions & 2 deletions docs/src/helper.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ remove_out_of_bounds_receivers
```@docs
devito_model
setup_grid
pad_sizes
pad_array
remove_padding
convertToCell
process_input_data
Expand Down
6 changes: 4 additions & 2 deletions examples/scripts/fwi_example_2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#

using Statistics, Random, LinearAlgebra
using JUDI, SlimOptim, HDF5, SegyIO, PyPlot
using JUDI, HDF5, SegyIO, SlimOptim, SlimPlotting

# Load starting model
n,d,o,m0 = read(h5open("$(JUDI.JUDI_DATA)/overthrust_model.h5","r"), "n", "d", "o", "m0")
Expand Down Expand Up @@ -66,4 +66,6 @@ for j=1:niterations
model0.m .= proj(model0.m .+ step .* p)
end

figure(); imshow(sqrt.(1f0./adjoint(model0.m))); title("FWI with SGD")
figure()
plot_velocity(model0.m'.^(-.5))
title("FWI with SGD")
53 changes: 14 additions & 39 deletions examples/scripts/modeling_basic_2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
#' This example is converted to a markdown file for the documentation.

#' # Import JUDI, Linear algebra utilities and Plotting
using JUDI, PyPlot, LinearAlgebra
using JUDI, LinearAlgebra, SlimPlotting

#+ echo = false; results = "hidden"
close("all")
imcmap = "cet_CET_L1"
dcmap = "PuOr"

#' # Create a JUDI model structure
#' In JUDI, a `Model` structure contains the grid information (origin, spacing, number of gridpoints)
Expand Down Expand Up @@ -91,7 +93,7 @@ q = judiVector(srcGeometry, wavelet)
#' condition for the propagation.

# Setup options
opt = Options(subsampling_factor=2, space_order=32)
opt = Options(subsampling_factor=2, space_order=16, free_surface=false)

#' Linear Operators
#' The core idea behind JUDI is to abstract seismic inverse problems in term of linear algebra. In its simplest form, seismic inversion can be formulated as
Expand Down Expand Up @@ -119,10 +121,7 @@ dobs = Pr*F*adjoint(Ps)*q

#' Plot the shot record
fig = figure()
imshow(dobs.data[1], vmin=-1, vmax=1, cmap="PuOr", extent=[xrec[1], xrec[end], timeD/1000, 0], aspect="auto")
xlabel("Receiver position (m)")
ylabel("Time (s)")
title("Synthetic data")
plot_sdata(dobs[1]; new_fig=false, name="Synthetic data", cmap=dcmap)
display(fig)

#' Because we have abstracted the linear algebra, we can solve the adjoint wave-equation as well
Expand Down Expand Up @@ -152,19 +151,13 @@ rtm = adjoint(J)*dD

#' We show the linearized data.
fig = figure()
imshow(dD.data[1], vmin=-1, vmax=1, cmap="PuOr", extent=[xrec[1], xrec[end], timeD/1000, 0], aspect="auto")
xlabel("Receiver position (m)")
ylabel("Time (s)")
title("Linearized data")
plot_sdata(dobs[1]; new_fig=false, name="Linearized data", cmap=dcmap)
display(fig)


#' And the RTM image
fig = figure()
imshow(rtm', vmin=-1e2, vmax=1e2, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("RTM image")
plot_simage(rtm'; new_fig=false, name="RTM image", cmap=imcmap)
display(fig)

#' ## Inversion utility functions
Expand All @@ -185,10 +178,7 @@ f, g = fwi_objective(model0, q, dobs; options=opt)

#' Plot gradient
fig = figure()
imshow(g', vmin=-1e2, vmax=1e2, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("FWI gradient")
plot_simage(g'; new_fig=false, name="FWI gradient", cmap=imcmap)
display(fig)


Expand All @@ -199,17 +189,11 @@ fjn, gjn = lsrtm_objective(model0, q, dobs, dm; nlind=true, options=opt)

#' Plot gradients
fig = figure()
imshow(gj', vmin=-1, vmax=1, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("LSRTM gradient")
plot_simage(gj'; new_fig=false, name="LSRTM gradient", cmap=imcmap, cbar=true)
display(fig)

fig = figure()
imshow(gjn', vmin=-1, vmax=1, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("LSRTM gradient with background data substracted")
plot_simage(gjn'; new_fig=false, name="LSRTM gradient with background data substracted", cmap=imcmap, cbar=true)
display(fig)

#' By extension, lsrtm_objective is the same as fwi_objecive when `dm` is zero
Expand All @@ -218,13 +202,10 @@ display(fig)
#' OMP_NUM_THREADS=1 (no parllelism) produces the exact (difference == 0) same result
#' gjn2 == g
fjn2, gjn2 = lsrtm_objective(model0, q, dobs, 0f0.*dm; nlind=true, options=opt)
fig = figure()

#' Plot gradient
imshow(gjn2', vmin=-1e2, vmax=1e2, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("LSRTM gradient with zero perturbation")
fig = figure()
plot_simage(gjn2'; new_fig=false, name="LSRTM gradient with zero perturbation", cmap=imcmap)
display(fig)


Expand All @@ -236,15 +217,9 @@ f, gmf = twri_objective(model0, q, dobs, nothing; options=Options(frequencies=[[

#' Plot gradients
fig = figure()
imshow(gm', vmin=-1, vmax=1, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("TWRI gradient w.r.t m")
plot_simage(gm'; new_fig=false, name="TWRI gradient w.r.t m", cmap=imcmap)
display(fig)

fig = figure()
imshow(gy.data[1], vmin=-1e2, vmax=1e2, cmap="PuOr", extent=[xrec[1], xrec[end], timeD/1000, 0], aspect="auto")
xlabel("Receiver position (m)")
ylabel("Time (s)")
title("TWRI gradient w.r.t y")
plot_sdata(gy[1]; new_fig=false, name="TWRI gradient w.r.t y", cmap=dcmap)
display(fig)
8 changes: 3 additions & 5 deletions src/JUDI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ module JUDI
export JUDIPATH, set_verbosity, ftp_data, get_serial, set_serial, set_parallel
JUDIPATH = dirname(pathof(JUDI))


# Only needed if extension not available (julia < 1.9)
if !isdefined(Base, :get_extension)
using Requires
Expand Down Expand Up @@ -102,10 +101,12 @@ function _worker_pool()
return nothing
end
p = default_worker_pool()
pool = length(p) < 2 ? nothing : p
pool = nworkers(p) < 2 ? nothing : p
return pool
end

nworkers(::Any) = length(workers())

_TFuture = Future
_verbose = false
_devices = []
Expand Down Expand Up @@ -178,9 +179,6 @@ function __init__()
copy!(devito, pyimport("devito"))
# Initialize lock at session start
PYLOCK[] = ReentrantLock()

# Prevent autopadding to use external allocator
set_devito_config("autopadding", false)

# Make sure there is no conflict for the cuda init thread with CUDA.jl
if get(ENV, "DEVITO_PLATFORM", "") == "nvidiaX"
Expand Down
27 changes: 20 additions & 7 deletions src/TimeModeling/Modeling/distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ end
x
end

"""
safe_gc()

Generic GC, compatible with different julia versions of it.
"""
safe_gc() = try Base.GC.gc(); catch; gc() end

"""
local_reduce!(future, other)
Expand Down Expand Up @@ -64,9 +58,28 @@ Adapted from `DistributedOperations.jl` (MIT license). Striped from custom types
with different reduction functions.
"""
function reduce!(futures::Vector{_TFuture})
isnothing(_worker_pool()) && return reduce_all_workers!(futures)
# Number of parallel workers
nwork = nworkers(_worker_pool())
nf = length(futures)
# Reduction batch. We want to avoid finished task to hang waiting for the
# binary tree reduction to reach their index holding memory.
bsize = min(nwork, nf)
# First batch
res = reduce_all_workers!(futures[1:bsize])
# Loop until all reduced
for i = bsize+1:bsize:nf
last = min(nf, i + bsize - 1)
single_reduce!(res, reduce_all_workers!(futures[i:last]))
end
return res
end


function reduce_all_workers!(futures::Vector{_TFuture})
# Get length a next power of two for binary reduction
M = length(futures)
L = round(Int,log2(prevpow(2,M)))
L = round(Int, log2(prevpow(2,M)))
m = 2^L
# remainder
R = M - m
Expand Down
2 changes: 2 additions & 0 deletions src/TimeModeling/Modeling/misfit_fg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ function _multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes,
data_precon=nothing, model_precon=LinearAlgebra.I)
GC.gc(true)
devito.clear_cache()

# assert this is for single source LSRTM
@assert source.nsrc == 1 "Multiple sources are used in a single-source fwi_objective"
@assert dObs.nsrc == 1 "Multiple-source data is used in a single-source fwi_objective"
Expand Down Expand Up @@ -63,6 +64,7 @@ function _multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes,

length(options.frequencies) == 0 ? freqs = nothing : freqs = options.frequencies
IT = illum ? (PyArray, PyArray) : (PyObject, PyObject)

@juditime "Python call to J_adjoint" begin
argout = rlock_pycall(ac."J_adjoint", Tuple{Float32, PyArray, IT...}, modelPy,
src_coords, qIn, rec_coords, dObserved, t_sub=options.subsampling_factor,
Expand Down
20 changes: 4 additions & 16 deletions src/TimeModeling/Modeling/propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ the pool is empty, a standard loop and accumulation is ran. If the pool is a jul
any custom Distributed pool, the loop is distributed via `remotecall` followed by are binary tree remote reduction.
"""
function run_and_reduce(func, pool, nsrc, arg_func::Function; kw=nothing)
# Allocate devices
_set_devices!()
# Run distributed loop
res = Vector{_TFuture}(undef, nsrc)
for i = 1:nsrc
Expand All @@ -44,21 +42,11 @@ function run_and_reduce(func, ::Nothing, nsrc, arg_func::Function; kw=nothing)
kw_loc = isnothing(kw) ? Dict() : kw(i)
next = func(arg_func(i)...; kw_loc...)
end
single_reduce!(out, next)
end
out
end

function _set_devices!()
ndevices = length(_devices)
if ndevices < 2
return
end
asyncmap(enumerate(workers())) do (pi, p)
remotecall_wait(p) do
pyut.set_device_ids(_devices[pi % ndevices + 1])
@juditime "Reducting $(func) for src $(i)" begin
single_reduce!(out, next)
end
end
out
end

_prop_fw(::judiPropagator{T, O}) where {T, O} = true
Expand Down Expand Up @@ -112,7 +100,7 @@ function multi_src_fg!(G, model, q, dobs, dm; options=Options(), ms_func=multi_s
kw_func = i -> Dict(:illum=> illum, Dict(k => kw_i(v, i) for (k, v) in kw)...)
# Distribute source
res = run_and_reduce(ms_func, pool, nsrc, arg_func; kw=kw_func)
f, g = update_illum(res, model, :adjoint_born)
res = update_illum(res, model, :adjoint_born)
f, g = as_vec(res, Val(options.return_array))
G .+= g
return f
Expand Down
2 changes: 1 addition & 1 deletion src/TimeModeling/Modeling/twri_objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function _twri_objective(model_full::AbstractModel, source::judiVector, dObs::ju
dtComp = convert(Float32, modelPy."critical_dt")

# Extrapolate input data to computational grid
qIn = time_resample(source.data[1], source.geometry, dtComp)
qIn = time_resample(make_input(source), source.geometry, dtComp)
dObserved = time_resample(make_input(dObs), dObs.geometry, dtComp)

if isnothing(y)
Expand Down
Loading
Loading