Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD Enhancement: Avoid SSA-ing aggregate types by treating pointer-based element & field accesses as if they were differentiable method calls. #4197

Open
saipraveenb25 opened this issue May 20, 2024 · 0 comments
Assignees
Labels
goal:quality & productivity Quality issues and issues that impact our productivity coding day to day inside slang kind:enhancement a desirable new feature, option, or behavior kind:performance things we wish were faster

Comments

@saipraveenb25
Copy link
Collaborator

TLDR:
Our 'super-SSA' pass currently turns large data-types (arrays, structs) into pure value types by turning all pointer accesses into GetElement or UpdateElement. This is easy for the auto-diff pass to work with, but in the presence of non-trivial control-flow, can lead to heavy duplication of large data-structures.

We can avoid this situation by instead treating the setting or getting of a[i] as if it's a function call to load or store with an IRVar and an index i (or field id in the case of a struct type). This lets us avoid having to create copies of the array data.

We would likely need a pass to convert chains of IRGetElementPtr and IRLoad/IRStore to something like IRLoadElement/IRStoreElement. These instructions will need to accept an arbitrary number of operands as we can have arbitrarily nested types that need multiple lookups to get to a primitive value.

Long form:

Here is an example of a not-uncommon snippet that triggers this problem

[Differentiable]
float[10] g(float[10] a)
{
    for (int i = 0; i < 10; i++)
    {
        a[i] = f(a[i]);
    }

    return a;
}

This results in (roughly) the following IR for the primal context function and the backward pass function.
The key issue is that the context now contains an array of arrays to hold the state of the array after each iteration, even though only one element is really being mutated at a time. This is because our SSA pass converts the entire array into a loop state variable (loop phi value), and our checkpointing mechanism needs to store all loop phi values into the context var.

//// Primal context pass. ////

// Entry block
%t = OpBlock
{
    // Context storage for all loop phi variables (n_iters + 1)
    %ctx_a = IRVar : %array(%array(%float, 10), 11) // Catastrophically large amount of storage.
    %ctx_i = IRVar : %array(%float, 11)

    OpLoop %c %br %c 0
}

// Condition
%c = OpBlock
{
    %i = OpParam : %float
    %a = OpParam : %array(%float, 10)

    // Context store operations.
    %ctx_i_ptr = OpGetElementPtr(%ctx_i, %i) : %ptr(%int)
    OpStore(%ctx_i_ptr, %i)
    %ctx_a_ptr = OpGetElementPtr(%ctx_a, %i) : %ptr(%array(%float, 10))
    OpStore(%ctx_a_ptr, %a)
    
    %2 = OpLesser(%i, 10) : %bool

    %OpIfElse(%2, %b, %br, %br)
}

// Loop body.
%b = OpBlock 
{ /*...*/ }

// Break block
%br = OpBlock
{ /*...*/ }

//// Backprop pass /////

// Entry block
%t_rev = OpBlock
{
    // Count down from the end
    OpLoop %c_rev %br_rev %c_rev 9 

    // Variable to hold the derivative of %a
    %var_da_rev = OpVar : %ptr(%array(%float, 10))
}

// Condition
%c_rev = OpBlock
{
    // rev-mode loop counter (runs backwards from limit to 0)
    %dc = OpParam : %int
    
    %2 = OpLesser(%i, 10) : %bool

    OpIfElse %2 %b %br %br
}

// Loop body.
%b_rev = OpBlock 
{
    // Context load operations.
    %ctx_i_ptr = OpGetElementPtr(%ctx_i, %dc) : %ptr(%int)
    %i_saved = OpLoad(%ctx_i_ptr) : %int

    %ctx_a_ptr = OpGetElementPtr(%ctx_a, %dc) : %ptr(%array(%float, 10))
    %a_saved = OpLoad(%ctx_a_ptr) : %array(%float, 10)

    %a_i = OpGetElement(%a_saved, %i_saved) : %float
    %a_pair_i = OpMakeDifferentialPair(%a_i, 0) : %diff_pair(%float)

    %da_rev_ptr = OpGetElementPtr(%var_da_rev, %i_saved) : %ptr(%float)
    %df_output = OpLoad(%da_rev_ptr) : %float

    // Call rev-mode of f to propagate derivative of output of f to input of f. (Assume f has no context requirement)
    %var_a_pair_i = OpVar : %ptr(%diff_pair(%float))
    OpStore(%var_a_pair_i, %a_pair_i)
    OpCall(f_rev, %a_pair_i, %df_output) : %float 

    // Load derivative for a_i
    %a_pair_i_loaded = OpLoad(%var_a_pair_i, %a_pair_i)
    %da_rev_i = OpDifferentialPairGetDifferential(%a_pair_i_loaded) : %float

    // Create derivative array for backpropagation (this happens during gradient materialization)
    %da_rev_local_var = OpVar : %ptr(%array(%float, 10))
    %da_rev_init_zero = OpMakeArray(0, 0, 0, 0, 0, 0, 0, 0, 0, 0) : %array(%float, 10)
    OpStore(%da_rev_local_var, %da_rev_init_zero)

    %da_rev_var_i = OpGetElementPtr(%da_rev_local_var, %dc) : %ptr(%float)
    %curr_dval = OpLoad(%da_rev_var_i) : %float
    %acc_dval = OpAdd(%curr_dval, %da_rev_i) : %float
    OpStore(%da_rev_var_i, %acc_dval)

    // Add derivative array to the global var.
    %curr_dval_a = OpLoad(%var_da_rev) : %array(%float, 10)
    %new_dval_a = OpLoad(%da_rev_local_var) : %array(%float, 10)
    %acc_dval_a = OpCall('array_dadd', %curr_dval_a, %new_dval_a) : %array(%float, 10)
    OpStore(%var_da_rev, %acc_dval_a)

    %dc_next = OpAdd(%dc, -1)

    OpUnconditionalBranch(%c_rev, %dc_next)
}

// Break block
%br_rev = OpBlock
{ /*...*/ }

Instead, if were to write the same function above in this manner, utilizing an indexed load/store

groupshared float shared[10];
groupshared float d_shared[10];

float load(int i)
{
    return shared[i];
}

[BackwardDerivativeOf(load)]
void rev_d_load(int i, float value)
{
    d_shared[i] = value;
}

void store(int i, float value)
{
    shared[i] = value;
}

[BackwardDerivativeOf(store)]
void rev_d_store(int i, inout DifferentialPair<float> dpv)
{
    dpv = DifferentialPair<float>(dpv.getPrimal(), d_shared[i]);
}

[Differentiable]
void g()
{
    for (int i = 0; i < 10; i++)
    {
        store(i, f(load(i)));
    }
}

This is a function that will not use an array of arrays to store anything, dramatically reducing the memory complexity of the resulting function. since the reverse-mode gradient will simply deposit the gradient directly to the appropriate element. Further, since only the result of the load operations requires caching, this avoids caching the entire array state at each step.

@saipraveenb25 saipraveenb25 added kind:enhancement a desirable new feature, option, or behavior kind:performance things we wish were faster goal:quality & productivity Quality issues and issues that impact our productivity coding day to day inside slang labels May 20, 2024
@saipraveenb25 saipraveenb25 added this to the Q3 2024 (Summer) milestone May 20, 2024
@saipraveenb25 saipraveenb25 self-assigned this May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
goal:quality & productivity Quality issues and issues that impact our productivity coding day to day inside slang kind:enhancement a desirable new feature, option, or behavior kind:performance things we wish were faster
Projects
None yet
Development

No branches or pull requests

1 participant