Skip to content

Commit

Permalink
Add FLOPS support to the new profiler API. (#51734)
Browse files Browse the repository at this point in the history
Summary:
The new profiler API was added in PR#48280. This PR is to add FLOPS
support to the new profiler API.

Pull Request resolved: #51734

Test Plan:
```python
python test/test_profiler.py -k test_flops
```

Reviewed By: xuzhao9

Differential Revision: D26261851

Pulled By: ilia-cher

fbshipit-source-id: dbeba4c197e6f51a9a8e640e8bb60ec38df87f73
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Feb 5, 2021
1 parent 430329e commit 5c3a054
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 5 deletions.
14 changes: 14 additions & 0 deletions test/test_profiler.py
Expand Up @@ -385,6 +385,20 @@ def test_flops(self):
profiler_output = prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10)
self.assertIn("FLOPS", profiler_output)

if not (kineto_available() and torch.cuda.is_available()):
return

with profile(activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
with_flops=True,
) as kineto_profiler:
model(inputs)
profiler_output = kineto_profiler.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1)
self.assertIn("FLOPS", profiler_output)

@unittest.skipIf(not kineto_available(), "Kineto is required")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_kineto_profiler_api(self):
Expand Down
4 changes: 2 additions & 2 deletions torch/autograd/profiler.py
Expand Up @@ -365,8 +365,8 @@ class profile(object):
with_flops (bool, optional): If with_flops is set, the profiler will estimate
the FLOPS (floating pointer operations per second) value using the operator's input shape
and total CPU time. This allows one to estimate the hardware performance. Currently,
this option only works for the matrix multiplication and convolution functions.
and total time. This allows one to estimate the hardware performance. Currently,
this option only works for the matrix multiplication and 2D convolution operators.
profile_memory (bool, optional): track tensor memory allocation/deallocation.
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/profiler_kineto.h
Expand Up @@ -170,7 +170,7 @@ struct TORCH_API KinetoEvent {
uint8_t activity_type_;
c10::optional<std::vector<std::vector<int64_t>>> shapes_;
c10::optional<std::vector<std::string>> stack_;
uint64_t flops_;
uint64_t flops_ = 0;

std::string name_;
uint64_t device_index_ = 0;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/profiler_legacy.h
Expand Up @@ -331,7 +331,7 @@ struct TORCH_API LegacyEvent {
uint64_t correlation_id_;
// Extra arguments for computing op flops
std::unordered_map<std::string, c10::IValue> extra_args_;
uint64_t flops_;
uint64_t flops_ = 0;
};

// a linked-list of fixed sized vectors, to avoid
Expand Down
6 changes: 5 additions & 1 deletion torch/profiler/profiler.py
Expand Up @@ -92,7 +92,8 @@ class profile(object):
during the profiling;
- ``record_shapes`` - save information about operator's input shapes;
- ``profile_memory`` - track tensor memory allocation/deallocation;
- ``with_stack`` - record source information (file and line number) for the ops.
- ``with_stack`` - record source information (file and line number) for the ops;
- ``with_flops`` - use formula to estimate the FLOPS of specific operators (matrix multiplication and 2D convolution);
- ``use_cuda`` - (deprecated, use ``activities``).
.. note::
Expand Down Expand Up @@ -178,6 +179,7 @@ def __init__(
record_shapes: bool = False,
profile_memory: bool = False,
with_stack: bool = False,
with_flops: bool = False,
# deprecated:
use_cuda: Optional[bool] = None):
if activities:
Expand Down Expand Up @@ -207,6 +209,7 @@ def __init__(
self.record_steps = False
self.on_trace_ready = on_trace_ready
self.record_shapes = record_shapes
self.with_flops = with_flops
self.profile_memory = profile_memory
self.with_stack = with_stack
self.step_num = 0
Expand Down Expand Up @@ -353,6 +356,7 @@ def _start_warmup(self):
use_cuda=(ProfilerActivity.CUDA in self.activities),
use_cpu=(ProfilerActivity.CPU in self.activities),
record_shapes=self.record_shapes,
with_flops=self.with_flops,
profile_memory=self.profile_memory,
with_stack=self.with_stack,
use_kineto=True,
Expand Down

0 comments on commit 5c3a054

Please sign in to comment.