Skip to content

feat: batched complex GEMM (C = alpha*A*B + beta*C) #3

@shinaoka

Description

@shinaoka

Summary

Add batched complex GEMM support using 4-real decomposition: Cr = Ar*Br - Ai*Bi, Ci = Ar*Bi + Ai*Br.

Background

Complex GEMM is essential for scientific computing (quantum mechanics, signal processing). cuBLAS provides cgemm/zgemm with Tensor Core support, but a fallback is needed for portability.

Approach

  • 4-real decomposition: Complex GEMM → 4 real GEMMs, each using existing CMMA/MMA tile pipeline
  • Tile-level implementation in cubek-matmul (separate repo): decompose complex tiles into real tiles
  • Scalar register fallback for hardware without Tensor Core complex support
  • Batched: support [B, M, K] × [B, K, N] → [B, M, N] with broadcasting

Scope

  • cubek-matmul: Add MatmulPrecision impl for Complex types, complex-aware tile routines
  • cubecl (this repo): Ensure CMMA load/store work with interleaved complex layout
  • Alpha/beta scaling support

References

  • Design doc: docs/plans/2026-04-14-complex-design.md
  • PyTorch uses interleaved complex + cuBLAS cgemm/zgemm directly
  • cuBLAS complex GEMM: 1 call with Tensor Cores vs 4 real GEMMs

Dependencies

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions