Skip to content

Commit

Permalink
Use libkineto in profiler
Browse files Browse the repository at this point in the history
Summary:
Adding ability to use Kineto (CUPTI) to profile CUDA kernels

Test Plan:
python test/test_profiler.py

[ghstack-poisoned]
  • Loading branch information
ilia-cher committed Oct 16, 2020
1 parent 26383d8 commit a4d4124
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 40 deletions.
20 changes: 20 additions & 0 deletions test/test_profiler.py
Expand Up @@ -99,6 +99,26 @@ def forward(self, x):

torch._C._set_graph_executor_optimize(prev_opt)

@unittest.skipIf(not torch.autograd.kineto_available(), "Kineto is required")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_kineto(self):
x = torch.randn(10, 10).cuda()
y = torch.randn(10, 10).cuda()
with profile(use_cuda=True, use_kineto=True) as p:
z = torch.mm(x, y)
z = z + y
z = z.cpu()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
found_gemm = False
found_memcpy = False
for e in p.function_events:
if "gemm" in e.name:
found_gemm = True
if "Memcpy" in e.name or "memcpy" in e.name:
found_memcpy = True
self.assertTrue(found_gemm)
self.assertTrue(found_memcpy)

if __name__ == '__main__':
run_tests()
6 changes: 3 additions & 3 deletions torch/autograd/__init__.py
Expand Up @@ -242,6 +242,6 @@ def variable(*args, **kwargs):
raise RuntimeError("autograd initialization failed")

# Import all native method/classes
from torch._C._autograd import (ProfilerState, ProfilerConfig, ProfilerEvent,
_enable_profiler, _disable_profiler, _profiler_enabled,
_enable_record_function, _set_empty_test_observer)
from torch._C._autograd import (ProfilerActivity, ProfilerState, ProfilerConfig, ProfilerEvent,
_prepare_profiler, _enable_profiler, _disable_profiler, _profiler_enabled,
_enable_record_function, _set_empty_test_observer, kineto_available)
65 changes: 51 additions & 14 deletions torch/autograd/profiler.py
Expand Up @@ -364,32 +364,56 @@ def __init__(
use_cuda=False,
record_shapes=False,
profile_memory=False,
with_stack=False):
with_stack=False,
use_kineto=False):
self.enabled = enabled
self.use_cuda = use_cuda
self.function_events = None
if not self.enabled:
return
self.use_cuda = use_cuda
self.function_events = None
self.entered = False
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.with_stack = with_stack
self.use_kineto = use_kineto

self.profiler_kind = None
self.kineto_activities = []
if self.use_kineto:
if self.use_cuda:
self.profiler_kind = torch.autograd.ProfilerState.KINETO
self.kineto_activities = [
torch.autograd.ProfilerActivity.CPU,
# uses CUPTI
torch.autograd.ProfilerActivity.CUDA_RUNTIME,
torch.autograd.ProfilerActivity.CUDA]
else:
# intially we're not using Kineto for CPU only case
self.profiler_kind = torch.autograd.ProfilerState.CPU
elif self.use_cuda:
# legacy CUDA mode
self.profiler_kind = torch.autograd.ProfilerState.CUDA
else:
self.profiler_kind = torch.autograd.ProfilerState.CPU
self.kineto_activities = set(self.kineto_activities)

if self.profiler_kind == torch.autograd.ProfilerState.KINETO:
assert torch.autograd.kineto_available()

self.config = torch.autograd.ProfilerConfig(
self.profiler_kind,
self.record_shapes,
self.profile_memory,
self.with_stack)

def __enter__(self):
if not self.enabled:
return
if self.entered:
raise RuntimeError("autograd profiler traces are not reentrant")
self.entered = True
profiler_kind = torch.autograd.ProfilerState.CUDA if self.use_cuda \
else torch.autograd.ProfilerState.CPU

config = torch.autograd.ProfilerConfig(
profiler_kind,
self.record_shapes,
self.profile_memory,
self.with_stack)
torch.autograd._enable_profiler(config)
torch.autograd._prepare_profiler(self.config, self.kineto_activities)
torch.autograd._enable_profiler(self.config)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
Expand Down Expand Up @@ -732,7 +756,7 @@ class FunctionEvent(FormattedTimesMixin):
def __init__(
self, id, node_id, name, thread, cpu_start, cpu_end, fwd_thread=None, input_shapes=None,
stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, is_async=False,
is_remote=True, sequence_nr=-1):
is_remote=True, sequence_nr=-1, device_id=-1):
self.id: int = id
self.node_id: int = node_id
self.name: str = name
Expand All @@ -751,6 +775,7 @@ def __init__(
self.is_async: bool = is_async
self.is_remote: bool = is_remote
self.sequence_nr: int = sequence_nr
self.device_id: int = device_id

def append_kernel(self, name, device, start, end):
self.kernels.append(Kernel(name, device, Interval(start, end)))
Expand Down Expand Up @@ -802,15 +827,21 @@ def self_cpu_time_total(self):

@property
def cuda_time_total(self):
if self.device_id >= 0:
return self.cpu_interval.elapsed_us()
return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels)

@property
def self_cuda_time_total(self):
if self.device_id >= 0:
return self.cuda_time_total - sum([child.cuda_time_total for child in self.cpu_children])
return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) - \
sum([child.cuda_time_total for child in self.cpu_children])

@property
def cpu_time_total(self):
if self.device_id >= 0:
return 0
return self.cpu_interval.elapsed_us()

@property
Expand Down Expand Up @@ -1045,6 +1076,7 @@ def adjusted_time(cuda_record, cuda_records_map):
is_async=is_async,
is_remote=is_remote_event,
sequence_nr=start.sequence_nr(),
device_id=start.device_id(),
)
# note: async events have only cpu total time
if not is_async and start.has_cuda():
Expand Down Expand Up @@ -1180,7 +1212,9 @@ def build_table(
has_input_shapes = any(
[(event.input_shapes is not None and len(event.input_shapes) > 0) for event in events])

MAX_NAME_COLUMN_WIDTH = 55
name_column_width = max([len(evt.key) for evt in events]) + 4
name_column_width = min(name_column_width, MAX_NAME_COLUMN_WIDTH)

DEFAULT_COLUMN_WIDTH = 12

Expand Down Expand Up @@ -1288,8 +1322,11 @@ def append(s):
continue
else:
event_limit += 1
name = evt.key
if len(name) >= MAX_NAME_COLUMN_WIDTH-3:
name = name[:(MAX_NAME_COLUMN_WIDTH-3)] + "..."
row_values = [
evt.key, # Name
name,
# Self CPU total, 0 for async events. %
format_time_share(evt.self_cpu_time_total,
self_cpu_time_total),
Expand Down
14 changes: 12 additions & 2 deletions torch/csrc/autograd/init.cpp
Expand Up @@ -39,7 +39,13 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
.value("Disabled", ProfilerState::Disabled)
.value("CPU", ProfilerState::CPU)
.value("CUDA", ProfilerState::CUDA)
.value("NVTX", ProfilerState::NVTX);
.value("NVTX", ProfilerState::NVTX)
.value("KINETO", ProfilerState::KINETO);

py::enum_<ActivityType>(m, "ProfilerActivity")
.value("CPU", ActivityType::CPU)
.value("CUDA_RUNTIME", ActivityType::CUDA_RUNTIME)
.value("CUDA", ActivityType::CUDA);

py::class_<ProfilerConfig>(m, "ProfilerConfig")
.def(py::init<ProfilerState, bool, bool, bool>());
Expand All @@ -61,11 +67,15 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
.def("is_remote", &Event::isRemote)
.def("sequence_nr", &Event::sequenceNr)
.def("stack", &Event::stack)
.def("scope", &Event::scope);
.def("scope", &Event::scope)
.def("device_id", &Event::device);

py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
.def(py::init<bool, bool>());

m.def("kineto_available", kinetoAvailable);

m.def("_prepare_profiler", prepareProfiler);
m.def("_enable_profiler", enableProfiler);
m.def(
"_disable_profiler",
Expand Down

0 comments on commit a4d4124

Please sign in to comment.