Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The refactoring is trying to fix two issues: 1) Expose API for compiling and autotuning dot fusions - previously the implementation was hidden in an anonymous namespace and used a global object for storing the results (autotuning cache). This might be useful for modelling matmul performance. 2) Reduce code duplication - previously the filtering of the triton configs was done separately for default and exhaustive search. Some configs need to be adjusted to be correctly compiled by Triton, for example: - when there are less elements in a tensor than the number of threads (happens for sparse dot metadata with small tiles); - when compiling for Hopper with `wgmma` enabled (block_m or num_warps may need to be adjusted). Notably, `kDefaultGemmTiling` could need such adjustments, so returning it as-is may result in fatal compilation errors. With this patch, the same config generator is used for all code paths. There are a few more changes: - do not handle reference implementation (cuBLAS) separately; instead, use it in the `Config` variant along with the other options (cuDNN, triton). - do not repeat cuDNN conditions - put them in one place. - use actual tile sizes and SM count for calculating the split-K upper bound (more accurate heuristic than a global limit). PiperOrigin-RevId: 625975092
- Loading branch information