Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[IR] First-class StructInfo #293

Open
tqchen opened this issue Dec 2, 2022 · 11 comments
Open

[IR] First-class StructInfo #293

tqchen opened this issue Dec 2, 2022 · 11 comments

Comments

@tqchen
Copy link
Contributor

tqchen commented Dec 2, 2022

One of relax’s design goal is to enable dynamic shapes and program analysis based on dynamic StructInfos. The shape propagation helps us to build effective dynamic shape aware programs.

Shape can be viewed as one kind of “structural value” information — it tells us about some information about the runtime value.

As we start to do more developments, we find a few useful lessons and observations that could help us further evolve the design.

O0: Tracking shapes in advanced structural compositions

@R.func
def func(X: R.Tensor((n, m)), Y: R.Object):
        t: Tuple(R.Tensor, R.Object) = (X, Y)
        z: R.Tensor = t[0]

In the above program, we are composing X, Y as a tuple. Under such conditions, we would need to define a shape_ of the tuple to be able to trace the shape of z, however Y do not have shape.

O1: Desire for a single place grouping of information

The structural information about a value is spread in between type and shape, making printer and parser needing to be able to collect and values in both side. It would be simpler to have a clear location of grouping.

O2: Extension

As part of our research effort we are considering about extensibility of the system, and we will need to introduce other structural value information besides shape.

Taking these motivations into account, we propose to further formalize structural info deduction as part of relax.

Design

We will introduce a class called StructInfo, which is a composite data structure that is going to contain all the necessary structure information deduced for compiler. A StructInfo contains structural information about the corresponding compiled value. See the pesudo code below for a proposed classes.

class Expr:
        checked_type: Type
       # optional struct info attached to the Expr that provides
        # additional information about the Expr
        struct_info: Optional[StructInfo] 

class StructInfo:
        pass

class ShapeStructInfo(StructInfo):
        """Matches to shape value""".
        values: Optional[Array[PrimExpr]]
        ndim: int
        

class TensorStructInfo(StructInfo):
        """Matches to any tensor with shape and dtype.
        
        Matching behavior: populate tir.Var in shape
        """
        shape: Optional[Expr]   # None | ShapeExpr | Expr
        dtype: DataType
        ndim: int

class FuncStructInfo:
        """StructInfo on callable values, used for StructInfo deduction"""
        params: Array[StructInfo]
        # must be not null when opaque_fn_rel is None
        ret : Optional[StructInfo]
        # opaque StructInfo that gives an escape hatch to
        # derive StructInfo from call expr through opaque env packed func
        # takes pre-cedence with params and ret_StructInfo
        # when it exists.
        # They can be used to set types of OpNode with special intrinsics
        opaque_fn_rel: Optional[TypedEnvFunc[Call->StructInfo]]

class ObjectStructInfo(StructInfo):
        """Matches any object"""

class PrimStructInfo(StructInfo):
        """Matches any PrimValue"""

class TupleStructInfo:
        """Matches to tuple"""
        fields : Array[StructInfo]

class MatchCast(Binding):
        """Generalizes match shape

        Matches value to StructInfo, populates un-defined using match semantics
        Produce runtime-assert equality.
        """
        var : Var
        struct_info: StructInfo
        # only takes leaf values
        value: Expr

# analysis functions defined for StructInfo
def all_vars(p: StructInfo) -> Array[Var]:
        """List relax.Vars that this StructInfo contains"""

def all_tir_vars(p: StructInfo) ->Array[tir.Var]:
        """List all the TIR-vars that this StructInfo contains"""

def get_static_type(p: StructInfo) -> Type:
        """Get the static type of that this StructInfo belongs to."""

The key design consideration include

  • There is clear injective mapping between annotations in TVMScript and StructInfo
    • Each annotation in tvmscript now corresponds to a StructInfo
    • This can helps to simplify the parser and printer implementation
  • StructInfo can uniquely decide the static type of the Expr
    • It is important to have clear static type because static type is our safetynet
  • StructInfo provides extra (runtime value) information. These extra runtime value information(such as symbolic shape values) in intermediate values can always be safely erased at any time point down to static type information and the prorgam should still compiles.
  • Checking semantics:
    • We are ensure the corresponding static type is always consistent and checked, we can derive the checked_type field from the StructInfo.
    • The extra (runtime value) information are deduced at best-effort. When uncertain(during unification, or deduction) the compiler/pass can choose to return the static type.
    • Formally, the extra information have the semantics of “assume”.
  • When a StructInfo appears in argument annotation a function, the semantics is a match_cast

Example Programs

The program below shows how a StructInfo flows throughout a program. Noteable items include:

  • FuncStructInfo can help preserving the StructInfo deduction information
    • The particular deduction is depends on the strength of deducer and best effort deduction
    • When unsure, it is OK to fallback to static type version
  • TupleStructInfo helps to propagate shape information in construction and destruction.
  • match_cast, matches the value to the StructInfo. asserts it matches and populate un-defined symbolic shape variable if necessary and return a new value with the corresponding StructInfo.
@R.function
def subfunc(x: R.Tensor((n,)), y: R.Object) 
        -> R.Tuple[R.Tensor((n,)), R.Object]:
        return (x, y)

def example_tuple(x: R.Tensor(ndim=1), 
                              y: R.Object):
        # defines x1, n
        x1: R.Tensor((n,)) = R.match_cast(x, R.Tensor((n,)))

        # var assignment can propagate func StructInfo if necessary
        f : R.Func(args=[R.Tensor((n,))],
                        ret=R.Tuple[R.Tensor((n,)), R.Object]) = subfunc
        t : R.Tuple[R.Tensor((n,)), R.Object]) = f(x1, y)
        # deduction through tuple
        z : R.Tensor[(n,)] = t[0]
        return z

The program below shows a possible extension of StructInfo to support sparse computation.

## Possible future StructInfos
class SparseCSRTensorStructInfo:
        """Matches to any sparse CSR with indptr and indices"""
        indptr: Optional[Expr]
        indices: Optional[Expr]

def example_sparse(indptr: R.Tensor[(m,)], 
                                 indices: R.Tensor[(nnz,)], 
                                 data0: R.Tensor[(nnz,)], 
                                 data1: R.Tensor[(nnz,)]):
        y : R.SparseCSR(indptr, indices) = make_csr(indptr, indices, data0)
        x : R.SparseCSR(indptr, indices) = make_csr(indptr, indices, data1)
        z : R.SparseCSR(indptr, indices) = x + y
        return z

Note that the behavior of sparse addition z = x+y is dependent on whether x and y share the same indptr and indices. Having such information available at compile time can helps compilation optimizations.

Discussions

The extra information in Expr.struct_info does not come for free. Because StructInfo can depend on other values. We should view it as being bundled together with Expr, and consider it carefully when rewriting the code.

def example(x: R.Tensor):
        z = func0(x)
        s = shape_func(z)
        y: R.Tensor[s] = opaque_fn(x)   

Consider the above example, if we simply look at the input arguments of calls, we know that there is no dependency from y to z. One possible optimization might involve reordering y into the beginning of the function, or do dead-code to eliminate everything that is not referenced by y.

def example_wrong_after_reorder(x: R.Tensor):
        # naively reorder y into beginning because we forget 
        #  to look into struct_info field
        y: R.Tensor[s] = opaque_fn(x)   
        z = func0(x)
        s = shape_func(z)

def example_wrong_after_deadcode(x: R.Tensor):
        # naively deleted the shape calculation before we forget to count
        # references through y.struct_info
        y: R.Tensor[s] = opaque_fn(x)   

To track these dependencies, use all_vars(struct_info).

It is important to remember that the extra information of StructInfo takes assume semantics rather than static_assert. This means that we will only do best effort checking. To see why we need to take assume semantics. Consider the following function.

def example(x: R.Tensor):
        z = func0(x)
        s0 = shape_func(z)
        y: R.Tensor[s0] = opaque_fn(x)  

Imagine that we want to “recheck” the relation y: R.Tensor[s0] = opaque_fn(x). We will retrigger deduction function of opaque_fn. And obtain the following program. The second deduction will generate a fresh shape function call s1, imagine that shape_func is arbitrary sequence of computations. Then it is impossible to always proof the equivalence.

def example(x: R.Tensor):
        z = func0(x)
        s0 = shape_func(z)
        s1 = shape_func(z)
        # check s0 == s1
        y: R.Tensor[s0] = opaque_fn(x)

So when we print out TVMScript with already compiler deduced information, we will parse these info back as they are, to ensure round-trip capabilities. To enable user provided information and runtime check, we can always rely on match_cast.

One can view StructInfo as equivalent to “dependent type”. However a normally type system usually have the follwing properties:

  • They cannot be simply erased from Expr
  • They can be checked repeitively, and semantics follows static_assert.
  • The result is usually predictable.

The extra value information in StructInfo StructInfo have the following properties:

  • Use assume semantics here
  • They can be safely erased to static type. The code can still compile correctly.
    • Such erasure of course will insert additional match checks on input function boundaries, to ensure extra StructInfo info holds.
  • StructInfo use to serve as extra information for compilation to do optimizations.
  • We can choose to erase, or generate informative StructInfos in compiler analysis.

As a result, the extra information is more akin to “extra optional analysis information that compiler can take”. We acknowledge the difficulty in doing full proves on extra runtime information. Instead, because all the information are available in runtime (Tensor) values, we use the static type as “safety net”. Static type also remains important and stable acorss dialects such as TIR and relax.

The relation between static type and StructInfo are:

  • A StructInfo can always deduce (and erase-down to) a corresponding static type.
  • We can convert static type can corresponds to a minimum StructInfo that maps to the type.
  • During compilation, we do strong type checkings on static types, and best-effort deduction on the extra information.

Because of these above considerations, we still believe that it is important to distinguish (static)type and StructInfo, and call them out separately.

Upgrading to the StructInfo

Update to use StructInfo can be mechanical, as we can need to change the shape deduction to StructInfo deductions, match_shape to match_cast. We can also create shape accessor functions that redirects to TensorStructInfo’s shape field to obtain the corresponding symbolic shape. We can choose to always run StructInfo deduction then use get_static_type to set the corresponding static types. To enable static type check, get some existing type deduction may not be harmful to avoid StructInfo deduction generate additional bindings. The additional structural information can help us simplify writing parser, printer and propagations across functions and mixed tuple compositions.

@slyubomirsky
Copy link
Collaborator

To be clear, this is proposing to replace the current use of shape_? I would definitely be in favor of having something that maps more clearly to the annotations in TVMScript, as this proposal discusses.

@YuchenJin
Copy link
Collaborator

To be clear, this is proposing to replace the current use of shape_? I would definitely be in favor of having something that maps more clearly to the annotations in TVMScript, as this proposal discusses.

Yes, this proposed Pattern will replace the current shape_, and it would open doors to more useful but erasable information besides shape (termed as "structural value information" in Tianqi's proposal).

I also very like it directly maps to the annotations in TVMScript, and it can remove the duplicated code when we register an Op -- now we have FInferType and FInferShape which have overlapping logics (ndim in type/shape), with Pattern, we only need FInferPattern.

@slyubomirsky
Copy link
Collaborator

I'd like to clarify something about the examples in the discussion section: Do we want to permit Relax variables to appear in shape annotations? I wrote the draft spec on the belief that we do not permit it and that shape variables can be introduced only in MatchShape nodes.

@tqchen tqchen changed the title [DISCUSS][IR] First-class Structural Pattern [DISCUSS][IR] First-class StructInfo Dec 5, 2022
@tqchen
Copy link
Contributor Author

tqchen commented Dec 5, 2022

NOTE: updated the terminology to StructInfo to avoid confusion with the dataflow pattern lang.

@tqchen
Copy link
Contributor Author

tqchen commented Dec 5, 2022

@slyubomirsky thank you for bringing it up. In the particular example, indeed the relax var can appear in the shape in cases where the shape being deduced through an opaque function. It still holds that shape variables are only defined through Match

@slyubomirsky
Copy link
Collaborator

Okay, we might need to define how that should work.

@tqchen
Copy link
Contributor Author

tqchen commented Dec 5, 2022

agree, more broadly, we should clarify what does "match_cast" semantics implies here.

@slyubomirsky
Copy link
Collaborator

slyubomirsky commented Dec 6, 2022

Since I won't be able to be at the community meeting tomorrow, I'll give some of my thoughts in advance in writing (would be happy to discuss further based on what is said at the meeting).

Drawing on my draft rules for shape inference, I think StructInfo could work like this:

  1. match_cast dynamically checks that the value at run time matches the specified information
  2. If StructInfo is annotated on the function arguments or return value, there will be an implicit match_cast at the beginning and end of the function to dynamically check these

It is a little harder to decide what to do with annotated StructInfo. In the draft specification, I said that if the compiler cannot statically prove a shape annotation matches the computed shape_, that the compiler should raise an error and require a dynamic cast. This is an approach we can use for StructInfo, but that does not match the intent of being "best-effort." I think we could use the following policy:

  1. Do nothing if the annotated StructInfo matches the StructInfo computed statically
  2. Give a compile-time warning if the annotated StructInfo might not match (i.e., cannot conclude equality or disequality) the computed StructInfo. Give the variable the annotated StructInfo (trust the annotation, per "assume semantics")
  3. Give a compile-time error if the annotated StructInfo definitely does not match the computed StructInfo

In these cases, I think there should be no run-time semantics for the annotation (i.e., there will be a dynamic check only if there is an explicit match_cast). Alternatively, we could have a compiler flag to turn all instances of case 2 into implicit match_casts (or make that the default).

My only worry is about error reporting if a shape mismatch is detected late in compilation, e.g., after several passes that may have transformed the AST. How would we convey that to the user? Would we expect users to keep track of which passes are applied? For example, it's possible that there is not enough information in the initial program to conclude that a shape mismatched, but after applying function inlining, the compiler is able to conclude that there is a mismatch. It's good to detect errors, but my concern is about how to report them to the user.

I think the relationship of StructInfo to type should be clearly specified as well. I think all StructInfo should be associated with a Relax type.

In general, I like this idea and I would love to spend time whiteboarding out rules for the different kinds of expressions and how we should process the StructInfo. I think we should be careful about what sorts of expressions we permit to appear inside StructInfo and what the scoping rules will be.

@tqchen
Copy link
Contributor Author

tqchen commented Dec 6, 2022

Thank you @slyubomirsky for bringing up great points. Agree with your points on how annotation works on arguments and return. I agree with policy especially around 1/3 (where 3 is best effort)

One of things to consider on on policy 2(whether warning should be issued) when we use TVMScript for both storing the intermediate output(where the struct_info are being deduced by the compiler) as an IR. To enable roundtrip capabilities in such cases, the best approach is to not run the deduction to avoid possible additional bindings being created due to general re-deduction, especially around the opaque shape example, and directly take the "assume semantics".

In the case of user provided program where there are only partial annotations. I agree that providing some form of implicit match_cast(or warning) would make sense. Perhaps we could have some syntax to distinguish the two usecases. Alternatively we always recommend users to use match_cast, which is more explicit and clearly state that the annotation is assume.

One way to think about best effort compilation error is that if we do not have the rich information, likely they will turn into a runtime error, and in some sense an error at compiler time could be better. Indeed we can think through a bit more about error reporting here. My guess is that the operator context might help.

Also agree about mapping StructInfo into type, the get_static_type() function provide such a mapping.

@YuchenJin
Copy link
Collaborator

Thanks everyone for the proposal and discussion!

At the Relax open dev meeting on Dec 7 (recording, passcode: j$qkF+D2), the community has agreed on bringing the proposed StructInfo in, and we will proceed on the implementation.

@tqchen
Copy link
Contributor Author

tqchen commented Dec 13, 2022

One to to note is that the StructInfo deduction is something that we can continue to refine further.

To help us to quickly get onto the new infra, the first iteration of implementation likely only seek to match the original best effort shape deduction results that we currently have(and not have smarter deductions), so we can have a basis on the new infra for iteration.

  • For example, we may not implement the cross function call shape deduction that involves symbolic shape and erase to static shapes.
  • We may also erase the dynamic information in between branches.

Here is a overall sketch guideline on how struct info deduction can work, we do it with the following helper functions:

def unify_to_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo:
   """Find LCA of lhs and rhs"""

def erase_to_well_defined(info: StructInfo, 
    shape_var_in_scope: List[tir.Var],
    var_in_scope: List[Var]):
) -> StructInfo:
   """Erase info to exclude vars that are not in scope"""

unify_to_lca helps us to find an LCA of two struct info by erasing information.

erase_to_well_defined is another function introduced to ensure correctness. Consider the following code example

def f(x: R.Tensor[(n, m)]):
    k = tir.Var("k", "int64")
    v0 = opaque_fn(x)
    v1 = match_cast(v0, R.Tensor[(n, k)])
    v2 : R.Tensor[(n+1, k+2)] = pad(v1)
    return v2

In the above code, the return value y have shape (n + 1, k + 1), However, at the level of function signature, only n, m are defined considering the parameters, and k is undefined ones we go outside the scope of the function body and only look at the parameters. In this case:

  • We can call erase_to_well_defined(R.Tensor[(n+1, k+1)], defined=[n, m]) .
  • The result type will become R.Tensor(ndim=2), which is a more coarse grained struct info that do not contains an undefined var.

Let us we look at another example

def f(x: R.Tensor[(n, m)]):
   v2 : R.Tensor[(n, m+2)] = pad(x)
   return v2

In this case erase_to_well_defined(R.Tensor[(n, m+2)], defined={n, m}) will give us R.Tensor[(n, m+2)], because both n amd m can be picked up from the function parameters.

erase_to_well_defined should be used in scenarios where we are returning values from a scope to outsde, and ensure the struct_info out result is well-defined.

  • They should be used in Function, SeqExpr, If

Here is a rough set of deduction rule (note this is a rougn sketch to ensure consistency with shape):

  • Call: Look at call.op.struct_info which should be a FuncStructInfo, apply function struct deduction rule.
  • If:
    if_node.struct_info = unify_to_lca(
    	erase_to_well_defined(if_node.then_case, parent_scope_vars),
    	erase_to_well_defined(if_node.else_case, parent_scope_vars)
    )
  • SeqExpr
     seq_node.struct_info = erase_to_well_defined(seq_node.body, parent_scope_vars)	
  • Function:
    • when ret is specified, check if ret is well defined by looking at vars in params and directly use ret.
    • when ret needs to be deduced,
       ret_struct_info = erase_to_well_defined(func.body, param_scope_vars)	
  • tuple: compose as usual
  • tuple-get-item: compose as usual

Note that initial implementation will mainly aims to first reach parity of the original shape deduction with simpler infra, while not realizing the full best effort.

When uncertain, we can call erase_to_well_defined with no provided vars, or only with vars defined in params. This will give us a good enough case that matches the shape behavior and get the initial infra in place, while leaving room for further refinement of the deduction.

We will do another round of iteration to further strengthen the best effort deduction rules.

@tqchen tqchen changed the title [DISCUSS][IR] First-class StructInfo [IR] First-class StructInfo Dec 16, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants