Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use libkineto in profiler #46470

Closed
wants to merge 91 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
a4d4124
Use libkineto in profiler
ilia-cher Oct 16, 2020
e27f74c
Update on "Use libkineto in profiler"
ilia-cher Oct 27, 2020
5c3833e
Update on "Use libkineto in profiler"
ilia-cher Nov 2, 2020
662431b
Update on "Use libkineto in profiler"
ilia-cher Nov 2, 2020
ea956aa
Update on "Use libkineto in profiler"
ilia-cher Nov 2, 2020
7dfdbc9
Update on "Use libkineto in profiler"
ilia-cher Nov 2, 2020
6725778
Update on "Use libkineto in profiler"
ilia-cher Nov 2, 2020
e9a219b
Update on "Use libkineto in profiler"
ilia-cher Nov 2, 2020
49a9fee
Update on "Use libkineto in profiler"
ilia-cher Nov 2, 2020
8edb346
Update on "Use libkineto in profiler"
ilia-cher Nov 2, 2020
f288623
Update on "Use libkineto in profiler"
ilia-cher Nov 2, 2020
979cdfa
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
c8cbeb0
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
226089c
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
266b75f
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
6958eac
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
97e5070
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
8d111d2
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
bfb0360
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
1ff1a12
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
b3b69d8
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
2faeb8a
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
67c890d
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
ed8babe
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
ffc11fd
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
fe76b84
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
76ee80c
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
5761ea2
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
dde5ec3
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
3a25bd2
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
6023998
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
0bc66a6
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
aa2d09e
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
91718ac
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
1556a7c
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
4a0fec9
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
bb6396a
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
60b5dee
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
e1a5480
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
38a37dd
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
c6c6039
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
17767d1
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
3537e9d
Update on "Use libkineto in profiler"
ilia-cher Nov 3, 2020
043dcd2
Update on "Use libkineto in profiler"
ilia-cher Nov 4, 2020
aa17339
Update on "Use libkineto in profiler"
ilia-cher Nov 4, 2020
9262f92
Update on "Use libkineto in profiler"
ilia-cher Nov 11, 2020
8371b33
Update on "Use libkineto in profiler"
ilia-cher Nov 11, 2020
67d4acb
Update on "Use libkineto in profiler"
ilia-cher Nov 11, 2020
9f1d24f
Update on "Use libkineto in profiler"
ilia-cher Nov 11, 2020
e864205
Update on "Use libkineto in profiler"
ilia-cher Nov 11, 2020
380b874
Update on "Use libkineto in profiler"
ilia-cher Nov 11, 2020
165bb7c
Update on "Use libkineto in profiler"
ilia-cher Nov 11, 2020
445b8c1
Update on "Use libkineto in profiler"
ilia-cher Nov 12, 2020
7c317f5
Update on "Use libkineto in profiler"
ilia-cher Nov 12, 2020
c904443
Update on "Use libkineto in profiler"
ilia-cher Nov 12, 2020
1f600f8
Update on "Use libkineto in profiler"
ilia-cher Nov 12, 2020
5aacc1c
Update on "Use libkineto in profiler"
ilia-cher Nov 12, 2020
651f556
Update on "Use libkineto in profiler"
ilia-cher Nov 12, 2020
9997011
Update on "Use libkineto in profiler"
ilia-cher Nov 13, 2020
cfd0424
Update on "Use libkineto in profiler"
ilia-cher Nov 13, 2020
30114d8
Update on "Use libkineto in profiler"
ilia-cher Nov 13, 2020
27e4e9c
Update on "Use libkineto in profiler"
ilia-cher Nov 13, 2020
bde96f6
Update on "Use libkineto in profiler"
ilia-cher Nov 13, 2020
b1a0292
Update on "Use libkineto in profiler"
ilia-cher Nov 13, 2020
b7fda07
Update on "Use libkineto in profiler"
ilia-cher Nov 13, 2020
459df8e
Update on "Use libkineto in profiler"
ilia-cher Nov 17, 2020
09a4762
Update on "Use libkineto in profiler"
ilia-cher Nov 17, 2020
cafee0f
Update on "Use libkineto in profiler"
ilia-cher Nov 17, 2020
39ff2b3
Update on "Use libkineto in profiler"
ilia-cher Nov 17, 2020
5502837
Update on "Use libkineto in profiler"
ilia-cher Nov 17, 2020
7c2017b
Update on "Use libkineto in profiler"
ilia-cher Nov 20, 2020
525e5b5
Update on "Use libkineto in profiler"
ilia-cher Nov 20, 2020
1f50e4b
Update on "Use libkineto in profiler"
ilia-cher Nov 20, 2020
f70a95c
Update on "Use libkineto in profiler"
ilia-cher Nov 20, 2020
5fed8be
Update on "Use libkineto in profiler"
ilia-cher Nov 20, 2020
2494879
Update on "Use libkineto in profiler"
ilia-cher Nov 20, 2020
4f401ff
Update on "Use libkineto in profiler"
ilia-cher Nov 20, 2020
c689e6b
Update on "Use libkineto in profiler"
ilia-cher Nov 21, 2020
4a5632f
Update on "Use libkineto in profiler"
ilia-cher Nov 22, 2020
6d0e7ab
Update on "Use libkineto in profiler"
ilia-cher Nov 22, 2020
95b686f
Update on "Use libkineto in profiler"
ilia-cher Nov 23, 2020
d98a5fb
Update on "Use libkineto in profiler"
ilia-cher Nov 23, 2020
cb7367e
Update on "Use libkineto in profiler"
ilia-cher Nov 23, 2020
5ad0a34
Update on "Use libkineto in profiler"
ilia-cher Nov 23, 2020
d6bd96e
Update on "Use libkineto in profiler"
ilia-cher Nov 23, 2020
0c4faaa
Update on "Use libkineto in profiler"
ilia-cher Nov 23, 2020
ab754e1
Update on "Use libkineto in profiler"
ilia-cher Nov 23, 2020
aee38e8
Update on "Use libkineto in profiler"
ilia-cher Nov 24, 2020
671785f
Update on "Use libkineto in profiler"
ilia-cher Nov 24, 2020
8fde042
Update on "Use libkineto in profiler"
ilia-cher Nov 24, 2020
ca6cb73
Update on "Use libkineto in profiler"
ilia-cher Nov 25, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions aten/src/ATen/record_function.cpp
Expand Up @@ -10,13 +10,13 @@ namespace {

// Used to generate unique callback handles
CallbackHandle next_unique_callback_handle() {
static std::atomic<uint64_t> unique_cb_id {0};
return CallbackHandle(++unique_cb_id);
static std::atomic<uint64_t> unique_cb_id {1};
return CallbackHandle(unique_cb_id++);
}

RecordFunctionHandle next_unique_record_function_handle() {
static std::atomic<uint64_t> unique_rf_id {0};
return RecordFunctionHandle(++unique_rf_id);
static std::atomic<uint64_t> unique_rf_id {1};
return RecordFunctionHandle(unique_rf_id++);
}

thread_local RecordFunctionTLS rf_tls_;
Expand Down
43 changes: 17 additions & 26 deletions benchmarks/profiler_benchmark/profiler_bench.py
@@ -1,10 +1,9 @@
import argparse
import statistics
import sys
import timeit
import torch

from torch.utils._benchmark import Timer
from torch.utils.benchmark import Timer

PARALLEL_TASKS_NUM = 4
INTERNAL_ITER = None
Expand Down Expand Up @@ -34,29 +33,30 @@ def parallel_task(x):
parser.add_argument('--with_cuda', action='store_true')
parser.add_argument('--with_stack', action='store_true')
parser.add_argument('--use_script', action='store_true')
parser.add_argument('--use_kineto', action='store_true')
parser.add_argument('--profiling_tensor_size', default=1, type=int)
parser.add_argument('--workload', default='loop', type=str)
parser.add_argument('--internal_iter', default=256, type=int)
parser.add_argument('--n', default=100, type=int)
ilia-cher marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument('--use_timer', action='store_true')
parser.add_argument('--timer_min_run_time', default=100, type=int)
parser.add_argument('--timer_min_run_time', default=10, type=int)
parser.add_argument('--cuda_only', action='store_true')

args = parser.parse_args()

if args.with_cuda and not torch.cuda.is_available():
print("No CUDA available")
sys.exit()

print("Payload: {}; {} iterations, N = {}\n".format(
args.workload, args.internal_iter, args.n))
print("Payload: {}, {} iterations; timer min. runtime = {}\n".format(
args.workload, args.internal_iter, args.timer_min_run_time))
INTERNAL_ITER = args.internal_iter

for profiling_enabled in [False, True]:
print("Profiling {}, tensor size {}x{}, use cuda: {}, with stacks: {}, use script: {}".format(
print("Profiling {}, tensor size {}x{}, use cuda: {}, use kineto: {}, with stacks: {}, use script: {}".format(
"enabled" if profiling_enabled else "disabled",
args.profiling_tensor_size,
args.profiling_tensor_size,
args.with_cuda,
args.use_kineto,
args.with_stack,
args.use_script))

Expand All @@ -83,27 +83,18 @@ def payload():
x = None
with torch.autograd.profiler.profile(
use_cuda=args.with_cuda,
with_stack=args.with_stack) as prof:
with_stack=args.with_stack,
use_kineto=args.use_kineto,
use_cpu=not args.cuda_only) as prof:
x = workload(input_x)
return x
else:
def payload():
return workload(input_x)

if args.use_timer:
t = Timer(
"payload()",
globals={"payload": payload},
timer=timeit.default_timer,
).blocked_autorange(min_run_time=args.timer_min_run_time)
print(t)
else:
runtimes = timeit.repeat(payload, repeat=args.n, number=1)
avg_time = statistics.mean(runtimes) * 1000.0
stddev_time = statistics.stdev(runtimes) * 1000.0
print("\tavg. time: {:.3f} ms, stddev: {:.3f} ms".format(
avg_time, stddev_time))
if args.workload == "loop":
print("\ttime per iteration: {:.3f} ms".format(
avg_time / args.internal_iter))
print()
t = Timer(
"payload()",
globals={"payload": payload},
timer=timeit.default_timer,
).blocked_autorange(min_run_time=args.timer_min_run_time)
print(t)
32 changes: 30 additions & 2 deletions cmake/Dependencies.cmake
Expand Up @@ -1751,7 +1751,8 @@ endif()
#
# End ATen checks
#

set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt)

# Disable compiler feature checks for `fmt`.
Expand All @@ -1764,6 +1765,7 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt)
set_target_properties(fmt-header-only PROPERTIES INTERFACE_COMPILE_FEATURES "")

list(APPEND Caffe2_DEPENDENCY_LIBS fmt::fmt-header-only)
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE)

# ---[ Kineto
if(USE_KINETO)
Expand All @@ -1774,8 +1776,34 @@ if(USE_KINETO)
set(KINETO_LIBRARY_TYPE "static" CACHE STRING "")
set(CUDA_SOURCE_DIR "${CUDA_TOOLKIT_ROOT_DIR}" CACHE STRING "")

message(STATUS "Configuring Kineto dependency:")
message(STATUS " KINETO_SOURCE_DIR = ${KINETO_SOURCE_DIR}")
message(STATUS " KINETO_BUILD_TESTS = ${KINETO_BUILD_TESTS}")
message(STATUS " KINETO_LIBRARY_TYPE = ${KINETO_LIBRARY_TYPE}")
message(STATUS " CUDA_SOURCE_DIR = ${CUDA_SOURCE_DIR}")

if(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/include)
set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/extras/CUPTI/include")
elseif(EXISTS ${CUDA_SOURCE_DIR}/include/cupti.h)
set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/include")
endif()

if((NOT DEFINED CUDA_cupti_LIBRARY) OR (${CUDA_cupti_LIBRARY} STREQUAL "CUDA_cupti_LIBRARY-NOTFOUND"))
if(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti_static.a)
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti_static.a")
elseif(EXISTS ${CUDA_SOURCE_DIR}/lib64/libcupti_static.a)
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/lib64/libcupti_static.a")
elseif(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti.so)
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti.so")
elseif(EXISTS ${CUDA_SOURCE_DIR}/lib64/libcupti.so)
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/lib64/libcupti.so")
endif()
endif()
message(STATUS " CUDA_cupti_LIBRARY = ${CUDA_cupti_LIBRARY}")
message(STATUS " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}")

add_subdirectory("${KINETO_SOURCE_DIR}")
message(STATUS "Configured libkineto as a dependency.")
message(STATUS "Configured Kineto as a dependency.")
endif()

list(APPEND Caffe2_DEPENDENCY_LIBS kineto)
Expand Down
20 changes: 10 additions & 10 deletions test/cpp/jit/test_misc.cpp
Expand Up @@ -2163,7 +2163,7 @@ TEST(TLSFutureCallbacksTest, Basic) {
// test running callbacks with propagation of TLS state.
{
// Enable the profiler in this thread
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
Expand All @@ -2172,12 +2172,12 @@ TEST(TLSFutureCallbacksTest, Basic) {
// Since we join here, we can ensure that all callbacks corresponding to
// markCompleted() have finished.
t.join();
torch::autograd::profiler::disableProfiler();
torch::autograd::profiler::disableProfilerLegacy();
}
// then() with TLS State
{
// Enable the profiler in this thread
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
Expand All @@ -2190,7 +2190,7 @@ TEST(TLSFutureCallbacksTest, Basic) {
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
t.join();
s2->wait();
torch::autograd::profiler::disableProfiler();
torch::autograd::profiler::disableProfilerLegacy();
}
}

Expand All @@ -2199,7 +2199,7 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
auto profilerEnabledCb = []() {
ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
};
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
Expand All @@ -2212,10 +2212,10 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
// Don't cleanup TLSState, and just consolidate.
auto opts = torch::autograd::profiler::ProfilerDisableOptions(false, true);
auto thread_event_lists =
torch::autograd::profiler::disableProfiler(std::move(opts));
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
// Ensure that the events from this thread are still profiled and we obtain
// the expected in events in our consolidated list when calling
// disableProfiler().
// disableProfilerLegacy().
bool found_ones = false;
bool found_add = false;
for (const auto& li : thread_event_lists) {
Expand All @@ -2237,21 +2237,21 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
s1->addCallback(verifyProfilerCb);
// Disable the profiler, but do not consolidate results in the main thread.
auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
torch::autograd::profiler::disableProfiler(std::move(opts));
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); });
t.join();

// Similar to above test, but verifies correctness in the case where
// continuation runs on the main thread.
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
s1 = c10::make_intrusive<Future>(IntType::get());
s1->addCallback(verifyProfilerCb);
// Runs callback inline
s1->markCompleted(at::IValue(1));
opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
torch::autograd::profiler::disableProfiler(std::move(opts));
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
}

TEST(IValueKWargsTest, Basic) {
Expand Down