-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fori_loop|While_loop] Enable while_loop/fori_loop, Add linear/MNIST test case #7094
Conversation
ManfeiBai
commented
May 22, 2024
•
edited
Loading
edited
- Modify BuildForiLoop for while_loop/fori_loop
- Modify test code for while_loop/fori_loop with simple version
- Add linear_layer_model and MNIST_model(without BN layer) test case
l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device) | ||
l_out = torch.randn(bs, 10, dtype=torch.float32, device=device) | ||
iteri = torch.tensor(3, dtype=torch.int64, device=device) | ||
_, _, res = mnist(iteri, l_in_0, l_out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we init the l_out
inside the mnist forward and returns it? This way the only additional output we provide is the count.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, good idea, due to hard to gain result before finish one run, so current plan would be:
run body_fn with pure python while
, get result, then use this result as l_out
before implementation, did quick benchmark for this idea with code:
- iteri = 3:
while_loop(3.675 s), pure while(3.107 s), pure_while_then_while_loop_plan(3.724 s)
- iteri = 10:
while_loop(3.670s), pure while(3.119s), pure_while_then_while_loop_plan(3.932s)
- iteri = 10:
while_loop(3.744 s), pure while(3.110 s), pure_while_then_while_loop_plan(3.749 s)
- iteri = 10:
while_loop(3.878 s), pure while(7.053 s), pure_while_then_while_loop_plan(3.809 s)
so this plan looks like worth to try
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only review the test part, left some comments
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Outdated
Show resolved
Hide resolved
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Outdated
Show resolved
Hide resolved
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Outdated
Show resolved
Hide resolved
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Outdated
Show resolved
Hide resolved
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Outdated
Show resolved
Hide resolved
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Outdated
Show resolved
Hide resolved
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Outdated
Show resolved
Hide resolved
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Outdated
Show resolved
Hide resolved
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Outdated
Show resolved
Hide resolved
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Outdated
Show resolved
Hide resolved
f3e0200
to
7699d56
Compare
@ManfeiBai give me a ping when CI is green |
617d3b2
to
5cf7c9c
Compare
fork PR failed, track in #7157 |
thanks for review and comments, since CI failed to start, track test result in #7157 |