Skip to content

Commit

Permalink
Multi-GPU Kineto profiler test (#51391)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #51391

Adding a test to check the kineto profiler on multiple gpus

Test Plan: python test/test_profiler.py

Reviewed By: ngimel

Differential Revision: D26160788

Pulled By: ilia-cher

fbshipit-source-id: f3554f52176cc26e7f331d205f1a514eb03aa758
  • Loading branch information
Ilia Cherniavskii authored and facebook-github-bot committed Jan 30, 2021
1 parent 11cda92 commit 17b5683
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
58 changes: 43 additions & 15 deletions test/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import torch.nn as nn
import torch.optim
import torch.utils.data
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import (
TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS, TemporaryFileName)
import torch.autograd.profiler as profiler
from torch.autograd.profiler import profile
from torch.autograd import kineto_available
from torch.autograd.profiler import profile as _profile
from torch.profiler import profile, kineto_available, DeviceType, ProfilerActivity

try:
import psutil
Expand All @@ -33,7 +33,7 @@ def test_mem_leak(self):
p = psutil.Process()
last_rss = collections.deque(maxlen=5)
for outer_idx in range(10):
with profile(use_cuda=True):
with _profile(use_cuda=True):
for _ in range(1024):
t = torch.mm(t, t)

Expand Down Expand Up @@ -79,7 +79,7 @@ def forward(self, x):

mod = DummyModule()

with profile(with_stack=True, use_kineto=kineto_available()) as p:
with _profile(with_stack=True, use_kineto=kineto_available()) as p:
x = torch.randn(10, 10, requires_grad=True)
y = torch.randn(10, 10, requires_grad=True)
z = x + y
Expand Down Expand Up @@ -115,11 +115,11 @@ def payload(self):
@unittest.skipIf(not kineto_available(), "Kineto is required")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_kineto(self):
with profile(use_cuda=True, use_kineto=True):
with _profile(use_cuda=True, use_kineto=True):
self.payload()

# rerun to avoid initial start overhead
with profile(use_cuda=True, use_kineto=True) as p:
with _profile(use_cuda=True, use_kineto=True) as p:
self.payload()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
Expand All @@ -134,6 +134,34 @@ def test_kineto(self):
self.assertTrue(found_memcpy)
# p.export_chrome_trace("/tmp/test_trace.json")

@unittest.skipIf(not kineto_available(), "Kineto is required")
@unittest.skipIf(not TEST_MULTIGPU, "Multiple GPUs needed")
def test_kineto_multigpu(self):
with profile(
activities=[
ProfilerActivity.CPU,
ProfilerActivity.CUDA]) as prof:
for gpu_id in [0, 1]:
x = torch.randn(10, 10).cuda(gpu_id)
y = torch.randn(10, 10).cuda(gpu_id)
z = x.matmul(y)

found_gemm_0 = False
found_gemm_1 = False
found_cuda = False
for evt in prof.events():
if "gemm" in evt.name.lower() and evt.device_type == DeviceType.CUDA:
if evt.device_index == 0:
found_gemm_0 = True
elif evt.device_index == 1:
found_gemm_1 = True
if "cuda" in evt.name.lower() and evt.device_type == DeviceType.CPU:
found_cuda = True

self.assertTrue(found_gemm_0)
self.assertTrue(found_gemm_1)
self.assertTrue(found_cuda)

def test_high_level_trace(self):
"""Checks that python side high level events are recorded.
"""
Expand Down Expand Up @@ -200,7 +228,7 @@ def judge(expected_event_count, prof):
for key, count in expected_event_count.items():
self.assertTrue((key in actual_event_count.keys()) and (count == actual_event_count[key]))

with profile() as prof:
with _profile() as prof:
train()
expected_event_count = {
# "+1" because the final iteration will enter __next__ but skip the loop body.
Expand All @@ -212,13 +240,13 @@ def judge(expected_event_count, prof):

# Test on pickle/unpickle. Expect to work in multi-processing.
optimizer = pickle.loads(pickle.dumps(optimizer))
with profile() as prof:
with _profile() as prof:
train()
judge(expected_event_count, prof)

# Test on customized optimizer.
optimizer = CustomSGD(model.parameters(), lr=1e-4)
with profile() as prof:
with _profile() as prof:
train()
expected_event_count = {
"enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1),
Expand All @@ -235,7 +263,7 @@ def test_flops(self):
nn.ReLU(),
)
inputs = torch.randn(40, 16, 18, 260)
with profiler.profile(record_shapes=True, with_flops=True) as prof:
with _profile(record_shapes=True, with_flops=True) as prof:
model(inputs)
profiler_output = prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10)
print(profiler_output)
Expand All @@ -246,7 +274,7 @@ def test_flops(self):
def test_kineto_profiler_api(self):
called_num = [0]

with profile(use_cuda=True, use_kineto=True):
with _profile(use_cuda=True, use_kineto=True):
self.payload()

def trace_handler(p):
Expand All @@ -255,7 +283,7 @@ def trace_handler(p):
# p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json")
called_num[0] += 1

with torch.profiler.profile(
with profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
Expand All @@ -272,7 +300,7 @@ def trace_handler(p):
self.assertEqual(called_num[0], 2)

# case without enable_pred
with torch.profiler.profile(
with profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA]
Expand All @@ -283,7 +311,7 @@ def trace_handler(p):
sort_by="self_cuda_time_total", row_limit=-1))

def test_export_stacks(self):
with profile(with_stack=True, use_kineto=kineto_available()) as p:
with _profile(with_stack=True, use_kineto=kineto_available()) as p:
x = torch.randn(10, 10)
y = torch.randn(10, 10)
z = torch.mm(x, y)
Expand Down
2 changes: 1 addition & 1 deletion torch/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
'''

from .profiler import profile, schedule, ProfilerAction, ProfilerActivity
from torch.autograd import kineto_available
from torch.autograd import kineto_available, DeviceType
8 changes: 8 additions & 0 deletions torch/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ def key_averages(self, group_by_input_shape: bool = False, group_by_stack_n: int
assert self.profiler
return self.profiler.key_averages(group_by_input_shape, group_by_stack_n)

def events(self):
"""
Returns the list of unaggregated profiler events,
to be used in the trace callback or after the profiling is finished
"""
assert self.profiler
return self.profiler.function_events

def _enter_actions(self):
if self.current_action == ProfilerAction.WARMUP:
self._start_warmup()
Expand Down

0 comments on commit 17b5683

Please sign in to comment.