Skip to content

Commit

Permalink
Update on "New profiler API"
Browse files Browse the repository at this point in the history
Summary:
Adding new API for the kineto profiler that supports enable predicate
function - better support for profiling of training loops (example in test_profiler_kineto_api)

Test Plan:
python test/test_profiler.py -k test_profiler_kineto_api

Differential Revision: [D25142220](https://our.internmc.facebook.com/intern/diff/D25142220)

[ghstack-poisoned]
  • Loading branch information
ilia-cher committed Dec 18, 2020
2 parents cc3b96c + 48636e1 commit 1c6a8ba
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 41 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Features described in this documentation are classified by release status:
torch.jit <jit>
torch.linalg <linalg>
torch.overrides
profiler
nn.init
onnx
optim
Expand Down
17 changes: 17 additions & 0 deletions docs/source/profiler.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. currentmodule:: torch.profiler

torch.profiler
==============

Overview
--------
.. automodule:: torch.profiler


API Reference
-------------

.. autoclass:: torch.profiler.profile
:members:

.. autofunction:: torch.profiler.schedule
8 changes: 7 additions & 1 deletion torch/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# type: ignore
r'''
PyTorch Profiler API
PyTorch Profiler is a tool that allows the collecton of the performance metrics during the model training and inference.
Profiler's context manager API can be used to better understand what model operators are the most expensive, examine their input shapes and
stack traces, study device kernel activity and visualize the execution trace.
.. note::
An earlier version of the API in ``torch.autograd`` module is considered legacy and will be deprecated.
'''

from .profiler import profile, schedule, ProfilerAction, ProfilerActivity
90 changes: 50 additions & 40 deletions torch/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ class ProfilerAction(Enum):

def schedule(*, wait: int, warmup: int, active: int):
"""
Represents profiler behavior:
- wait for 'wait' steps
- do the warmup for the next 'warmup' steps
- do the active recording for the next 'active' steps
- repeat the cycle
Represents profiler behavior: wait for ``wait`` steps, then
do the warmup for the next ``warmup`` steps, then
do the active recording for the next ``active`` steps and then
repeat the cycle staring with the next step.
"""
def schedule_fn(step: int) -> ProfilerAction:
assert step >= 0
Expand All @@ -41,38 +40,43 @@ def schedule_fn(step: int) -> ProfilerAction:

def _default_schedule_fn(_: int) -> ProfilerAction:
"""
Default profiler behavior - immediately start recording the events,
keep doing it on every step
Default profiler behavior - immediately starts recording the events,
keeps doing it on every profiler step.
"""
return ProfilerAction.RECORD


class profile(object):
"""
PyTorch profiler context manager.
Profiler context manager.
Arguments:
activities - list of activity groups (CPU, CUDA);
schedule - a callable takes step (int) as a single parameter and returns
ProfilerAction value - specifies the profiler action on each step;
on_trace_ready (optional) - callable, called each time the trace is ready
during the profiling;
record_shapes - save information about operator's input shapes;
profile_memory - track tensor memory allocation/deallocation;
with_stack - save stack traces;
use_gpu - (deprecated, use activities)
Notes:
- profiler is using Kineto library - system profiler library, with support for CUPTI tracing
- with default schedule profiler immediately starts recording the events, a single trace produced
when context manager exits
- non-default profiler schedules can useful when tracing training loops, allowing users to enable
profiler on certain iterations (steps) and account for the warmup.
Warning: enabling shape and stack tracing causes an additional overhead.
- ``activities`` - list of activity groups (CPU, CUDA) to use in profiling;
- ``schedule`` - callable that takes step (int) as a single parameter and returns
``ProfilerAction`` value that specifies the profiler action on each step;
- ``on_trace_ready`` (optional) - callable, called each time the trace is ready
during the profiling;
- ``record_shapes`` - save information about operator's input shapes;
- ``profile_memory`` - track tensor memory allocation/deallocation;
- ``with_stack`` - save stack traces;
- ``use_gpu`` - (deprecated, use ``activities``).
.. note::
Use ``torch.profiler.schedule`` to generate the callable schedule.
Non-default schedules are useful when profiling long training jobs
and allow the user to obtain multiple traces at the different iterations
of the training process.
The default schedule simply records all the events continuously for the
duration of the context manager.
.. note::
Enabling shape and stack tracing results in additional overhead.
Examples:
- example that uses default schedule:
.. code-block:: python
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
Expand All @@ -82,18 +86,25 @@ class profile(object):
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
- the following example demonstrates the usage of schedule() and next_step():
def trace_handler(p):
print(p.key_averages().table(
Usimg the profiler's ``schedule``, ``on_trace_ready`` and ``next_step`` functions:
.. code-block:: python
# Non-default profiler schedule allows user to turn profiler on and off
# on different iterations of the training loop;
# trace_handler is called every time a new trace becomes available
def trace_handler(prof):
print(prof.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
# p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json")
# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step()) + ".json")
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
# Profiler will skip the first step/iteration,
# In this example with wait=1, warmup=1, active=2,
# profiler will skip the first step/iteration,
# start warming up on the second, record
# the third and the forth iterations,
# after which the trace will become available
Expand All @@ -105,11 +116,11 @@ def trace_handler(p):
warmup=1,
active=2),
on_trace_ready=trace_handler
) as p:
for iter in range(N):
code_iteration_to_profile(iter)
# send a signal to the profiler that the next iteration has started
p.next_step()
) as p:
for iter in range(N):
code_iteration_to_profile(iter)
# send a signal to the profiler that the next iteration has started
p.next_step()
"""
def __init__(
self,
Expand Down Expand Up @@ -256,9 +267,8 @@ def _exit_actions(self):
if self.current_action == ProfilerAction.WARMUP:
self._start_trace()
self._stop_trace()
elif self.current_action == ProfilerAction.RECORD:
self._stop_trace()
elif self.current_action == ProfilerAction.RECORD_AND_SAVE:
elif self.current_action in \
[ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE]:
self._stop_trace()
if self.on_trace_ready:
self.on_trace_ready(self)
Expand Down

0 comments on commit 1c6a8ba

Please sign in to comment.