-
Notifications
You must be signed in to change notification settings - Fork 75k
Memory Safe Computations with XLA #58679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks a lot for the PR, looks very impressive! Please give us some time to figure this out! |
|
@awav Can you please resolve conflicts? Thank you! |
|
Hi @awav Can you please resolve conflicts? Thank you! |
|
This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
|
Hi, @gbaned! I will get back to the PR next week. I was going to split it in 3 parts to make it more review-friendly:
Wdyt? Or should I keep this PR as is (but synced with master)? |
|
@gbaned do you know anything about https://github.com/openxla/xla? Will PRs to tensorflow automatically merged into the open XLA? |
|
Yes, for now. |
|
Hi @awav Can you please resolve conflicts? Thank you! |
|
This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
|
Hi @awav Any update on this PR? Please. Thank you! |
1 similar comment
|
Hi @awav Any update on this PR? Please. Thank you! |
|
@gbaned, apologies for not being more active here. I'll try to find a workaround to catch up with current updates. Although, this would be a bit challenging considering how many new things XLA has introduced. Any help would be appreciated. |
|
Hi @cheshire Any update on this PR? Please. Thank you! |
|
@cheshire Hi!, There will be another attempt to merge current state of the branch with the main branch next week. Also flash attention (https://arxiv.org/abs/2205.14135) is kind of doing the similar thing as this PR. This PR implements more naive version, however, it might be improved. |
|
Hi @cheshire Any update on this PR? Please. Thank you! |
|
Hi @cheshire Any update on this PR? Please. Thank you! |
1 similar comment
|
Hi @cheshire Any update on this PR? Please. Thank you! |
|
Hi @awav Can you please resolve conflicts? Thank you! |
|
This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
|
This PR was closed because it has been inactive for 14 days since being marked as stale. Please reopen if you'd like to work on this further. |
Hello,
In this pull request I would like to introduce the code of the paper that has been accepted at the NeurIPS 2022. This is the joint work of Yuze An (@melody-an), Tilman Roeder (@dyedgreen), Mark van der Wilk (@markvdw) and myself (@awav).
We added optimization passes to the XLA optimization pipeline that automatically adapt the computational graph to make the computation more memory tolerable on a single device. That is, the code that fails on a single device with the out of memory (OOM) issue would be automatically modified by the compiler to an analogue that fits into memory and runs successfully. In general, many computational graphs can be transformed invariantly to the computation into another representation promoting speed and (or) memory advantages. Our implementation does not offer updates in the memory allocation manager and focuses solely on adjusting the computational graph. We call the set of introduced optimizations - eXLA. The eXLA consists of the following passes:
euclidean_distance_rewriter.[h|cc]). We replace euclidean distance matrices of the formsum((A[..., :, None, :] - B[..., None, : , :])**2, axis=-1)with more efficient and less memory demanding counterpartsum(A ** 2, axis=-1)[..., :, None] + sum(B ** 2, axis=-1)[..., :, None] - 2 * A @ B.T.hlo_mco.[h|cc]). At this optimization pass eXLA finds a chain of matrix products and optimizes the order of evaluation using a classic dynamic programming optimization routine. E.g. the expression of(A @ B) @ v, whereA (B @ v)expression. Our implementation covers also transposes and reducing sum operations.tensor_splitter.[h|cc]). This optimization pass searches for splitting paths that start and end with operations generating tensors that are too big for allocation and reducing those big tensors to smaller tensors, respectively. The outer product of two vectors is an example of an operation that generates a big tensor. A summation of all elements in the tensor is an example of a reducing operation. One of the important properties of that splitting path is that all operations applied to big tensors between generating and reducing operations act on at least one dimension linearly or independently. By linear and independent operation, we mean an operation that can be applied to slices of a tensor (slices along the dimension) and afterwards these slices can be combined to get the same result as if the operation was applied to the whole tensor.rce_optimizer.[h|cc]) and reshape sinker (reshape_sinker.[h|cc]) optimization passes which prepare the computational graph for tensor splitting optimization pass. In some cases, graphs can be polluted with reshape, broadcast, transpose and reduce operations which can be safely removed from the graph and increase the chance of splitting longer paths.Our extension allows a user to control the splitting optimization via two options:
xla_tensor_size_thresholdandxla_tensor_split_size. Thexla_tensor_size_thresholdoption controls when the splitting optimization should be run to search for the splitting path, i.e. if the tensor size is larger than the threshold, then the splitting procedure will be triggered. Thexla_tensor_split_sizeoption exists to decide the size of slices into which the splitting optimization procedure will chunk tensors on the splitting path. By default,xla_tensor_split_sizeequalsxla_tensor_size_threshold. An example:XLA_FLAGS="--xla_tensor_size_threshold=1GB --xla_tensor_split_size=500MB" python train.pyWe tested eXLA on sparse Gaussian processes (SGPR), kNN, and Transformer models. In the case of SGPR, the determinant factor for performance is the number of inducing points, i.e. the number of points in the input space used to describe the target function. We showed that without any change in the SGPR's code from GPflow package, it is possible to scale the model to the much larger number of inducing points making the model perform best compared to the competitors.
In the experiment with Transformer models, we used the configuration from the TensorFlow tutorial. We were changing the sequence length and hence the self-attention matrix size. In the experiment, we were able to execute the model on the sequence length 2000 for out-of-the-box standard TensorFlow implementations compilers (XLA). The TensorFlow code for a sequence length of more than 2000 crashed with OOM. With eXLA, we ran the same code up to the sequence length 7000.
You can find more details in the paper.
I would appreciate it if someone could give a feedback and an opinion on how these features can be materialized as a part of future XLA release.
Thanks!