AD Enhancement: Avoid SSA-ing aggregate types by treating pointer-based element & field accesses as if they were differentiable method calls. #4197
Labels
goal:quality & productivity
Quality issues and issues that impact our productivity coding day to day inside slang
kind:enhancement
a desirable new feature, option, or behavior
kind:performance
things we wish were faster
Milestone
TLDR:
Our 'super-SSA' pass currently turns large data-types (arrays, structs) into pure value types by turning all pointer accesses into
GetElement
orUpdateElement
. This is easy for the auto-diff pass to work with, but in the presence of non-trivial control-flow, can lead to heavy duplication of large data-structures.We can avoid this situation by instead treating the setting or getting of
a[i]
as if it's a function call toload
orstore
with anIRVar
and an indexi
(or field id in the case of a struct type). This lets us avoid having to create copies of the array data.We would likely need a pass to convert chains of
IRGetElementPtr
andIRLoad
/IRStore
to something likeIRLoadElement
/IRStoreElement
. These instructions will need to accept an arbitrary number of operands as we can have arbitrarily nested types that need multiple lookups to get to a primitive value.Long form:
Here is an example of a not-uncommon snippet that triggers this problem
This results in (roughly) the following IR for the primal context function and the backward pass function.
The key issue is that the context now contains an array of arrays to hold the state of the array after each iteration, even though only one element is really being mutated at a time. This is because our SSA pass converts the entire array into a loop state variable (loop phi value), and our checkpointing mechanism needs to store all loop phi values into the context var.
Instead, if were to write the same function above in this manner, utilizing an indexed load/store
This is a function that will not use an array of arrays to store anything, dramatically reducing the memory complexity of the resulting function. since the reverse-mode gradient will simply deposit the gradient directly to the appropriate element. Further, since only the result of the load operations requires caching, this avoids caching the entire array state at each step.
The text was updated successfully, but these errors were encountered: