Goal
Currently we support branching at the symbolic level by evaluating each path and doing a weighted sum of the branches given a conditional. This is fast for small branches with shared CSE, but cases with significant non-shared work will be slower than code that has a true branch.
The goal is to resolve this with a hard branching option in codegen that doesn't evaluate every branch.A small motivating example here was Rot3.from_rotation_matrix.
Piecewise functions
SymPy has Piecewise functions that print to inline ternaries that are quite relevant here. Side note: These were super slow for me at the symbolic level in the past dealing with splines, so I had implemented a simpler form in our codebase with some stricter assumptions that was much better. However, it generally works. The way it does printing is with nested ternary statements. This is something that basically should work within SymForce, but doesn't solve this problem because non-shared CSE terms can't be included inside the ternaries as separate statements.

Basic Expression Implementation
It's not hard to make a simple custom expression in SymPy (or symengine). I think derivatives, CSE, and numeric substitution work nominally with the below. The main challenge is on the printing side and how to handle CSE. We could also decide to augment Piecewise instead.
import sympy as sm
class Branch(sm.Function):
nargs = 3
@classmethod
def eval(cls, conditional, x, y):
# Simplify if conditional is not symbolic
if conditional.is_number:
if conditional.is_positive:
return x
else:
return y
def _eval_derivative(self, s):
conditional = self.args[0]
operands = self.args[1:]
return self.__class__(conditional, *(v.diff(s) for v in operands))

CSE
CSE will pull out intermediate terms across everything. Let's say we then make a function for identifying which intermediates are needed where given the set of temporaries and the symbolic expression. We could somehow try to first do CSE on the inner scopes, but I think it might be easier to use the current CSE then post-process and pass that to printing somehow.
During printing SymPy traverses the expression tree as scalars and assumes you're printing expressions. This makes it hard to put statements inline. For example, how do we compute tmp1 only when needed in the example below? (ignoring tmp0 for now, but that is complex in its own right as it's needed in 2 out of 3 branches)
double Func(double x) {
double tmp0 = pow(x, 2);
double tmp1 = sin(x);
return ((x < -1) ? (
0
)
: ((x <= 1) ? (
tmp0
)
: (
tmp0/tmp1 + tmp1
)));
}
Handwritten, we'd hope for something like:
double Func2(double x) {
double tmp0 = pow(x, 2);
if (x < -1) {
return 0;
} else if (x <= 1) {
return tmp0;
} else {
double tmp1 = sin(x);
return tmp0/tmp1 + tmp1;
}
}
One idea is to use lambdas and let the complier optimize those out:
double Func3(double x) {
double tmp0 = pow(x, 2);
return ((x < -1) ? (
0
)
: ((x <= 1) ? (
tmp0
)
: [&](){
double tmp1 = sin(x);
return tmp0/tmp1 + tmp1;
}()));
}
Well if we're doing that we could also do it on the whole thing and take control of the printing of that piecewise scalar ourselves.
double Func4(double x) {
double tmp0 = pow(x, 2);
return [&]() -> double{
if (x < -1) {
return 0;
} else if (x <= 1) {
return tmp0;
} else {
double tmp1 = sin(x);
return tmp0/tmp1 + tmp1;
}
}();
}
I realized in this simple case the compiler is going to figure it out anyway and not compute sin unless needed. Can see this here, with all three versions being equivalent:
https://godbolt.org/z/G11rb1Yh8
So on one hand we could implement the lambda thing. On the other hand, one could argue what the point is if the compiler does it. Some things to think about here:
- Will the compiler be able to do this in complicated scenarios where there are chains of branches on top of each other? Is our codegen
from_rotation_matrix actually more instructions at the assembly level?
- There is potentially some value in making the human-readable code easy to understand, pushing towards lambda over ternary and factoring the sub-expressions close to usage.
- In many of our use cases we want to have a branch that applies to a bunch of outputs, not just one scalar output. Each output will share a conditional and likely share sub-expressions. In the current formulations, each scalar would generate its own if statements, and we might see that any sub-expressions needed for two scalars (which could be in the same logical branch) will still need to be at the top level.
- SymPy pretty strongly draws a line that a function takes in multiple arguments and returns a scalar. Where could we try to inject the idea of a vector-valued output and effectively use that in printing?
Interested to discuss these points, particularly using Rot3.from_rotation_matrix as an example.
Weighted Sum
Might be better and more general to support a WeightedSum instead of a Branch, where you have a vector of weights and arguments and return the sum. A two-way only branch and relying on nesting seems prone to ugliness. Maybe the two-way case can be a wrapper helper.
Tangent: Naming
Where, Branch, If, Cond
Goal
Currently we support branching at the symbolic level by evaluating each path and doing a weighted sum of the branches given a conditional. This is fast for small branches with shared CSE, but cases with significant non-shared work will be slower than code that has a true branch.
The goal is to resolve this with a hard branching option in codegen that doesn't evaluate every branch.A small motivating example here was
Rot3.from_rotation_matrix.Piecewise functions
SymPy has Piecewise functions that print to inline ternaries that are quite relevant here. Side note: These were super slow for me at the symbolic level in the past dealing with splines, so I had implemented a simpler form in our codebase with some stricter assumptions that was much better. However, it generally works. The way it does printing is with nested ternary statements. This is something that basically should work within SymForce, but doesn't solve this problem because non-shared CSE terms can't be included inside the ternaries as separate statements.
Basic Expression Implementation
It's not hard to make a simple custom expression in SymPy (or symengine). I think derivatives, CSE, and numeric substitution work nominally with the below. The main challenge is on the printing side and how to handle CSE. We could also decide to augment Piecewise instead.
CSE
CSE will pull out intermediate terms across everything. Let's say we then make a function for identifying which intermediates are needed where given the set of temporaries and the symbolic expression. We could somehow try to first do CSE on the inner scopes, but I think it might be easier to use the current CSE then post-process and pass that to printing somehow.
During printing SymPy traverses the expression tree as scalars and assumes you're printing expressions. This makes it hard to put statements inline. For example, how do we compute
tmp1only when needed in the example below? (ignoringtmp0for now, but that is complex in its own right as it's needed in 2 out of 3 branches)Handwritten, we'd hope for something like:
One idea is to use lambdas and let the complier optimize those out:
Well if we're doing that we could also do it on the whole thing and take control of the printing of that piecewise scalar ourselves.
I realized in this simple case the compiler is going to figure it out anyway and not compute
sinunless needed. Can see this here, with all three versions being equivalent:https://godbolt.org/z/G11rb1Yh8
So on one hand we could implement the lambda thing. On the other hand, one could argue what the point is if the compiler does it. Some things to think about here:
from_rotation_matrixactually more instructions at the assembly level?Interested to discuss these points, particularly using
Rot3.from_rotation_matrixas an example.Weighted Sum
Might be better and more general to support a WeightedSum instead of a Branch, where you have a vector of weights and arguments and return the sum. A two-way only branch and relying on nesting seems prone to ugliness. Maybe the two-way case can be a wrapper helper.
Tangent: Naming
Where, Branch, If, Cond