Skip to content

Conversation

@awav
Copy link

@awav awav commented Nov 24, 2022

Hello,

In this pull request I would like to introduce the code of the paper that has been accepted at the NeurIPS 2022. This is the joint work of Yuze An (@melody-an), Tilman Roeder (@dyedgreen), Mark van der Wilk (@markvdw) and myself (@awav).

We added optimization passes to the XLA optimization pipeline that automatically adapt the computational graph to make the computation more memory tolerable on a single device. That is, the code that fails on a single device with the out of memory (OOM) issue would be automatically modified by the compiler to an analogue that fits into memory and runs successfully. In general, many computational graphs can be transformed invariantly to the computation into another representation promoting speed and (or) memory advantages. Our implementation does not offer updates in the memory allocation manager and focuses solely on adjusting the computational graph. We call the set of introduced optimizations - eXLA. The eXLA consists of the following passes:

  • Match and replace (peephole) optimization (euclidean_distance_rewriter.[h|cc]). We replace euclidean distance matrices of the form sum((A[..., :, None, :] - B[..., None, : , :])**2, axis=-1) with more efficient and less memory demanding counterpart sum(A ** 2, axis=-1)[..., :, None] + sum(B ** 2, axis=-1)[..., :, None] - 2 * A @ B.T.
  • Order optimization (hlo_mco.[h|cc]). At this optimization pass eXLA finds a chain of matrix products and optimizes the order of evaluation using a classic dynamic programming optimization routine. E.g. the expression of (A @ B) @ v, where $A, B \in R^{n \times n}$ matrices and $v \in R^{n}$ will be reordered to the A (B @ v) expression. Our implementation covers also transposes and reducing sum operations.
  • Tensor splitting (tensor_splitter.[h|cc]). This optimization pass searches for splitting paths that start and end with operations generating tensors that are too big for allocation and reducing those big tensors to smaller tensors, respectively. The outer product of two vectors is an example of an operation that generates a big tensor. A summation of all elements in the tensor is an example of a reducing operation. One of the important properties of that splitting path is that all operations applied to big tensors between generating and reducing operations act on at least one dimension linearly or independently. By linear and independent operation, we mean an operation that can be applied to slices of a tensor (slices along the dimension) and afterwards these slices can be combined to get the same result as if the operation was applied to the whole tensor.
  • We also implemented redundant code elimination (RCE, rce_optimizer.[h|cc]) and reshape sinker (reshape_sinker.[h|cc]) optimization passes which prepare the computational graph for tensor splitting optimization pass. In some cases, graphs can be polluted with reshape, broadcast, transpose and reduce operations which can be safely removed from the graph and increase the chance of splitting longer paths.

Our extension allows a user to control the splitting optimization via two options: xla_tensor_size_threshold and xla_tensor_split_size. The xla_tensor_size_threshold option controls when the splitting optimization should be run to search for the splitting path, i.e. if the tensor size is larger than the threshold, then the splitting procedure will be triggered. The xla_tensor_split_size option exists to decide the size of slices into which the splitting optimization procedure will chunk tensors on the splitting path. By default, xla_tensor_split_size equals xla_tensor_size_threshold. An example:

XLA_FLAGS="--xla_tensor_size_threshold=1GB --xla_tensor_split_size=500MB" python train.py

We tested eXLA on sparse Gaussian processes (SGPR), kNN, and Transformer models. In the case of SGPR, the determinant factor for performance is the number of inducing points, i.e. the number of points in the input space used to describe the target function. We showed that without any change in the SGPR's code from GPflow package, it is possible to scale the model to the much larger number of inducing points making the model perform best compared to the competitors.

SGPR-plot-poster

In the experiment with Transformer models, we used the configuration from the TensorFlow tutorial. We were changing the sequence length and hence the self-attention matrix size. In the experiment, we were able to execute the model on the sequence length 2000 for out-of-the-box standard TensorFlow implementations compilers (XLA). The TensorFlow code for a sequence length of more than 2000 crashed with OOM. With eXLA, we ran the same code up to the sequence length 7000.

Transformer-poster

You can find more details in the paper.
I would appreciate it if someone could give a feedback and an opinion on how these features can be materialized as a part of future XLA release.

Thanks!

@google-ml-butler google-ml-butler bot added the size:XL CL Change Size:Extra Large label Nov 24, 2022
@google-ml-butler google-ml-butler bot requested a review from r4nt November 24, 2022 22:52
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Nov 24, 2022
@gbaned gbaned added the comp:xla XLA label Nov 25, 2022
@gbaned gbaned requested a review from cheshire November 25, 2022 11:15
@cheshire
Copy link
Contributor

Thanks a lot for the PR, looks very impressive! Please give us some time to figure this out!

@sherhut sherhut self-requested a review November 28, 2022 15:50
@gbaned gbaned requested review from cheshire and removed request for cheshire December 5, 2022 10:52
@gbaned
Copy link
Contributor

gbaned commented Dec 29, 2022

@awav Can you please resolve conflicts? Thank you!

@gbaned gbaned added stat:awaiting response Status - Awaiting response from author and removed awaiting review Pull request awaiting review labels Dec 29, 2022
@gbaned
Copy link
Contributor

gbaned commented Mar 21, 2023

Hi @awav Can you please resolve conflicts? Thank you!

@github-actions
Copy link

github-actions bot commented Apr 5, 2023

This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Apr 5, 2023
@awav
Copy link
Author

awav commented Apr 5, 2023

Hi, @gbaned! I will get back to the PR next week. I was going to split it in 3 parts to make it more review-friendly:

  1. Matrix chain optimisations
  2. Distance matrix optimisations
  3. Splitting optimisations

Wdyt? Or should I keep this PR as is (but synced with master)?

@google-ml-butler google-ml-butler bot removed stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author labels Apr 5, 2023
@awav
Copy link
Author

awav commented Apr 20, 2023

@gbaned do you know anything about https://github.com/openxla/xla? Will PRs to tensorflow automatically merged into the open XLA?

@mihaimaruseac
Copy link
Collaborator

Yes, for now.

@gbaned
Copy link
Contributor

gbaned commented May 5, 2023

Hi @awav Can you please resolve conflicts? Thank you!

@gbaned gbaned added the stat:awaiting response Status - Awaiting response from author label May 5, 2023
@github-actions
Copy link

This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label May 20, 2023
@gbaned
Copy link
Contributor

gbaned commented Jun 1, 2023

Hi @awav Any update on this PR? Please. Thank you!

1 similar comment
@gbaned
Copy link
Contributor

gbaned commented Jun 23, 2023

Hi @awav Any update on this PR? Please. Thank you!

@awav
Copy link
Author

awav commented Jun 30, 2023

@gbaned, apologies for not being more active here. I'll try to find a workaround to catch up with current updates. Although, this would be a bit challenging considering how many new things XLA has introduced. Any help would be appreciated.

@google-ml-butler google-ml-butler bot removed stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author labels Jun 30, 2023
@gbaned
Copy link
Contributor

gbaned commented Aug 25, 2023

Hi @cheshire Can you please assist on above comments from @awav. Thank you!

@gbaned
Copy link
Contributor

gbaned commented Sep 25, 2023

Hi @cheshire Any update on this PR? Please. Thank you!

@awav
Copy link
Author

awav commented Oct 25, 2023

@cheshire Hi!, There will be another attempt to merge current state of the branch with the main branch next week. Also flash attention (https://arxiv.org/abs/2205.14135) is kind of doing the similar thing as this PR. This PR implements more naive version, however, it might be improved.

@gbaned
Copy link
Contributor

gbaned commented Dec 14, 2023

Hi @cheshire Any update on this PR? Please. Thank you!

@gbaned gbaned added the awaiting review Pull request awaiting review label Dec 14, 2023
@gbaned
Copy link
Contributor

gbaned commented Dec 29, 2023

Hi @cheshire Any update on this PR? Please. Thank you!

1 similar comment
@gbaned
Copy link
Contributor

gbaned commented Jan 19, 2024

Hi @cheshire Any update on this PR? Please. Thank you!

@gbaned
Copy link
Contributor

gbaned commented Mar 8, 2024

Hi @awav Can you please resolve conflicts? Thank you!

@gbaned gbaned added stat:awaiting response Status - Awaiting response from author and removed awaiting review Pull request awaiting review labels Mar 8, 2024
@github-actions
Copy link

This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Mar 23, 2024
@github-actions
Copy link

github-actions bot commented Apr 7, 2024

This PR was closed because it has been inactive for 14 days since being marked as stale. Please reopen if you'd like to work on this further.

@github-actions github-actions bot closed this Apr 7, 2024
@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Apr 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

comp:xla XLA size:XL CL Change Size:Extra Large stale This label marks the issue/pr stale - to be closed automatically if no activity

Projects

Status: Closed/Rejected

Development

Successfully merging this pull request may close these issues.

4 participants