Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - Reentrant concurrent snoopi_deep profiles. #309

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 183 additions & 10 deletions SnoopCompileCore/src/snoopi_deep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,28 +70,201 @@ function addchildren!(parent::InferenceTimingNode, t::Core.Compiler.Timings.Timi
end
end

module SnoopiDeepParallelism

# Mutex ordering: MUTEX > jl_typeinf_lock
const MUTEX = ReentrantLock()

mutable struct Invocation
# start_idx is mutated when older invocations are deleted and the profile is shifted.
start_idx::Int
stop_idx::Int
start_time::UInt64
root_start_excl_time::UInt64
root_stop_excl_time::UInt64
end
function Invocation(start_idx, start_root_excl_time)
# Start at the current time.
return Invocation(start_idx, 0, time_ns(), start_root_excl_time, 0)
end

"""
Global (locked) vector tracking running snoopi calls, and when they started.
- When one finishes, we lock(inference), export results, clear the inference profiles up to
the next oldest snoopi call, then unlock(inference).

Imagine this is an ongoing inference profile, where each letter is another inference profile
result, and we start two profiles, 1 and 2, at the times indicated below:
ABCDEFGHIJKLMNOPQRSTUVWX
1> 2> <1 <2

- invocations: [(1,A), (2,D)]
- 1 ends:
copy out ABCDEFGHIJKLMNOPQRSTU
pop (1,A) from invocations
read oldest invocation: (2,D)
delete up to D.
- New profile:
DEFGHIJKLMNOPQRSTUVWX
2> <2

- 2 ends:
copy out DEFGHIJKLMNOPQRSTUVWX
pop (2,D) from invocations
no active invocations, so ...
... delete up to X (end of this profile).
"""
const invocations = Invocation[]

function _current_profile_stats_locked()
ccall(:jl_typeinf_lock_begin, Cvoid, ())
try
inference_root_timing = Core.Compiler.Timings._timings[1]
children = inference_root_timing.children
# Since we were able to grab the lock, we must not be in an inference profile,
# meaning we are in ROOT(). So to get an accurate ROOT timing, we have to add the
# accumulated time since the ROOT was last updated:
accum_root_time = time_ns() - inference_root_timing.cur_start_time
current_root_time = inference_root_timing.time + accum_root_time
Comment on lines +127 to +128
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit hazy on these details by now but trust that it is correct 👍


return length(children), current_root_time
finally
ccall(:jl_typeinf_lock_end, Cvoid, ())
end
end

function _fetch_profile_buffer_locked(start_idx, stop_idx)
ccall(:jl_typeinf_lock_begin, Cvoid, ())
try
inference_root_timing = Core.Compiler.Timings._timings[1]
children = inference_root_timing.children
return children[start_idx:stop_idx]
finally
ccall(:jl_typeinf_lock_end, Cvoid, ())
end
end

function start_timing_invocation()
# Locking respects mutex ordering.
Base.@lock MUTEX begin
current_profile_length, current_root_time = _current_profile_stats_locked()
profile_start_idx = current_profile_length + 1
invocation = Invocation(profile_start_idx, current_root_time)
push!(invocations, invocation)
return invocation
end
end

function stop_timing_invocation!(invocation)
invocation.stop_idx, invocation.root_stop_excl_time = _current_profile_stats_locked()
end

function finish_timing_invocation_and_clear_profile(invocation)
# Locking respects mutex ordering.
Base.@lock MUTEX begin
# Check if this invocation was the oldest. If so, we'll want to clear the parts of
# the profile only it was using.
if invocations[1] !== invocation
idx = findfirst(==(invocation), invocations)
@assert idx !== nothing "invocation wasn't found in invocations: $invocation."
deleteat!(invocations, idx)
return
end

# Clear this invocation from the invocations vector.
popfirst!(invocations)

# Now clear the global inference profile up to the start of the next invocation.
# If no next invocations, clear them all.
if isempty(invocations)
ccall(:jl_typeinf_lock_begin, Cvoid, ())
try
Core.Compiler.Timings.reset_timings()
finally
ccall(:jl_typeinf_lock_end, Cvoid, ())
end
return
end

# Else, we stop at the next oldest invocation.
next_oldest = invocations[1]
start_idx = next_oldest.start_idx
to_delete = start_idx - 1
if to_delete == 0
return
end
# Shift back the indices for all the running invocations
for running_invocation in invocations
running_invocation.start_idx -= to_delete
running_invocation.stop_idx -= to_delete
end
# Clear the profile up to the start of the new oldest invocation.
ccall(:jl_typeinf_lock_begin, Cvoid, ())
try
inference_root_timing = Core.Compiler.Timings._timings[1]
children = inference_root_timing.children
deleteat!(children, 1:to_delete)
finally
ccall(:jl_typeinf_lock_end, Cvoid, ())
end
end
end

end # module

function start_deep_timing()
Core.Compiler.Timings.reset_timings()
invocation = SnoopiDeepParallelism.start_timing_invocation()
Core.Compiler.__set_measure_typeinf(true)
return invocation
end
function stop_deep_timing()
function stop_deep_timing!(invocation)
Core.Compiler.__set_measure_typeinf(false)
Core.Compiler.Timings.close_current_timer()
return SnoopiDeepParallelism.stop_timing_invocation!(invocation)
end

function finish_snoopi_deep()
return InferenceTimingNode(Core.Compiler.Timings._timings[1])
function finish_snoopi_deep(invocation)
buffer = SnoopiDeepParallelism._fetch_profile_buffer_locked(invocation.start_idx, invocation.stop_idx)

# Clean up the profile buffer, so that we don't leak memory.
SnoopiDeepParallelism.finish_timing_invocation_and_clear_profile(invocation)

root_node = _create_finished_ROOT_Timing(invocation, buffer)
return InferenceTimingNode(root_node)
end

# The MethodInstance for ROOT(), and default empty values for other fields.
# Copied from julia typeinf
root_inference_frame_info() =
Core.Compiler.Timings.InferenceFrameInfo(Core.Compiler.Timings.ROOTmi, 0x0, Any[], Any[Core.Const(Core.Compiler.Timings.ROOT)], 1)

function _create_finished_ROOT_Timing(invocation, buffer)
total_time = time_ns() - invocation.start_time

# Create a new ROOT() node, specific to this profiling invocation, which wraps the
# current profile buffer, and contains the total time for the profile.
return Core.Compiler.Timings.Timing(
root_inference_frame_info(),
invocation.start_time,
0,
# Total exclusive time spent in ROOT during the lifetime of this node.
invocation.root_stop_excl_time - invocation.root_start_excl_time,
# Use the copied-out section of the profile buffer as the children of ROOT()
buffer,
)
end



function _snoopi_deep(cmd::Expr)
return quote
start_deep_timing()
invocation = start_deep_timing()
try
$(esc(cmd))
finally
stop_deep_timing()
stop_deep_timing!(invocation)
end
finish_snoopi_deep()
# return the timing result:
finish_snoopi_deep(invocation)
end
end

Expand Down Expand Up @@ -134,5 +307,5 @@ end
# These are okay to come at the top-level because we're only measuring inference, and
# inference results will be cached in a `.ji` file.
precompile(start_deep_timing, ())
precompile(stop_deep_timing, ())
precompile(finish_snoopi_deep, ())
precompile(stop_deep_timing!, (SnoopiDeepParallelism.Invocation,))
precompile(finish_snoopi_deep, (SnoopiDeepParallelism.Invocation,))
158 changes: 158 additions & 0 deletions test/snoopi_deep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,7 @@ end
# pgdsgui(axs[2], rit; bystr="Inclusive", consts=true, interactive=false)
end


Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Snuck in? :)

Suggested change

@testset "Stale" begin
cproj = Base.active_project()
cd(joinpath("testmodules", "Stale")) do
Expand Down Expand Up @@ -944,6 +945,163 @@ end
Pkg.activate(cproj)
end

_name(frame::SnoopCompileCore.InferenceTiming) = frame.mi_info.mi.def.name

@testset "reentrant concurrent profiles 1 - overlap" begin
# Warmup
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps something like the following? :)

Suggested change
# Warmup
# Warmup to prevent, e.g., `+` from appearing in the profiles; only `ROOT` and `foo*` should appear.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Likewise on the warmup comments below if you agree.)

@eval foo1(x) = x+2
@eval foo1(2)

# Test:
t1 = SnoopCompileCore.start_deep_timing()

@eval foo1(x) = x+2
@eval foo1(2)

t2 = SnoopCompileCore.start_deep_timing()

@eval foo2(x) = x+2
@eval foo2(2)

SnoopCompileCore.stop_deep_timing!(t1)
SnoopCompileCore.stop_deep_timing!(t2)

prof1 = SnoopCompileCore.finish_snoopi_deep(t1)
prof2 = SnoopCompileCore.finish_snoopi_deep(t2)

@test Set(_name.(SnoopCompile.flatten(prof1))) == Set([:ROOT, :foo1, :foo2])
@test Set(_name.(SnoopCompile.flatten(prof2))) == Set([:ROOT, :foo2])

# Test Cleanup
@test isempty(SnoopCompileCore.SnoopiDeepParallelism.invocations)
@test isempty(Core.Compiler.Timings._timings[1].children)
end

@testset "reentrant concurrent profiles 2 - interleaved" begin
# Warmup
@eval foo1(x) = x+2
@eval foo1(2)

# Test:
t1 = SnoopCompileCore.start_deep_timing()

@eval foo1(x) = x+2
@eval foo1(2)

t2 = SnoopCompileCore.start_deep_timing()

@eval foo2(x) = x+2
@eval foo2(2)

SnoopCompileCore.stop_deep_timing!(t1)

@eval foo3(x) = x+2
@eval foo3(2)

SnoopCompileCore.stop_deep_timing!(t2)

@eval foo4(x) = x+2
@eval foo4(2)

prof1 = SnoopCompileCore.finish_snoopi_deep(t1)

@eval foo5(x) = x+2
@eval foo5(2)

prof2 = SnoopCompileCore.finish_snoopi_deep(t2)

@test Set(_name.(SnoopCompile.flatten(prof1))) == Set([:ROOT, :foo1, :foo2])
@test Set(_name.(SnoopCompile.flatten(prof2))) == Set([:ROOT, :foo2, :foo3])

# Test Cleanup
@test isempty(SnoopCompileCore.SnoopiDeepParallelism.invocations)
@test isempty(Core.Compiler.Timings._timings[1].children)
end

@testset "reentrant concurrent profiles 3 - nested" begin
# Warmup
@eval foo1(x) = x+2
@eval foo1(2)

# Test:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That this works is thoroughly pleasing.

local prof1, prof2, prof3
prof1 = SnoopCompileCore.@snoopi_deep begin
@eval foo1(x) = x+2
@eval foo1(2)
prof2 = SnoopCompileCore.@snoopi_deep begin
@eval foo2(x) = x+2
@eval foo2(2)
prof3 = SnoopCompileCore.@snoopi_deep begin
@eval foo3(x) = x+2
@eval foo3(2)
end
@eval foo4(x) = x+2
@eval foo4(2)
end
@eval foo5(x) = x+2
@eval foo5(2)
end

@test Set(_name.(SnoopCompile.flatten(prof1))) == Set([:ROOT, :foo1, :foo2, :foo3, :foo4, :foo5])
@test Set(_name.(SnoopCompile.flatten(prof2))) == Set([:ROOT, :foo2, :foo3, :foo4])
@test Set(_name.(SnoopCompile.flatten(prof3))) == Set([:ROOT, :foo3])

# Test Cleanup
@test isempty(SnoopCompileCore.SnoopiDeepParallelism.invocations)
@test isempty(Core.Compiler.Timings._timings[1].children)
end

@testset "reentrant concurrent profiles 3 - parallelism + accurate timing" begin
# Warmup
@eval foo1(x) = x+2
@eval foo1(2)

# Test:
local ts
snoop_times = Float64[0.0, 0.0, 0.0, 0.0]
# Run it twice to ensure we warmup the eval block
for _ in 1:2
@sync begin
ts = [
Threads.@spawn begin
sleep((i-1) / 10) # (Divide by 10 so the test isn't too slow)
snoop_time = @timed SnoopCompile.@snoopi_deep @eval begin
$(Symbol("foo$i"))(x) = x + 1
sleep(1.5 / 10)
Comment on lines +1067 to +1070
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope that these delays are large enough --- and think they should be by orders of magnitude --- to prevent nondeterministic failures due to jitter. But I've found the machines are just lying in wait, for the moment I think something like that. Then 🦈 🩸.

$(Symbol("foo$i"))(2)
end
snoop_times[i] = snoop_time.time
return snoop_time.value
end
for i in 1:4
]
end
end
profs = fetch.(ts)

@test Set(_name.(SnoopCompile.flatten(profs[1]))) == Set([:ROOT, :foo1])
@test Set(_name.(SnoopCompile.flatten(profs[2]))) == Set([:ROOT, :foo1, :foo2])
@test Set(_name.(SnoopCompile.flatten(profs[3]))) == Set([:ROOT, :foo2, :foo3])
@test Set(_name.(SnoopCompile.flatten(profs[4]))) == Set([:ROOT, :foo3, :foo4])

# Test the sanity of the reported Timings
@testset for i in eachindex(profs)
prof = profs[i]
# Test that the time for the inference is accounted for
@test 0.15 < prof.mi_timing.exclusive_time
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fear the definitions of these times has already become hazy to me 😅. This one? 🤔How do know that the exclusive rather than inclusive time should be over 0.15 seconds?

@test prof.mi_timing.exclusive_time < prof.mi_timing.inclusive_time
# Test that the inclusive time (the total time reported by snoopi_deep) matches
# the actual time to do the snoopi_deep, as measured by `@time`.
# These should both be approximately ~0.15 seconds.
@info prof.mi_timing.inclusive_time
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debugging straggler? :)

@test prof.mi_timing.inclusive_time <= snoop_times[i]
end

# Test Cleanup
@test isempty(SnoopCompileCore.SnoopiDeepParallelism.invocations)
@test isempty(Core.Compiler.Timings._timings[1].children)
end

if Base.VERSION >= v"1.7"
@testset "JET integration" begin
f(c) = sum(c[1])
Expand Down