Skip to content

Commit

Permalink
Refactor GEMM autotuner
Browse files Browse the repository at this point in the history
The refactoring is trying to fix two issues:
1) Expose API for compiling and autotuning dot fusions - previously the implementation was hidden in an anonymous namespace and used a global object for storing the results (autotuning cache). This might be useful for modelling matmul performance.
2) Reduce code duplication - previously the filtering of the triton configs was done separately for default and exhaustive search. Some configs need to be adjusted to be correctly compiled by Triton, for example:
- when there are less elements in a tensor than the number of threads (happens for sparse dot metadata with small tiles);
- when compiling for Hopper with `wgmma` enabled (block_m or num_warps may need to be adjusted).
Notably, `kDefaultGemmTiling` could need such adjustments, so returning it as-is may result in fatal compilation errors. With this patch, the same config generator is used for all code paths.

There are a few more changes:
- do not handle reference implementation (cuBLAS) separately; instead, use it in the `Config` variant along with the other options (cuDNN, triton).
- do not repeat cuDNN conditions - put them in one place.
- use actual tile sizes and SM count for calculating the split-K upper bound (more accurate heuristic than a global limit).

PiperOrigin-RevId: 625975092
  • Loading branch information
sergeykozub authored and tensorflower-gardener committed Apr 18, 2024
1 parent 6d3d608 commit 698eaa5
Show file tree
Hide file tree
Showing 3 changed files with 585 additions and 674 deletions.

0 comments on commit 698eaa5

Please sign in to comment.