Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[Fix][TVMScript] Parse and print tir_vars of call_tir properly #361

Merged

Conversation

MasterJH5574
Copy link
Collaborator

This PR patches both the TVMScript parser and printer on handling the tir_vars parameter of call_tir.

When a TIR function contains symbolic shapes, its signature can have trailing TIR vars in the param list. In such cases, when using call_tir to call into the TIR function, we should provide the instances of those TIR vars through the param tir_vars. tir_vars is the last parameter of call_tir, one behind dtype.

def call_tir(
func: Union[str, Expr],
args: Union[Expr, List[Expr]],
shape: Union[RxTuple, ShapeExpr, List[int]],
dtype: Union[str, List[str]],
tir_vars: Optional[ShapeExpr] = None,
) -> Call:

The specific issues fixed in this PR is:

  • Printer: prior to this PR, when printing call_tir, the printer always explicitly prints "dtype=", while doesn’t print "tir_vars=" when there is. This will make the printed script looks like
    y = R.call_tir(copy, (x,), ((n * 2,)), dtype="float32", (n,))
    which doesn’t conform to Python syntax.
  • call_tir’s handling on TIR vars: the TIR vars will be parsed as a list or a tuple of Python, while on the call_tir side the parsed tuple/list will be treated as a Expr and passed to C++ side through FFI. This will incur FFI error like
    error:   Check failed: (!checked_type.defined()) is false: Expected RelayExpr, but got Array
    

Therefore, this PR fixes both sides. On printer side, we now print tir_vars after dtype and explicitly print "tir_vars=". On call_tir side, we wrap the input tir_vars to ShapeExpr if we find it a tuple or list. One regression test that covers both issues is provided.

cc @YuchenJin

MasterJH5574 added a commit to MasterJH5574/tlc-relax that referenced this pull request Jan 16, 2023
Copy link
Collaborator

@YuchenJin YuchenJin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@MasterJH5574 MasterJH5574 force-pushed the relax-dev/2023-01-16-call-tir-tir-args branch from 31e4479 to 0b0b893 Compare January 16, 2023 20:49
spectrometerHBH referenced this pull request in mlc-ai/relax Jan 16, 2023
* [TIR][Fix] Buffer slicing using index dtype as extent (#13788)

[Fix] Buffer slicing using index dtype as extent

* [TIR][Fix] IndexDataTypeNormalizer not unwrapping float casting (#13789)

* [Fix][TVMScript] Parse and print `tir_vars` of `call_tir` properly (tlc-pack#361)

* [Transform] Operator legalizer

* Documentation
@tqchen tqchen merged commit 15bf541 into tlc-pack:relax Jan 18, 2023
junrushao pushed a commit to junrushao/relax that referenced this pull request Jan 25, 2023
junrushao pushed a commit to junrushao/relax that referenced this pull request Jan 26, 2023
MasterJH5574 referenced this pull request in mlc-ai/relax Jan 28, 2023
* [TIR][Fix] Buffer slicing using index dtype as extent (#13788)

[Fix] Buffer slicing using index dtype as extent

* [TIR][Fix] IndexDataTypeNormalizer not unwrapping float casting (#13789)

* [Fix][TVMScript] Parse and print `tir_vars` of `call_tir` properly (tlc-pack#361)

* [Transform] Operator legalizer

* Documentation
junrushao pushed a commit to junrushao/relax that referenced this pull request Jan 29, 2023
MasterJH5574 referenced this pull request in mlc-ai/relax Jan 31, 2023
* [TIR][Fix] Buffer slicing using index dtype as extent (#13788)

[Fix] Buffer slicing using index dtype as extent

* [TIR][Fix] IndexDataTypeNormalizer not unwrapping float casting (#13789)

* [Fix][TVMScript] Parse and print `tir_vars` of `call_tir` properly (tlc-pack#361)

* [Transform] Operator legalizer

* Documentation
junrushao pushed a commit to junrushao/relax that referenced this pull request Feb 5, 2023
junrushao pushed a commit to junrushao/relax that referenced this pull request Feb 6, 2023
MasterJH5574 referenced this pull request in mlc-ai/relax Feb 8, 2023
* [TIR][Fix] Buffer slicing using index dtype as extent (#13788)

[Fix] Buffer slicing using index dtype as extent

* [TIR][Fix] IndexDataTypeNormalizer not unwrapping float casting (#13789)

* [Fix][TVMScript] Parse and print `tir_vars` of `call_tir` properly (tlc-pack#361)

* [Transform] Operator legalizer

* Documentation
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants