Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New profiler API #48280

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7896a1c
New profiler API
ilia-cher Nov 19, 2020
3697dd0
Update on "New profiler API"
ilia-cher Nov 20, 2020
c8c1e71
Update on "New profiler API"
ilia-cher Nov 20, 2020
3a94803
Update on "New profiler API"
ilia-cher Nov 20, 2020
6f604e4
Update on "New profiler API"
ilia-cher Nov 20, 2020
0700b58
Update on "New profiler API"
ilia-cher Nov 20, 2020
bcf4edf
Update on "New profiler API"
ilia-cher Nov 20, 2020
50fb369
Update on "New profiler API"
ilia-cher Nov 20, 2020
9081309
Update on "New profiler API"
ilia-cher Nov 20, 2020
a4405f8
Update on "New profiler API"
ilia-cher Nov 20, 2020
adae39e
Update on "New profiler API"
ilia-cher Nov 20, 2020
8be3509
Update on "New profiler API"
ilia-cher Nov 21, 2020
07d46bb
Update on "New profiler API"
ilia-cher Nov 22, 2020
0fc16e9
Update on "New profiler API"
ilia-cher Nov 22, 2020
d0ae1f2
Update on "New profiler API"
ilia-cher Nov 23, 2020
48915a5
Update on "New profiler API"
ilia-cher Nov 23, 2020
9b2328c
Update on "New profiler API"
ilia-cher Nov 23, 2020
c249cab
Update on "New profiler API"
ilia-cher Nov 23, 2020
d40387e
Update on "New profiler API"
ilia-cher Nov 23, 2020
5b1b598
Update on "New profiler API"
ilia-cher Nov 23, 2020
5c6a45d
Update on "New profiler API"
ilia-cher Nov 23, 2020
d8db2bf
Update on "New profiler API"
ilia-cher Nov 24, 2020
9fdc30e
Update on "New profiler API"
ilia-cher Nov 24, 2020
a46e065
Update on "New profiler API"
ilia-cher Nov 24, 2020
3af4699
Update on "New profiler API"
ilia-cher Nov 25, 2020
8eca63e
Update on "New profiler API"
ilia-cher Nov 26, 2020
4cc04a7
Update on "New profiler API"
ilia-cher Nov 30, 2020
5d3e080
Update on "New profiler API"
ilia-cher Dec 16, 2020
244b6c9
Update on "New profiler API"
ilia-cher Dec 16, 2020
890562e
Update on "New profiler API"
ilia-cher Dec 16, 2020
cc3b96c
Update on "New profiler API"
ilia-cher Dec 18, 2020
1c6a8ba
Update on "New profiler API"
ilia-cher Dec 18, 2020
4c5734f
Update on "New profiler API"
ilia-cher Dec 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 44 additions & 0 deletions test/test_profiler.py
Expand Up @@ -9,6 +9,8 @@
from torch.autograd.profiler import profile
from torch.autograd import kineto_available

import torch.profiler
ilia-cher marked this conversation as resolved.
Show resolved Hide resolved

try:
import psutil
HAS_PSUTIL = True
Expand Down Expand Up @@ -129,5 +131,47 @@ def test_kineto(self):
self.assertTrue(found_memcpy)
# p.export_chrome_trace("/tmp/test_trace.json")


@unittest.skipIf(not kineto_available(), "Kineto is required")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_kineto_profiler_api(self):
called_num = [0]

def test_output_fn(p):
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
# p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json")
called_num[0] += 1

with profile(use_cuda=True, use_kineto=True):
self.payload()

with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
enable_pred=torch.profiler.EnablePred(
wait=1,
warmup=1,
active=2,
output_fn=test_output_fn)
) as p:
for idx in range(8):
self.payload()
p.next_step()

self.assertEqual(called_num[0], 2)

# case without enable_pred
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA]
) as p:
self.payload()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))


if __name__ == '__main__':
run_tests()
9 changes: 9 additions & 0 deletions torch/autograd/profiler.py
Expand Up @@ -447,6 +447,15 @@ def __enter__(self):
torch.autograd._enable_profiler_legacy(self.config())
return self

def _prepare_kineto_trace(self):
assert self.kineto_activities
self.entered = True
torch.autograd._prepare_profiler(self.config(), self.kineto_activities)

def _start_kineto_trace(self):
assert self.kineto_activities
torch.autograd._enable_profiler(self.config(), self.kineto_activities)

def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return
Expand Down
6 changes: 6 additions & 0 deletions torch/profiler/__init__.py
@@ -0,0 +1,6 @@
# type: ignore
r'''
PyTorch Profiler API
'''

from .profiler import profile, EnablePred, ProfilerActivity
191 changes: 191 additions & 0 deletions torch/profiler/profiler.py
@@ -0,0 +1,191 @@
import torch.autograd.profiler as prof
from torch.autograd import ProfilerActivity

from typing import Callable, Iterable, Optional

class EnablePred(object):
"""
EnablePred describes on which steps profiler is active:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "EnablePred" stand for?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "enable predicate". It's probably not the best name, maybe "IterationSchedule" or something like that? (I'm bad at names too)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or ProfilerSchedule

- profiler starts in inactive state and stays in inactive state for the first 'wait' steps
- profiler then enters a warmup state and stays in this state for the next 'warmup' steps
- profiler then starts actively tracing/collecting stats for the next 'active' steps
- after this, profiler returns to the inactive state and cycle repeats

In case output_fn is specified, it is called every time the trace is ready
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like a callback instead of an "output_fn", is there a name involving "callback" that would fit?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up question, what is the signature of this function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an example of initialization will be useful here

"""
class Action(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both Action and State can inherit from enum. See https://docs.python.org/3.5/library/enum.html.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i was curious if our minimum supported python version supports enums, also the same for dataclasses

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're minimum at 3.6 now: https://github.com/pytorch/pytorch/blob/master/cmake/Dependencies.cmake#L963 so you're safe. Please update

START_WARMUP = 0
START_TRACE = 1
STOP_TRACE = 2

class State(object):
INACTIVE = 0
WARMUP = 1
ACTIVE = 2

def __init__(self, wait: int, warmup: int, active: int, output_fn: Optional[Callable[[prof.profile], None]]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest making the arguments to this function kwarg-only so it's more readable (it would also allow for default values)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And maybe "output_fn" -> "done_cb" or "on_done". Otherwise I'd expect that "output_fn" specifies where to put results

assert wait >= 0 and warmup >= 0 and active > 0
if warmup == 0:
print("Warning: profiler won't be using a warmup, which can skew profiler results")
self.wait = wait
self.warmup = warmup
self.active = active
self.output_fn = output_fn

def active_active_fn(step):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code here is very challenging to read and looks almost like a state machine. I guess I was expecting something simpler (like taking the modulus of the current iteration or epoch). Is this additional complexity really necessary?

Copy link
Contributor Author

@ilia-cher ilia-cher Dec 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is indeed exactly a finite state machine, yes
and we do need to track the states and transitions between them, so a finite state machine seems reasonable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on why that's necessary (vs taking a modulus)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can e.g. give an example showing that it's not so simple.
So you take the modulus and you know your current state, is it enough?
You still need to know what was the previous state and what where in the sequence you are.
If the current state is active, and the previous is warmup - action is start_trace
If the current state is active, and the previous is inactive - action is start_warmup, followed by start_trace
If the current state is active, and the previous state is active, does it mean we don't need to do anything?
It depends, consider (wait=0, warmup=0, active=2) - sometimes we don't do anything, sometimes we do stop, save, warmup, start.

So you still have to consider a lot of possibilities and edge cases, then in that case, why not just encode states and actions explicitly in the map? The original version of the code that didn't use it was way less readable imo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What actually happens when wait=0, warmup=0 and active=2? I would expect the following pattern:

ACTIVE ACTIVE ACTIVE ACTIVE ACTIVE...

Is this not correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"wait=0, warmup=0 and active=2" means don't wait, don't do warm up, trace for two iterations

Copy link
Contributor Author

@ilia-cher ilia-cher Dec 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then the sequence repeats, and we provide access to the traces once they are available to print/save

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting. So it's more correctly:

ACTIVE ACTIVE REPORT ACTIVE ACTIVE REPORT

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"REPORT" is implied at the every end of enable_pred cycle, will make sure to make it clear from the docs

if self._mod_step(step) == 1:
return [EnablePred.Action.STOP_TRACE, EnablePred.Action.START_WARMUP, EnablePred.Action.START_TRACE]
else:
return []

def inactive_warmup_fn(_):
raise RuntimeError("Incorrect profiler state sequence")

self.actions_map = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still feel it can be simplified. Can we replace the entire EnablePred with a function that returns one of the few enums of what to do on this iteration:

  • NONE
  • WARMUP (i.e. enable but not record)
  • RECORD (i.e. enabled)
  • RECORD_AND_FLUSH (i.e. enabled but dump a new iteration)

Then the profiler doesn't even need to depend on EnablePred class. It can be just an arbitrary callable with that maps step_num into action. Of course you'd provide the code you have here as the default implementation but the user can do their own. E.g. the simple one would be: enable_pred = lambda num: WARMUP if num < 100 else RECORD for simple profile that doesn't dump intermediate results.

This logic about state machine transition is best moved to Profiler directly

As for output_fn - what benefit does it have over just putting if p.step % 100 == 0: <my code> directly in the training loop? You can keep it, but I'd add it as a separate arg in addition to enable_pred so that both can be easily-passable lambdas

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good discussion here and above - I agree we should strive to simplify this as much as possible, and then make it as easy to understand as possible via good names and examples.

Warning: I haven't read the entire code yet so I might be missing the point - my comments are about user experience, not so much implementation details.

In my experience, a default behavior is be good for the majority of cases so I'll use that as example: one warmup iteration followed by three recorded iterations.
What frequently changes is when you want to trigger it (should allow simple modulo approach here I think) and what to do with the result. Less frequently, how many iterations to profile and which activities to include.

Given that, the standard use case ought to be something like:
trigger_when: p.step % 1000 == 100
handle_result: write_chrome_trace / print_summary / extract_metrics .... (list of predefined convenience functions)

i.e. trigger profiling at iteration 100, 1100 and so on. Important to note here is that profiling cannot be enabled while at the same time reporting. We need to warmup, record, then stop and process, in order for the overhead to be manageable during the record window.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify the above, this will be the behavior:
1100: Warmup
1101-1102: Record
1103: Record, then stop and process, calling the handle_result function asynchronously when processing is complete. As much work as possible is deferred to the processing stage to minimize overhead during recording. Recording should not continue in parallel with processing for detailed profiling, but we also need to consider a light-weight continuous profiling use case, possibly streaming a trace or ongoing extraction of metrics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzhulgakov suggestion sounds good, basically let's provide a way to define a sequence of states explicitly through lambda + suggested default impl

one caveat - we might then allow users to specify sequences like: warmup none warmup none, etc - so we'll need to check for that, but i guess that's fine - assuming most users use supplied default and the power users know what they are doing;

@gdankel I think both solutions above should basically allow these kind of use cases

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure... As long as users don't have to deal with this level of detail in the majority of cases, it seems reasonable.

EnablePred.State.ACTIVE: {
EnablePred.State.ACTIVE: active_active_fn,
EnablePred.State.WARMUP: [EnablePred.Action.START_TRACE],
EnablePred.State.INACTIVE: [EnablePred.Action.START_WARMUP, EnablePred.Action.START_TRACE],
},
EnablePred.State.WARMUP: {
EnablePred.State.ACTIVE: [EnablePred.Action.STOP_TRACE, EnablePred.Action.START_WARMUP],
EnablePred.State.WARMUP: [],
EnablePred.State.INACTIVE: [EnablePred.Action.START_WARMUP],
},
EnablePred.State.INACTIVE: {
EnablePred.State.ACTIVE: [EnablePred.Action.STOP_TRACE],
EnablePred.State.WARMUP: inactive_warmup_fn,
EnablePred.State.INACTIVE: [],
}
}

def _mod_step(self, step: int):
sum_states = self.wait + self.warmup + self.active
r = step % sum_states
if r == 0:
r = sum_states
return r

def _num_state(self, step: int):
mod_step = self._mod_step(step)
if mod_step <= self.wait:
return EnablePred.State.INACTIVE
elif mod_step <= self.wait + self.warmup:
return EnablePred.State.WARMUP
else:
return EnablePred.State.ACTIVE

def actions(self, step: int):
if step == 1:
st = self._num_state(step)
if st == EnablePred.State.ACTIVE:
return [EnablePred.Action.START_WARMUP, EnablePred.Action.START_TRACE]
elif st == EnablePred.State.WARMUP:
return [EnablePred.Action.START_WARMUP]
else:
return []
else:
st = self._num_state(step)
prev_st = self._num_state(step - 1)
acts = self.actions_map[st][prev_st]
if callable(acts):
return acts(step)
else:
return acts


class profile(object):
"""
PyTorch profiler context manager.
ilia-cher marked this conversation as resolved.
Show resolved Hide resolved

Arguments:
activities - list of activity groups (CPU, CUDA)
enable_pred (optional) - iteration predicate function, used together with `next_step` call
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you describe what enable_pred is, and also give an example below of using profiler with enable_pred?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure what a reader is supposed to take away from this statement. What is an "iteration prediction function"? How is it used with next_step? Why is that relevant?


Notes:
- profiler is based on the Kineto library - system profiler library, with support for CUPTI tracing
- enable_pred is used for training loop tracing, allowing users to enable profiler on certain
iterations and account for the warmup
- when enable_pred is not set, profiler is always active
- next_step uses record_function api to add information about steps in the trace
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The documentation does not explain what is next_step and does not explain what the user should do with it (the docs should say that users should call next_step once per training loop iteration, and that an event corresponding to the iteration will be recorded in this call)
  2. There should be a link to record_function documentation here, otherwise record_function api is not clear

"""
def __init__(
self,
activities: Iterable[ProfilerActivity],
enable_pred: Optional[EnablePred] = None,
record_shapes=False,
profile_memory=False,
with_stack=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we should still support use_cuda for backward compatibility?

self.activities = activities
self.enable_pred = enable_pred
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.with_stack = with_stack
self.step_num = 0
self.profiler: Optional[prof.profile] = None
self.step_rec_fn: Optional[prof.record_function] = None

if not self.enable_pred:
print("Warning: using profiler without enable predicate may result in the skewed " +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prints are bad, use warnings module

"results, use enable_pred to control the warmup time")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give a more detailed error message, with an example of how to instantiate and pass enable_pred


def __enter__(self):
self.next_step()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this 'self.next_step()' should put before return? Or else the first step info will loss when self.enable_pre==None.

if not self.enable_pred:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you just have a default enable_pred which does this behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case it's an endless tracing (for all iterations), then we should somehow represent it in EnablePred

self._run_action(EnablePred.Action.START_WARMUP)
self._run_action(EnablePred.Action.START_TRACE)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.step_rec_fn:
self.step_rec_fn.__exit__(None, None, None)
if self.profiler:
if self.enable_pred:
if self.enable_pred._num_state(self.step_num) == EnablePred.State.WARMUP:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling state again here leaks details. Can't profile object remember its current state instead?

self._run_action(EnablePred.Action.START_TRACE)
self._run_action(EnablePred.Action.STOP_TRACE, keep_profiler=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is keep_profiler true in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so that we can also use it after the context manager is finished


def next_step(self):
if self.step_rec_fn:
self.step_rec_fn.__exit__(None, None, None)
self.step_num += 1
if self.enable_pred:
self._run_actions(self.step_num)

self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num))
self.step_rec_fn.__enter__()

def export_chrome_trace(self, path: str):
ilia-cher marked this conversation as resolved.
Show resolved Hide resolved
assert self.profiler
return self.profiler.export_chrome_trace(path)

def key_averages(self, group_by_input_shape: bool = False, group_by_stack_n: int = 0):
assert self.profiler
return self.profiler.key_averages(group_by_input_shape, group_by_stack_n)

def _run_actions(self, step_num):
assert self.enable_pred
for act in self.enable_pred.actions(self.step_num):
self._run_action(act)

def _run_action(self, act, keep_profiler=False):
if act == EnablePred.Action.START_WARMUP:
self.profiler = prof.profile(
use_cuda=(ProfilerActivity.CUDA in self.activities),
use_cpu=(ProfilerActivity.CPU in self.activities),
record_shapes=self.record_shapes,
profile_memory=self.profile_memory,
with_stack=self.with_stack,
use_kineto=True,
)
self.profiler._prepare_kineto_trace()
elif act == EnablePred.Action.START_TRACE:
assert self.profiler is not None
self.profiler._start_kineto_trace()
elif act == EnablePred.Action.STOP_TRACE:
assert self.profiler is not None
self.profiler.__exit__(None, None, None)
if self.enable_pred and self.enable_pred.output_fn:
self.enable_pred.output_fn(self.profiler)
if not keep_profiler:
self.profiler = None