New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes to add user scratch pad for matmul primitive to fix OOM issue in Transformer LT #54381
Changes to add user scratch pad for matmul primitive to fix OOM issue in Transformer LT #54381
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
This reduces memory footprint of the primitive.
Could you please help clarify how providing a user-allocated scratchpad helps save the memory in this case? The user-allocated memory doesn't seem to be reused (only called once per each Compute
call). Does the primitive over-allocate memory when not given user-allocated scratchpad? The primitive is not cached, e.g., a new one is created each time the op is called, so shouldn't it be able to allocate just the amount it needs?
@penpornk we are working to enable oneDNN scratchpad "user-mode" for conv, inner-product, and matmul ops. To answer the last quest: (2) scratchpad buffer is allocated with Compute(), just before invoking oneDNN execution. Thanks |
@gzmkl Thank you for the quick reply! I understand that the user scratchpad is allocated in |
@penpornk With UserScratchPad, the TF framework controls the scratchpad by creating and releasing the scratchpad memory as needed : -ie (as of now) create it before execution and release it after execution of the primitive. This keeps the memory consumption low for the entire model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jojivk73 Thank you for the detailed reply!
I misread the do_not_cache
parameter. I thought that false
means the primitive is not cached (and assumed that the primitive is destroyed along with its auto-allocated scratchpad at the end of the kernel call -- hence having a similar lifetime as user-allocated scratchpad.) I understand now that the primitive here is cached and therefore will hold on to the scratchpad memory for as long as it stays in the LRU.
tensorflow/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc
Lines 151 to 154 in 598681c
// Create or retrieve matmul primitive from cache. | |
MklMatMulPrimitive<Tlhs, Trhs, Toutput>* matmul_prim = | |
MklMatMulPrimitiveFactory<float, Tlhs, Trhs, Toutput>::Get( | |
*params, false /* value for do_not_cache */); |
Rolling back PR #54381 because it broke Windows continuous build. PiperOrigin-RevId: 434790571
Adding changes for the matmul primitive to use user scratch pad. This reduces memory footprint of the primitive. It fixes an out of memory issue when running Transformer LT with multiple instances and total thread count is large. Managing scratch pad for the primitive from the framework, fixes the out of memory issue, reduces memory footprint and does not affect performance. The changes :
Creates a new struct that hold the Tensor for scratch pad arg.
Allocates memory based on scratch pad size queried from primitive description.
Sets user scratch pad in post ops for the primitive.