Skip to content

[AMDGPU] ISel & PEI for whole wave functions #145858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

rovka
Copy link
Collaborator

@rovka rovka commented Jun 26, 2025

Whole wave functions are functions that will run with a full EXEC mask.
They will not be invoked directly, but instead will be launched by way
of a new intrinsic, llvm.amdgcn.call.whole.wave (to be added in
a future patch). These functions are meant as an alternative to the
llvm.amdgcn.init.whole.wave or llvm.amdgcn.strict.wwm intrinsics.

Whole wave functions will set EXEC to -1 in the prologue and restore the
original value of EXEC in the epilogue. They must have a special first
argument, i1 %active, that is going to be mapped to EXEC. They may
have either the default calling convention or amdgpu_gfx. The inactive
lanes need to be preserved for all registers used, active lanes only for
the CSRs.

At the IR level, arguments to a whole wave function (other than
%active) contain poison in their inactive lanes. Likewise, the return
value for the inactive lanes is poison.

This patch contains the following work:

  • 2 new pseudos, SI_SETUP_WHOLE_WAVE_FUNC and SI_WHOLE_WAVE_FUNC_RETURN
    used for managing the EXEC mask. SI_SETUP_WHOLE_WAVE_FUNC will return
    a SReg_1 representing %active, which needs to be passed into
    SI_WHOLE_WAVE_FUNC_RETURN.
  • SelectionDAG support for generating these 2 new pseudos and the
    special handling of %active. Since the return may be in a different
    basic block, it's difficult to add the virtual reg for %active to
    SI_WHOLE_WAVE_FUNC_RETURN, so we initially generate an IMPLICIT_DEF
    which is later replaced via a custom inserter.
  • Expansion of the 2 pseudos during prolog/epilog insertion. PEI also
    marks any used VGPRs are WWM registers, which are then spilled and
    restored with the usual logic.

Future patches will include the llvm.amdgcn.call.whole.wave intrinsic
and a lot of optimization work (especially in order to reduce spills around
function calls).

Copy link
Collaborator Author

rovka commented Jun 26, 2025

@rovka rovka changed the title Add subtarget feature [AMDGPU] ISel & PEI for whole wave functions Jun 26, 2025
@rovka rovka marked this pull request as ready for review June 26, 2025 09:05
@shiltian
Copy link
Contributor

What is the fundamental motivation of doing this?

@rovka
Copy link
Collaborator Author

rovka commented Jun 27, 2025

What is the fundamental motivation of doing this?

The fundamental motivation for doing this is that the WWM intrinsics are fragile and don't handle control flow very well. We have an increasing number of use cases where we need to have some control flow inside a WWM region, and that isn't supported that well. For one of these use cases we've introduced the init.whole.wave intrinsic, but that's kind of hacky. Whole wave functions are meant to replace all that with something more robust (kind of like a MachineIR region).

rovka added 22 commits June 27, 2025 13:55
Whole wave functions are functions that will run with a full EXEC mask.
They will not be invoked directly, but instead will be launched by way
of a new intrinsic, `llvm.amdgcn.call.whole.wave` (to be added in
a future patch). These functions are meant as an alternative to the
`llvm.amdgcn.init.whole.wave` or `llvm.amdgcn.strict.wwm` intrinsics.

Whole wave functions will set EXEC to -1 in the prologue and restore the
original value of EXEC in the epilogue. They must have a special first
argument, `i1 %active`, that is going to be mapped to EXEC. They may
have either the default calling convention or amdgpu_gfx. The inactive
lanes need to be preserved for all registers used, active lanes only for
the CSRs.

At the IR level, arguments to a whole wave function (other than
`%active`) contain poison in their inactive lanes. Likewise, the return
value for the inactive lanes is poison.

This patch contains the following work:
* 2 new pseudos, SI_SETUP_WHOLE_WAVE_FUNC and SI_WHOLE_WAVE_FUNC_RETURN
  used for managing the EXEC mask. SI_SETUP_WHOLE_WAVE_FUNC will return
  a SReg_1 representing `%active`, which needs to be passed into
  SI_WHOLE_WAVE_FUNC_RETURN.
* SelectionDAG support for generating these 2 new pseudos and the
  special handling of %active. Since the return may be in a different
  basic block, it's difficult to add the virtual reg for %active to
  SI_WHOLE_WAVE_FUNC_RETURN, so we initially generate an IMPLICIT_DEF
  which is later replaced via a custom inserter.
* Expansion of the 2 pseudos during prolog/epilog insertion. PEI also
  marks any used VGPRs are WWM registers, which are then spilled and
  restored with the usual logic.

I'm still working on the GlobalISel support and on adding some docs in
AMDGPUUsage.rst.

Future patches will include the `llvm.amdgcn.call.whole.wave` intrinsic,
a codegen prepare patch that looks for the callees of that intrinsic and
marks them as whole wave functions, and probably a lot of optimization
work.
This reverts commit c6e9211d5644061521cbce8edac7c475c83b01d6.
@rovka rovka force-pushed the users/rovka/whole-wave-funcs branch from a666b2d to 8ea4ac9 Compare June 27, 2025 11:59
@@ -1662,6 +1714,21 @@ void SIFrameLowering::determineCalleeSaves(MachineFunction &MF,
if (MFI->isEntryFunction())
return;

if (MFI->isWholeWaveFunction()) {
// In practice, all the VGPRs are WWM registers, and we will need to save at
// least their inactive lanes. Add them to WWMReservedRegs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we considering all VGPRs to be WWM regs in this calling convention?

@@ -564,6 +564,10 @@ declare riscv_vls_cc(32768) void @riscv_vls_cc_32768()
; CHECK: declare riscv_vls_cc(32768) void @riscv_vls_cc_32768()
declare riscv_vls_cc(65536) void @riscv_vls_cc_65536()
; CHECK: declare riscv_vls_cc(65536) void @riscv_vls_cc_65536()
declare cc124 void @f.cc124(i1)
; CHECK: declare amdgpu_gfx_whole_wave void @f.cc124(i1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does cc124 become amdgpu_gfx_whole_wave here?

@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Diana Picus (rovka)

Changes

Whole wave functions are functions that will run with a full EXEC mask.
They will not be invoked directly, but instead will be launched by way
of a new intrinsic, llvm.amdgcn.call.whole.wave (to be added in
a future patch). These functions are meant as an alternative to the
llvm.amdgcn.init.whole.wave or llvm.amdgcn.strict.wwm intrinsics.

Whole wave functions will set EXEC to -1 in the prologue and restore the
original value of EXEC in the epilogue. They must have a special first
argument, i1 %active, that is going to be mapped to EXEC. They may
have either the default calling convention or amdgpu_gfx. The inactive
lanes need to be preserved for all registers used, active lanes only for
the CSRs.

At the IR level, arguments to a whole wave function (other than
%active) contain poison in their inactive lanes. Likewise, the return
value for the inactive lanes is poison.

This patch contains the following work:

  • 2 new pseudos, SI_SETUP_WHOLE_WAVE_FUNC and SI_WHOLE_WAVE_FUNC_RETURN
    used for managing the EXEC mask. SI_SETUP_WHOLE_WAVE_FUNC will return
    a SReg_1 representing %active, which needs to be passed into
    SI_WHOLE_WAVE_FUNC_RETURN.
  • SelectionDAG support for generating these 2 new pseudos and the
    special handling of %active. Since the return may be in a different
    basic block, it's difficult to add the virtual reg for %active to
    SI_WHOLE_WAVE_FUNC_RETURN, so we initially generate an IMPLICIT_DEF
    which is later replaced via a custom inserter.
  • Expansion of the 2 pseudos during prolog/epilog insertion. PEI also
    marks any used VGPRs are WWM registers, which are then spilled and
    restored with the usual logic.

Future patches will include the llvm.amdgcn.call.whole.wave intrinsic
and a lot of optimization work (especially in order to reduce spills around
function calls).


Patch is 212.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145858.diff

41 Files Affected:

  • (modified) llvm/docs/AMDGPUUsage.rst (+14)
  • (modified) llvm/include/llvm/AsmParser/LLToken.h (+1)
  • (modified) llvm/include/llvm/IR/CallingConv.h (+8)
  • (modified) llvm/lib/AsmParser/LLLexer.cpp (+1)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+3)
  • (modified) llvm/lib/IR/AsmWriter.cpp (+3)
  • (modified) llvm/lib/IR/Function.cpp (+1)
  • (modified) llvm/lib/IR/Verifier.cpp (+10)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp (+30-3)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h (+3)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUGISel.td (+4)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+4)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h (+6)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td (+11)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp (+4)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+4)
  • (modified) llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp (+1-1)
  • (modified) llvm/lib/Target/AMDGPU/SIFrameLowering.cpp (+77-10)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+34-6)
  • (modified) llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp (+1)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.cpp (+14)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.h (+2)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+40)
  • (modified) llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp (+7-2)
  • (modified) llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h (+6)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp (+2)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+2-1)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUPALMetadata.cpp (+1)
  • (modified) llvm/test/Bitcode/compatibility.ll (+4)
  • (added) llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect-whole-wave-functions.mir (+40)
  • (added) llvm/test/CodeGen/AMDGPU/irtranslator-whole-wave-functions.ll (+103)
  • (added) llvm/test/CodeGen/AMDGPU/isel-whole-wave-functions.ll (+190)
  • (added) llvm/test/CodeGen/AMDGPU/whole-wave-functions-pei.mir (+448)
  • (added) llvm/test/CodeGen/AMDGPU/whole-wave-functions.ll (+2414)
  • (modified) llvm/test/CodeGen/MIR/AMDGPU/long-branch-reg-all-sgpr-used.ll (+2)
  • (modified) llvm/test/CodeGen/MIR/AMDGPU/machine-function-info-after-pei.ll (+1)
  • (modified) llvm/test/CodeGen/MIR/AMDGPU/machine-function-info-long-branch-reg-debug.ll (+1)
  • (modified) llvm/test/CodeGen/MIR/AMDGPU/machine-function-info-long-branch-reg.ll (+1)
  • (modified) llvm/test/CodeGen/MIR/AMDGPU/machine-function-info-no-ir.mir (+4)
  • (modified) llvm/test/CodeGen/MIR/AMDGPU/machine-function-info.ll (+4)
  • (modified) llvm/test/Verifier/amdgpu-cc.ll (+33)
diff --git a/llvm/docs/AMDGPUUsage.rst b/llvm/docs/AMDGPUUsage.rst
index c5b9bd9de66e1..19357635ecfc1 100644
--- a/llvm/docs/AMDGPUUsage.rst
+++ b/llvm/docs/AMDGPUUsage.rst
@@ -1844,6 +1844,20 @@ The AMDGPU backend supports the following calling conventions:
                                      ..TODO::
                                      Describe.
 
+     ``amdgpu_gfx_whole_wave``       Used for AMD graphics targets. Functions with this calling convention
+                                     cannot be used as entry points. They must have an i1 as the first argument,
+                                     which will be mapped to the value of EXEC on entry into the function. Other
+                                     arguments will contain poison in their inactive lanes. Similarly, the return
+                                     value for the inactive lanes is poison.
+
+                                     The function will run with all lanes enabled, i.e. EXEC will be set to -1 in the
+                                     prologue and restored to its original value in the epilogue. The inactive lanes
+                                     will be preserved for all the registers used by the function. Active lanes only
+                                     will only be preserved for the callee saved registers.
+
+                                     In all other respects, functions with this calling convention behave like
+                                     ``amdgpu_gfx`` functions.
+
      ``amdgpu_gs``                   Used for Mesa/AMDPAL geometry shaders.
                                      ..TODO::
                                      Describe.
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index c7e4bdf3ff811..a2311d2ac285d 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -181,6 +181,7 @@ enum Kind {
   kw_amdgpu_cs_chain_preserve,
   kw_amdgpu_kernel,
   kw_amdgpu_gfx,
+  kw_amdgpu_gfx_whole_wave,
   kw_tailcc,
   kw_m68k_rtdcc,
   kw_graalcc,
diff --git a/llvm/include/llvm/IR/CallingConv.h b/llvm/include/llvm/IR/CallingConv.h
index d68491eb5535c..ef761eb1aed73 100644
--- a/llvm/include/llvm/IR/CallingConv.h
+++ b/llvm/include/llvm/IR/CallingConv.h
@@ -284,6 +284,9 @@ namespace CallingConv {
     RISCV_VLSCall_32768 = 122,
     RISCV_VLSCall_65536 = 123,
 
+    // Calling convention for AMDGPU whole wave functions.
+    AMDGPU_Gfx_WholeWave = 124,
+
     /// The highest possible ID. Must be some 2^k - 1.
     MaxID = 1023
   };
@@ -294,8 +297,13 @@ namespace CallingConv {
 /// directly or indirectly via a call-like instruction.
 constexpr bool isCallableCC(CallingConv::ID CC) {
   switch (CC) {
+  // Called with special intrinsics:
+  // llvm.amdgcn.cs.chain
   case CallingConv::AMDGPU_CS_Chain:
   case CallingConv::AMDGPU_CS_ChainPreserve:
+  // llvm.amdgcn.call.whole.wave
+  case CallingConv::AMDGPU_Gfx_WholeWave:
+  // Hardware entry points:
   case CallingConv::AMDGPU_CS:
   case CallingConv::AMDGPU_ES:
   case CallingConv::AMDGPU_GS:
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index ce813e1d7b1c4..520c6a00a9c07 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -679,6 +679,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(amdgpu_cs_chain_preserve);
   KEYWORD(amdgpu_kernel);
   KEYWORD(amdgpu_gfx);
+  KEYWORD(amdgpu_gfx_whole_wave);
   KEYWORD(tailcc);
   KEYWORD(m68k_rtdcc);
   KEYWORD(graalcc);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index c5e166cef6da6..b09696497cc4e 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -2274,6 +2274,9 @@ bool LLParser::parseOptionalCallingConv(unsigned &CC) {
     CC = CallingConv::AMDGPU_CS_ChainPreserve;
     break;
   case lltok::kw_amdgpu_kernel:  CC = CallingConv::AMDGPU_KERNEL; break;
+  case lltok::kw_amdgpu_gfx_whole_wave:
+    CC = CallingConv::AMDGPU_Gfx_WholeWave;
+    break;
   case lltok::kw_tailcc:         CC = CallingConv::Tail; break;
   case lltok::kw_m68k_rtdcc:     CC = CallingConv::M68k_RTD; break;
   case lltok::kw_graalcc:        CC = CallingConv::GRAAL; break;
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 7828ba45ec27f..3ce892ecbff19 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -404,6 +404,9 @@ static void PrintCallingConv(unsigned cc, raw_ostream &Out) {
     break;
   case CallingConv::AMDGPU_KERNEL: Out << "amdgpu_kernel"; break;
   case CallingConv::AMDGPU_Gfx:    Out << "amdgpu_gfx"; break;
+  case CallingConv::AMDGPU_Gfx_WholeWave:
+    Out << "amdgpu_gfx_whole_wave";
+    break;
   case CallingConv::M68k_RTD:      Out << "m68k_rtdcc"; break;
   case CallingConv::RISCV_VectorCall:
     Out << "riscv_vector_cc";
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 3e7fcbb983738..b1e8dd716063f 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1226,6 +1226,7 @@ bool llvm::CallingConv::supportsNonVoidReturnType(CallingConv::ID CC) {
   case CallingConv::AArch64_SVE_VectorCall:
   case CallingConv::WASM_EmscriptenInvoke:
   case CallingConv::AMDGPU_Gfx:
+  case CallingConv::AMDGPU_Gfx_WholeWave:
   case CallingConv::M68k_INTR:
   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 9cab88b09779a..32ce1880f2fdd 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2975,6 +2975,16 @@ void Verifier::visitFunction(const Function &F) {
           "perfect forwarding!",
           &F);
     break;
+  case CallingConv::AMDGPU_Gfx_WholeWave:
+    Check(F.arg_size() != 0 && F.arg_begin()->getType()->isIntegerTy(1),
+          "Calling convention requires first argument to be i1", &F);
+    Check(!F.arg_begin()->hasInRegAttr(),
+          "Calling convention requires first argument to not be inreg", &F);
+    Check(!F.isVarArg(),
+          "Calling convention does not support varargs or "
+          "perfect forwarding!",
+          &F);
+    break;
   }
 
   // Check that the argument values match the function type for this function...
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
index 14101e57f5143..b4ea3c81b3b6e 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
@@ -374,8 +374,10 @@ bool AMDGPUCallLowering::lowerReturn(MachineIRBuilder &B, const Value *Val,
     return true;
   }
 
-  unsigned ReturnOpc =
-      IsShader ? AMDGPU::SI_RETURN_TO_EPILOG : AMDGPU::SI_RETURN;
+  const bool IsWholeWave = MFI->isWholeWaveFunction();
+  unsigned ReturnOpc = IsWholeWave ? AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN
+                       : IsShader  ? AMDGPU::SI_RETURN_TO_EPILOG
+                                   : AMDGPU::SI_RETURN;
   auto Ret = B.buildInstrNoInsert(ReturnOpc);
 
   if (!FLI.CanLowerReturn)
@@ -383,6 +385,10 @@ bool AMDGPUCallLowering::lowerReturn(MachineIRBuilder &B, const Value *Val,
   else if (!lowerReturnVal(B, Val, VRegs, Ret))
     return false;
 
+  if (IsWholeWave) {
+    addOriginalExecToReturn(B.getMF(), Ret);
+  }
+
   // TODO: Handle CalleeSavedRegsViaCopy.
 
   B.insertInstr(Ret);
@@ -632,6 +638,17 @@ bool AMDGPUCallLowering::lowerFormalArguments(
     if (DL.getTypeStoreSize(Arg.getType()) == 0)
       continue;
 
+    if (Info->isWholeWaveFunction() && Idx == 0) {
+      assert(VRegs[Idx].size() == 1 && "Expected only one register");
+
+      // The first argument for whole wave functions is the original EXEC value.
+      B.buildInstr(AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_SETUP)
+          .addDef(VRegs[Idx][0]);
+
+      ++Idx;
+      continue;
+    }
+
     const bool InReg = Arg.hasAttribute(Attribute::InReg);
 
     if (Arg.hasAttribute(Attribute::SwiftSelf) ||
@@ -1347,6 +1364,7 @@ bool AMDGPUCallLowering::lowerTailCall(
   SmallVector<std::pair<MCRegister, Register>, 12> ImplicitArgRegs;
 
   if (Info.CallConv != CallingConv::AMDGPU_Gfx &&
+      Info.CallConv != CallingConv::AMDGPU_Gfx_WholeWave &&
       !AMDGPU::isChainCC(Info.CallConv)) {
     // With a fixed ABI, allocate fixed registers before user arguments.
     if (!passSpecialInputs(MIRBuilder, CCInfo, ImplicitArgRegs, Info))
@@ -1524,7 +1542,8 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   // after the ordinary user argument registers.
   SmallVector<std::pair<MCRegister, Register>, 12> ImplicitArgRegs;
 
-  if (Info.CallConv != CallingConv::AMDGPU_Gfx) {
+  if (Info.CallConv != CallingConv::AMDGPU_Gfx &&
+      Info.CallConv != CallingConv::AMDGPU_Gfx_WholeWave) {
     // With a fixed ABI, allocate fixed registers before user arguments.
     if (!passSpecialInputs(MIRBuilder, CCInfo, ImplicitArgRegs, Info))
       return false;
@@ -1592,3 +1611,11 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   return true;
 }
+
+void AMDGPUCallLowering::addOriginalExecToReturn(
+    MachineFunction &MF, MachineInstrBuilder &Ret) const {
+  const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
+  const SIInstrInfo *TII = ST.getInstrInfo();
+  const MachineInstr *Setup = TII->getWholeWaveFunctionSetup(MF);
+  Ret.addReg(Setup->getOperand(0).getReg());
+}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h
index a6e801f2a547b..e0033d59d10bb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h
@@ -37,6 +37,9 @@ class AMDGPUCallLowering final : public CallLowering {
   bool lowerReturnVal(MachineIRBuilder &B, const Value *Val,
                       ArrayRef<Register> VRegs, MachineInstrBuilder &Ret) const;
 
+  void addOriginalExecToReturn(MachineFunction &MF,
+                               MachineInstrBuilder &Ret) const;
+
 public:
   AMDGPUCallLowering(const AMDGPUTargetLowering &TLI);
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
index 1b909568fc555..c5063c4de4ad3 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
@@ -300,6 +300,10 @@ def : GINodeEquiv<G_AMDGPU_S_BUFFER_LOAD_SSHORT, SIsbuffer_load_short>;
 def : GINodeEquiv<G_AMDGPU_S_BUFFER_LOAD_USHORT, SIsbuffer_load_ushort>;
 def : GINodeEquiv<G_AMDGPU_S_BUFFER_PREFETCH, SIsbuffer_prefetch>;
 
+def : GINodeEquiv<G_AMDGPU_WHOLE_WAVE_FUNC_SETUP, AMDGPUwhole_wave_setup>;
+// G_AMDGPU_WHOLE_WAVE_FUNC_RETURN is simpler than AMDGPUwhole_wave_return,
+// so we don't mark it as equivalent.
+
 class GISelSop2Pat <
   SDPatternOperator node,
   Instruction inst,
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index d75c7a178b4a8..0421ed87e61f4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -1138,6 +1138,7 @@ CCAssignFn *AMDGPUCallLowering::CCAssignFnForCall(CallingConv::ID CC,
   case CallingConv::Cold:
     return CC_AMDGPU_Func;
   case CallingConv::AMDGPU_Gfx:
+  case CallingConv::AMDGPU_Gfx_WholeWave:
     return CC_SI_Gfx;
   case CallingConv::AMDGPU_KERNEL:
   case CallingConv::SPIR_KERNEL:
@@ -1163,6 +1164,7 @@ CCAssignFn *AMDGPUCallLowering::CCAssignFnForReturn(CallingConv::ID CC,
   case CallingConv::AMDGPU_LS:
     return RetCC_SI_Shader;
   case CallingConv::AMDGPU_Gfx:
+  case CallingConv::AMDGPU_Gfx_WholeWave:
     return RetCC_SI_Gfx;
   case CallingConv::C:
   case CallingConv::Fast:
@@ -5777,6 +5779,8 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(BUFFER_ATOMIC_FMIN)
   NODE_NAME_CASE(BUFFER_ATOMIC_FMAX)
   NODE_NAME_CASE(BUFFER_ATOMIC_COND_SUB_U32)
+  NODE_NAME_CASE(WHOLE_WAVE_SETUP)
+  NODE_NAME_CASE(WHOLE_WAVE_RETURN)
   }
   return nullptr;
 }
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
index 0dd2183b72b24..5716711de3402 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
@@ -607,6 +607,12 @@ enum NodeType : unsigned {
   BUFFER_ATOMIC_FMAX,
   BUFFER_ATOMIC_COND_SUB_U32,
   LAST_MEMORY_OPCODE = BUFFER_ATOMIC_COND_SUB_U32,
+
+  // Set up a whole wave function.
+  WHOLE_WAVE_SETUP,
+
+  // Return from a whole wave function.
+  WHOLE_WAVE_RETURN,
 };
 
 } // End namespace AMDGPUISD
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
index ce58e93a15207..e305f08925cc6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
@@ -348,6 +348,17 @@ def AMDGPUfdot2_impl : SDNode<"AMDGPUISD::FDOT2",
 
 def AMDGPUperm_impl : SDNode<"AMDGPUISD::PERM", AMDGPUDTIntTernaryOp, []>;
 
+// Marks the entry into a whole wave function.
+def AMDGPUwhole_wave_setup : SDNode<
+  "AMDGPUISD::WHOLE_WAVE_SETUP", SDTypeProfile<1, 0, [SDTCisInt<0>]>,
+  [SDNPHasChain, SDNPSideEffect]>;
+
+// Marks the return from a whole wave function.
+def AMDGPUwhole_wave_return : SDNode<
+  "AMDGPUISD::WHOLE_WAVE_RETURN", SDTNone,
+  [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
+>;
+
 // SI+ export
 def AMDGPUExportOp : SDTypeProfile<0, 8, [
   SDTCisInt<0>,       // i8 tgt
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index b632b16f5c198..d86e7735a07bc 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -4141,6 +4141,10 @@ bool AMDGPUInstructionSelector::select(MachineInstr &I) {
     return true;
   case AMDGPU::G_AMDGPU_WAVE_ADDRESS:
     return selectWaveAddress(I);
+  case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN: {
+    I.setDesc(TII.get(AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN));
+    return true;
+  }
   case AMDGPU::G_STACKRESTORE:
     return selectStackRestore(I);
   case AMDGPU::G_PHI:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index b20760c356263..a07699ae1eb23 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -5458,6 +5458,10 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case AMDGPU::G_PREFETCH:
     OpdsMapping[0] = getSGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
     break;
+  case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_SETUP:
+  case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN:
+    OpdsMapping[0] = AMDGPU::getValueMapping(AMDGPU::VCCRegBankID, 1);
+    break;
   }
 
   return getInstructionMapping(/*ID*/1, /*Cost*/1,
diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
index bc95d3f040e1d..098c2dc2405df 100644
--- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
@@ -3155,7 +3155,7 @@ bool GCNHazardRecognizer::fixRequiredExportPriority(MachineInstr *MI) {
   // Check entry priority at each export (as there will only be a few).
   // Note: amdgpu_gfx can only be a callee, so defer to caller setprio.
   bool Changed = false;
-  if (CC != CallingConv::AMDGPU_Gfx)
+  if (CC != CallingConv::AMDGPU_Gfx && CC != CallingConv::AMDGPU_Gfx_WholeWave)
     Changed = ensureEntrySetPrio(MF, NormalPriority, TII);
 
   auto NextMI = std::next(It);
diff --git a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp
index 6a3867937d57f..b88df50c6c999 100644
--- a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp
@@ -946,8 +946,18 @@ static Register buildScratchExecCopy(LiveRegUnits &LiveUnits,
 
   initLiveUnits(LiveUnits, TRI, FuncInfo, MF, MBB, MBBI, IsProlog);
 
-  ScratchExecCopy = findScratchNonCalleeSaveRegister(
-      MRI, LiveUnits, *TRI.getWaveMaskRegClass());
+  if (FuncInfo->isWholeWaveFunction()) {
+    // Whole wave functions already have a copy of the original EXEC mask that
+    // we can use.
+    assert(IsProlog && "Epilog should look at return, not setup");
+    ScratchExecCopy =
+        TII->getWholeWaveFunctionSetup(MF)->getOperand(0).getReg();
+    assert(ScratchExecCopy && "Couldn't find copy of EXEC");
+  } else {
+    ScratchExecCopy = findScratchNonCalleeSaveRegister(
+        MRI, LiveUnits, *TRI.getWaveMaskRegClass());
+  }
+
   if (!ScratchExecCopy)
     report_fatal_error("failed to find free scratch register");
 
@@ -996,10 +1006,15 @@ void SIFrameLowering::emitCSRSpillStores(
       };
 
   StoreWWMRegisters(WWMScratchRegs);
+
+  auto EnableAllLanes = [&]() {
+    unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
+    BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addImm(-1);
+  };
+
   if (!WWMCalleeSavedRegs.empty()) {
     if (ScratchExecCopy) {
-      unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
-      BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addImm(-1);
+      EnableAllLanes();
     } else {
       ScratchExecCopy = buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
                                              /*IsProlog*/ true,
@@ -1008,7 +1023,18 @@ void SIFrameLowering::emitCSRSpillStores(
   }
 
   StoreWWMRegisters(WWMCalleeSavedRegs);
-  if (ScratchExecCopy) {
+  if (FuncInfo->isWholeWaveFunction()) {
+    // SI_WHOLE_WAVE_FUNC_SETUP has outlived its purpose, so we can remove
+    // it now. If we have already saved some WWM CSR registers, then the EXEC is
+    // already -1 and we don't need to do anything else. Otherwise, set EXEC to
+    // -1 here.
+    if (!ScratchExecCopy)
+      buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL, /*IsProlog*/ true,
+                           /*EnableInactiveLanes*/ true);
+    else if (WWMCalleeSavedRegs.empty())
+      EnableAllLanes();
+    TII->getWholeWaveFunctionSetup(MF)->eraseFromParent();
+  } else if (ScratchExecCopy) {
     // FIXME: Split block and make terminator.
     unsigned ExecMov = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
     BuildMI(MBB, MBBI, DL, TII->get(ExecMov), TRI.getExec())
@@ -1083,11 +1109,6 @@ void SIFrameLowering::emitCSRSpillRestores(
   Register ScratchExecCopy;
   SmallVector<std::pair<Register, int>, 2> WWMCalleeSavedRegs, WWMScratchRegs;
   FuncInfo->splitWWMSpillRegisters(MF, WWMCalleeSavedRegs, WWMScratchRegs);
-  if (!WWMScratchRegs.empty())
-    ScratchExecCopy =
-        buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
-                             /*IsProlog*/ false, /*EnableInactiveLanes*/ true);
-
   auto RestoreWWMRegisters =
       [&](SmallVectorImpl<std::pair<Register, int>> &WWMRegs) {
         for (const auto &Reg : WWMRegs) {
@@ -1098,6 +1119,36 @@ void SIFrameLowering::emitCSRSpillRestores(
         }
       };
 
+  if (FuncInfo->isWholeWaveFunction()) {
+    // For whole wave functions, the EXEC is already -1 at this point.
+    // Therefore, we can restore the CSR WWM registers right away.
+    RestoreWWMRegisters(WWMCalleeSavedRegs);
+
+    // The original EXEC is the first operand of the return instruction.
+    const MachineInstr &Return = MBB.instr_back();
+    assert(Return.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN &&
+           "Unexpected return inst");
+    Register OrigExec = Return.getOperand(0).getReg();
+
+    if (!WWMScratchRegs.empty()) {
+      unsigned XorOpc = ST.isWave32() ? AMDGPU::S_XOR_B32 : AMDGPU::S_XOR_B64;
+      BuildMI(MBB, MBBI, DL, TII->get(XorOpc), TRI.getExec())
+          .addReg(OrigExec)
+          .addImm(-1);
+      RestoreWWMRegisters(WWMScratchRegs);
+    }
+
+    // Restore original EXEC.
+    unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
+    BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addReg(OrigExec);
+    return;
+  }
+
+  if (!WWMScratchRegs.empty())
+    ScratchExecCopy =
+        buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
+                             /*IsProlog*/ false, /*EnableInactiveLanes*/ true);
+
   RestoreWWMRegisters(WWMScratchRegs);
   if (!WWMCalleeSavedRegs.empty()) {
     if (ScratchExecCopy) {
@@ -1634,6 +1685,7 @@ void SIFrameLowering::determineCalleeSaves(MachineFunction &MF,
         NeedExecCopyReservedReg = true;
       else if (MI....
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants