Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fori_loop|While_loop] Enable while_loop/fori_loop, Add linear/MNIST test case #7094

Closed
wants to merge 383 commits into from

Conversation

ManfeiBai
Copy link
Collaborator

@ManfeiBai ManfeiBai commented May 22, 2024

  • Modify BuildForiLoop for while_loop/fori_loop
  • Modify test code for while_loop/fori_loop with simple version
  • Add linear_layer_model and MNIST_model(without BN layer) test case

l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device)
l_out = torch.randn(bs, 10, dtype=torch.float32, device=device)
iteri = torch.tensor(3, dtype=torch.int64, device=device)
_, _, res = mnist(iteri, l_in_0, l_out)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we init the l_out inside the mnist forward and returns it? This way the only additional output we provide is the count.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, good idea, due to hard to gain result before finish one run, so current plan would be:
run body_fn with pure python while, get result, then use this result as l_out

before implementation, did quick benchmark for this idea with code:

- iteri = 3: 
  while_loop(3.675 s), pure while(3.107 s), pure_while_then_while_loop_plan(3.724 s)
- iteri = 10: 
  while_loop(3.670s), pure while(3.119s), pure_while_then_while_loop_plan(3.932s)
- iteri = 10: 
  while_loop(3.744 s), pure while(3.110 s), pure_while_then_while_loop_plan(3.749 s)
- iteri = 10: 
  while_loop(3.878 s), pure while(7.053 s), pure_while_then_while_loop_plan(3.809 s)

so this plan looks like worth to try

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only review the test part, left some comments

@ManfeiBai ManfeiBai changed the title Enable while_loop/fori_loop, Add linear/MNIST and _get_xla_computation Enable while_loop/fori_loop, Add linear/MNIST test case May 30, 2024
@ManfeiBai ManfeiBai changed the title Enable while_loop/fori_loop, Add linear/MNIST test case [Fori_loop|While_loop] Enable while_loop/fori_loop, Add linear/MNIST test case May 30, 2024
@ManfeiBai ManfeiBai requested a review from JackCaoG May 30, 2024 18:20
@JackCaoG
Copy link
Collaborator

@ManfeiBai give me a ping when CI is green

@ManfeiBai ManfeiBai marked this pull request as ready for review May 30, 2024 18:37
@ManfeiBai
Copy link
Collaborator Author

fork PR failed, track in #7157

@ManfeiBai
Copy link
Collaborator Author

thanks for review and comments, since CI failed to start, track test result in #7157

@ManfeiBai ManfeiBai closed this May 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants