Skip to content

[ROCm] support torch._C._set_sm_carveout_experimental - Parity with Nvidia #149280

@functionstackx

Description

@functionstackx

🐛 Describe the bug

Hi @hliuca

On Nvidia, they support torch._C._set_sm_carveout_experimental for better compute-comms overlapping. this is useful during bwd pass of DDP and fwd/bwd pass of FSDP to ensure there is enough available SM/CUs for the rccl comms kernels to not be blocked by compute kernels that use up all the SM/CUs

Furthermore, it is useful towards benchmarking real world GEMMs that occurs in the backwards pass when the GEMM is unavailable to take up all the available SM/CUs due to rccl comms kernels occupying some of the SM/CUs

related to #147966

I was looking into implementing this myself but it seems like it isn't as simple as calling hipblasLtMatmulDescSetAttribute as it requires changes to hipblaslt itself since unlike cublasLtMatmulDescSetAttribute, HIPBLASLT_MATMUL_DESC_CU_COUNT_TARGET is not an option for hipblasLtMatmulDescSetAttribute function which takes in enum of hipblasLtMatmulDescAttributes_t at least according to the AMD docs

https://rocm.docs.amd.com/projects/hipBLASLt/en/latest/datatypes.html#_CPPv431hipblasLtMatmulDescAttributes_t

computeDesc.setAttribute<int32_t>(
        CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
        at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
            at::globalContext()._SMCarveout_EXPERIMENTAL().value());
  }

Versions

any rocm torch version

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: rocmAMD GPU support for PytorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions