-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Add default args for CUDA stream and events #52679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit a66d5c9 (more details on the Dr. CI page):
🕵️ 4 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some nits
test/jit/test_cuda.py
Outdated
def forward(self): | ||
device_index = torch.cuda._current_device() | ||
s = torch.jit.cuda.Stream(device_index, 0) | ||
s = torch.jit.cuda.Stream(priority=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s = torch.jit.cuda.Stream(priority=0) | |
s = torch.jit.cuda.Stream() |
test/jit/test_cuda.py
Outdated
device_index = torch.cuda._current_device() | ||
s0 = torch.cuda.current_stream(device_index) | ||
s1 = torch.jit.cuda.Stream(device_index, 0) | ||
s1 = torch.jit.cuda.Stream(priority=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s1 = torch.jit.cuda.Stream(priority=0) | |
s1 = torch.jit.cuda.Stream() |
test/jit/test_cuda.py
Outdated
def test_event_synchronize() -> float: | ||
device_index = torch.cuda._current_device() | ||
s = torch.jit.cuda.Stream(device_index, 0) | ||
s = torch.jit.cuda.Stream(priority=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s = torch.jit.cuda.Stream(priority=0) | |
s = torch.jit.cuda.Stream() |
test/jit/test_cuda.py
Outdated
def test_stream_synchronize() -> float: | ||
device_index = torch.cuda._current_device() | ||
s = torch.jit.cuda.Stream(device_index, 0) | ||
s = torch.jit.cuda.Stream(priority=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s = torch.jit.cuda.Stream(priority=0) | |
s = torch.jit.cuda.Stream() |
test/jit/test_cuda.py
Outdated
def event_default_args_1() -> bool: | ||
e = torch.jit.cuda.Event(blocking=True) | ||
return e is not None | ||
|
||
@torch.jit.script | ||
def event_default_args_2() -> bool: | ||
e = torch.jit.cuda.Event(enable_timing=True) | ||
return e is not None | ||
|
||
@torch.jit.script | ||
def event_default_args_3() -> bool: | ||
e = torch.jit.cuda.Event(interprocess=True) | ||
return e is not None | ||
|
||
@torch.jit.script | ||
def event_default_args_4() -> bool: | ||
e = torch.jit.cuda.Event(interprocess=True, blocking=True) | ||
return e is not None | ||
|
||
@torch.jit.script | ||
def event_default_args_5() -> bool: | ||
e = torch.jit.cuda.Event(enable_timing=True, blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using better test names here, and testing that the arguments and default arguments are being respect. That is, the event in event_default_args_3
should not be blocking and we should be able to verify that (I think?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, to verify blocking we would ideally need sleep
support in JIT and since that isn't available, it becomes a lot harder to verify if blocking was truly disabled. This is the case even with interprocess. Some of the tests like event_default_args_5
they are already tested - ref (test_event_wait) and hence I didn't add a redundant test for the same.
test/jit/test_cuda.py
Outdated
device_index = torch.cuda._current_device() | ||
current_stream = torch.cuda.current_stream(device_index) | ||
user_stream = torch.jit.cuda.Stream(device_index, 0) | ||
user_stream = torch.jit.cuda.Stream(priority=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
user_stream = torch.jit.cuda.Stream(priority=0) | |
user_stream = torch.jit.cuda.Stream() |
test/jit/test_cuda.py
Outdated
current_stream = torch.cuda.current_stream(device_index) | ||
default_stream = torch.cuda.default_stream(device_index) | ||
user_stream = torch.jit.cuda.Stream(device_index, 0) | ||
user_stream = torch.jit.cuda.Stream(priority=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
user_stream = torch.jit.cuda.Stream(priority=0) | |
user_stream = torch.jit.cuda.Stream() |
test/jit/test_cuda.py
Outdated
def test_simple_stream(): | ||
device_index = torch.cuda._current_device() | ||
s = torch.jit.cuda.Stream(device_index, 0) | ||
s = torch.jit.cuda.Stream(priority=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s = torch.jit.cuda.Stream(priority=0) | |
s = torch.jit.cuda.Stream() |
[ghstack-poisoned]
Co-authored-by: SplitInfinity <meghanl@fb.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few questions about the tests and how they make sure the default arguments work as expected.
As you alluded to in your comment, there isn't much we can do to really test these properties of events. Maybe it's better to have no test rather than a test that doesn't actually test anything?
test/jit/test_cuda.py
Outdated
e.record(s0) | ||
e.wait(s) | ||
A = torch.rand(1000, 1000, device="cuda") | ||
with torch.jit.cuda.stream(s): | ||
B = torch.mm(A, A) | ||
s.synchronize() | ||
e.record(s0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this test? I thought we would want to test that timing is disabled for this event (since that is the default argument). Would trying to call elapsed_time
throw an exception that we could check for using self.assertRaisesRegex
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried doing this, but elapsed_time
didn't throw any exception. So I tried to verify the blocking=True
. Although, it is difficult to prove that the event waits on the stream s
here.
test/jit/test_cuda.py
Outdated
e_tik = torch.jit.cuda.Event(enable_timing=True) | ||
e_tok = torch.jit.cuda.Event(enable_timing=True) | ||
s0 = torch.cuda.current_stream(0) | ||
e_tik.record(s0) | ||
A = torch.rand(1000, 1000, device="cuda") | ||
with torch.jit.cuda.stream(s0): | ||
B = torch.mm(A, A) | ||
s0.synchronize() | ||
e_tok.record(s0) | ||
e_tok.synchronize() | ||
return e_tik.elapsed_time(e_tok) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this test? I thought we would want to test that these events are nonblocking (since that is the default argument).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's difficult to prove that they are non-blocking, because of sleep
not being available. So I instead verified if enable_timing=True
can be verified.
test/jit/test_cuda.py
Outdated
e = torch.jit.cuda.Event(interprocess=True) | ||
s1 = torch.jit.cuda.Stream() | ||
s2 = torch.jit.cuda.Stream() | ||
A = torch.rand(1000, 1000, device="cuda") | ||
with torch.jit.cuda.stream(s1): | ||
B = torch.mm(A, A) | ||
s1.record_event(e) | ||
|
||
with torch.jit.cuda.stream(s2): | ||
C = torch.mm(A, A) | ||
s2.record_event(e) | ||
s1.synchronize() | ||
s2.synchronize() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question here - I thought we would want to test that the event is nonblocking and has timing disabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't verify the event is nonblocking and has timing disabled and hence in this test, it verifies if interprocess was enabled. For interprocess=True
, the event can be shared between two process. So in this test, I have tried using the event to be recorded between two streams.
I agree, we can probably eliminate the tests. My only objective of introducing the tests were to make sure , even with default arguments, the event object was getting created. But I guess that doesn't add much value here. Wdyt? |
Yeah, let's just remove the (new) tests. |
[ghstack-poisoned]
Done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure to address merge conflicts before merging.
Closing this PR. This PR was merged by another #53025 |
Stack from ghstack: