-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[AMDGPU] ISel & PEI for whole wave functions #145858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This stack of pull requests is managed by Graphite. Learn more about stacking. |
What is the fundamental motivation of doing this? |
The fundamental motivation for doing this is that the WWM intrinsics are fragile and don't handle control flow very well. We have an increasing number of use cases where we need to have some control flow inside a WWM region, and that isn't supported that well. For one of these use cases we've introduced the |
Whole wave functions are functions that will run with a full EXEC mask. They will not be invoked directly, but instead will be launched by way of a new intrinsic, `llvm.amdgcn.call.whole.wave` (to be added in a future patch). These functions are meant as an alternative to the `llvm.amdgcn.init.whole.wave` or `llvm.amdgcn.strict.wwm` intrinsics. Whole wave functions will set EXEC to -1 in the prologue and restore the original value of EXEC in the epilogue. They must have a special first argument, `i1 %active`, that is going to be mapped to EXEC. They may have either the default calling convention or amdgpu_gfx. The inactive lanes need to be preserved for all registers used, active lanes only for the CSRs. At the IR level, arguments to a whole wave function (other than `%active`) contain poison in their inactive lanes. Likewise, the return value for the inactive lanes is poison. This patch contains the following work: * 2 new pseudos, SI_SETUP_WHOLE_WAVE_FUNC and SI_WHOLE_WAVE_FUNC_RETURN used for managing the EXEC mask. SI_SETUP_WHOLE_WAVE_FUNC will return a SReg_1 representing `%active`, which needs to be passed into SI_WHOLE_WAVE_FUNC_RETURN. * SelectionDAG support for generating these 2 new pseudos and the special handling of %active. Since the return may be in a different basic block, it's difficult to add the virtual reg for %active to SI_WHOLE_WAVE_FUNC_RETURN, so we initially generate an IMPLICIT_DEF which is later replaced via a custom inserter. * Expansion of the 2 pseudos during prolog/epilog insertion. PEI also marks any used VGPRs are WWM registers, which are then spilled and restored with the usual logic. I'm still working on the GlobalISel support and on adding some docs in AMDGPUUsage.rst. Future patches will include the `llvm.amdgcn.call.whole.wave` intrinsic, a codegen prepare patch that looks for the callees of that intrinsic and marks them as whole wave functions, and probably a lot of optimization work.
This reverts commit c6e9211d5644061521cbce8edac7c475c83b01d6.
a666b2d
to
8ea4ac9
Compare
@@ -1662,6 +1714,21 @@ void SIFrameLowering::determineCalleeSaves(MachineFunction &MF, | |||
if (MFI->isEntryFunction()) | |||
return; | |||
|
|||
if (MFI->isWholeWaveFunction()) { | |||
// In practice, all the VGPRs are WWM registers, and we will need to save at | |||
// least their inactive lanes. Add them to WWMReservedRegs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we considering all VGPRs to be WWM regs in this calling convention?
@@ -564,6 +564,10 @@ declare riscv_vls_cc(32768) void @riscv_vls_cc_32768() | |||
; CHECK: declare riscv_vls_cc(32768) void @riscv_vls_cc_32768() | |||
declare riscv_vls_cc(65536) void @riscv_vls_cc_65536() | |||
; CHECK: declare riscv_vls_cc(65536) void @riscv_vls_cc_65536() | |||
declare cc124 void @f.cc124(i1) | |||
; CHECK: declare amdgpu_gfx_whole_wave void @f.cc124(i1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does cc124
become amdgpu_gfx_whole_wave
here?
@llvm/pr-subscribers-backend-amdgpu Author: Diana Picus (rovka) ChangesWhole wave functions are functions that will run with a full EXEC mask. Whole wave functions will set EXEC to -1 in the prologue and restore the At the IR level, arguments to a whole wave function (other than This patch contains the following work:
Future patches will include the Patch is 212.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145858.diff 41 Files Affected:
diff --git a/llvm/docs/AMDGPUUsage.rst b/llvm/docs/AMDGPUUsage.rst
index c5b9bd9de66e1..19357635ecfc1 100644
--- a/llvm/docs/AMDGPUUsage.rst
+++ b/llvm/docs/AMDGPUUsage.rst
@@ -1844,6 +1844,20 @@ The AMDGPU backend supports the following calling conventions:
..TODO::
Describe.
+ ``amdgpu_gfx_whole_wave`` Used for AMD graphics targets. Functions with this calling convention
+ cannot be used as entry points. They must have an i1 as the first argument,
+ which will be mapped to the value of EXEC on entry into the function. Other
+ arguments will contain poison in their inactive lanes. Similarly, the return
+ value for the inactive lanes is poison.
+
+ The function will run with all lanes enabled, i.e. EXEC will be set to -1 in the
+ prologue and restored to its original value in the epilogue. The inactive lanes
+ will be preserved for all the registers used by the function. Active lanes only
+ will only be preserved for the callee saved registers.
+
+ In all other respects, functions with this calling convention behave like
+ ``amdgpu_gfx`` functions.
+
``amdgpu_gs`` Used for Mesa/AMDPAL geometry shaders.
..TODO::
Describe.
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index c7e4bdf3ff811..a2311d2ac285d 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -181,6 +181,7 @@ enum Kind {
kw_amdgpu_cs_chain_preserve,
kw_amdgpu_kernel,
kw_amdgpu_gfx,
+ kw_amdgpu_gfx_whole_wave,
kw_tailcc,
kw_m68k_rtdcc,
kw_graalcc,
diff --git a/llvm/include/llvm/IR/CallingConv.h b/llvm/include/llvm/IR/CallingConv.h
index d68491eb5535c..ef761eb1aed73 100644
--- a/llvm/include/llvm/IR/CallingConv.h
+++ b/llvm/include/llvm/IR/CallingConv.h
@@ -284,6 +284,9 @@ namespace CallingConv {
RISCV_VLSCall_32768 = 122,
RISCV_VLSCall_65536 = 123,
+ // Calling convention for AMDGPU whole wave functions.
+ AMDGPU_Gfx_WholeWave = 124,
+
/// The highest possible ID. Must be some 2^k - 1.
MaxID = 1023
};
@@ -294,8 +297,13 @@ namespace CallingConv {
/// directly or indirectly via a call-like instruction.
constexpr bool isCallableCC(CallingConv::ID CC) {
switch (CC) {
+ // Called with special intrinsics:
+ // llvm.amdgcn.cs.chain
case CallingConv::AMDGPU_CS_Chain:
case CallingConv::AMDGPU_CS_ChainPreserve:
+ // llvm.amdgcn.call.whole.wave
+ case CallingConv::AMDGPU_Gfx_WholeWave:
+ // Hardware entry points:
case CallingConv::AMDGPU_CS:
case CallingConv::AMDGPU_ES:
case CallingConv::AMDGPU_GS:
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index ce813e1d7b1c4..520c6a00a9c07 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -679,6 +679,7 @@ lltok::Kind LLLexer::LexIdentifier() {
KEYWORD(amdgpu_cs_chain_preserve);
KEYWORD(amdgpu_kernel);
KEYWORD(amdgpu_gfx);
+ KEYWORD(amdgpu_gfx_whole_wave);
KEYWORD(tailcc);
KEYWORD(m68k_rtdcc);
KEYWORD(graalcc);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index c5e166cef6da6..b09696497cc4e 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -2274,6 +2274,9 @@ bool LLParser::parseOptionalCallingConv(unsigned &CC) {
CC = CallingConv::AMDGPU_CS_ChainPreserve;
break;
case lltok::kw_amdgpu_kernel: CC = CallingConv::AMDGPU_KERNEL; break;
+ case lltok::kw_amdgpu_gfx_whole_wave:
+ CC = CallingConv::AMDGPU_Gfx_WholeWave;
+ break;
case lltok::kw_tailcc: CC = CallingConv::Tail; break;
case lltok::kw_m68k_rtdcc: CC = CallingConv::M68k_RTD; break;
case lltok::kw_graalcc: CC = CallingConv::GRAAL; break;
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 7828ba45ec27f..3ce892ecbff19 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -404,6 +404,9 @@ static void PrintCallingConv(unsigned cc, raw_ostream &Out) {
break;
case CallingConv::AMDGPU_KERNEL: Out << "amdgpu_kernel"; break;
case CallingConv::AMDGPU_Gfx: Out << "amdgpu_gfx"; break;
+ case CallingConv::AMDGPU_Gfx_WholeWave:
+ Out << "amdgpu_gfx_whole_wave";
+ break;
case CallingConv::M68k_RTD: Out << "m68k_rtdcc"; break;
case CallingConv::RISCV_VectorCall:
Out << "riscv_vector_cc";
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 3e7fcbb983738..b1e8dd716063f 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1226,6 +1226,7 @@ bool llvm::CallingConv::supportsNonVoidReturnType(CallingConv::ID CC) {
case CallingConv::AArch64_SVE_VectorCall:
case CallingConv::WASM_EmscriptenInvoke:
case CallingConv::AMDGPU_Gfx:
+ case CallingConv::AMDGPU_Gfx_WholeWave:
case CallingConv::M68k_INTR:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 9cab88b09779a..32ce1880f2fdd 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2975,6 +2975,16 @@ void Verifier::visitFunction(const Function &F) {
"perfect forwarding!",
&F);
break;
+ case CallingConv::AMDGPU_Gfx_WholeWave:
+ Check(F.arg_size() != 0 && F.arg_begin()->getType()->isIntegerTy(1),
+ "Calling convention requires first argument to be i1", &F);
+ Check(!F.arg_begin()->hasInRegAttr(),
+ "Calling convention requires first argument to not be inreg", &F);
+ Check(!F.isVarArg(),
+ "Calling convention does not support varargs or "
+ "perfect forwarding!",
+ &F);
+ break;
}
// Check that the argument values match the function type for this function...
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
index 14101e57f5143..b4ea3c81b3b6e 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
@@ -374,8 +374,10 @@ bool AMDGPUCallLowering::lowerReturn(MachineIRBuilder &B, const Value *Val,
return true;
}
- unsigned ReturnOpc =
- IsShader ? AMDGPU::SI_RETURN_TO_EPILOG : AMDGPU::SI_RETURN;
+ const bool IsWholeWave = MFI->isWholeWaveFunction();
+ unsigned ReturnOpc = IsWholeWave ? AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN
+ : IsShader ? AMDGPU::SI_RETURN_TO_EPILOG
+ : AMDGPU::SI_RETURN;
auto Ret = B.buildInstrNoInsert(ReturnOpc);
if (!FLI.CanLowerReturn)
@@ -383,6 +385,10 @@ bool AMDGPUCallLowering::lowerReturn(MachineIRBuilder &B, const Value *Val,
else if (!lowerReturnVal(B, Val, VRegs, Ret))
return false;
+ if (IsWholeWave) {
+ addOriginalExecToReturn(B.getMF(), Ret);
+ }
+
// TODO: Handle CalleeSavedRegsViaCopy.
B.insertInstr(Ret);
@@ -632,6 +638,17 @@ bool AMDGPUCallLowering::lowerFormalArguments(
if (DL.getTypeStoreSize(Arg.getType()) == 0)
continue;
+ if (Info->isWholeWaveFunction() && Idx == 0) {
+ assert(VRegs[Idx].size() == 1 && "Expected only one register");
+
+ // The first argument for whole wave functions is the original EXEC value.
+ B.buildInstr(AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_SETUP)
+ .addDef(VRegs[Idx][0]);
+
+ ++Idx;
+ continue;
+ }
+
const bool InReg = Arg.hasAttribute(Attribute::InReg);
if (Arg.hasAttribute(Attribute::SwiftSelf) ||
@@ -1347,6 +1364,7 @@ bool AMDGPUCallLowering::lowerTailCall(
SmallVector<std::pair<MCRegister, Register>, 12> ImplicitArgRegs;
if (Info.CallConv != CallingConv::AMDGPU_Gfx &&
+ Info.CallConv != CallingConv::AMDGPU_Gfx_WholeWave &&
!AMDGPU::isChainCC(Info.CallConv)) {
// With a fixed ABI, allocate fixed registers before user arguments.
if (!passSpecialInputs(MIRBuilder, CCInfo, ImplicitArgRegs, Info))
@@ -1524,7 +1542,8 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// after the ordinary user argument registers.
SmallVector<std::pair<MCRegister, Register>, 12> ImplicitArgRegs;
- if (Info.CallConv != CallingConv::AMDGPU_Gfx) {
+ if (Info.CallConv != CallingConv::AMDGPU_Gfx &&
+ Info.CallConv != CallingConv::AMDGPU_Gfx_WholeWave) {
// With a fixed ABI, allocate fixed registers before user arguments.
if (!passSpecialInputs(MIRBuilder, CCInfo, ImplicitArgRegs, Info))
return false;
@@ -1592,3 +1611,11 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
return true;
}
+
+void AMDGPUCallLowering::addOriginalExecToReturn(
+ MachineFunction &MF, MachineInstrBuilder &Ret) const {
+ const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
+ const SIInstrInfo *TII = ST.getInstrInfo();
+ const MachineInstr *Setup = TII->getWholeWaveFunctionSetup(MF);
+ Ret.addReg(Setup->getOperand(0).getReg());
+}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h
index a6e801f2a547b..e0033d59d10bb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h
@@ -37,6 +37,9 @@ class AMDGPUCallLowering final : public CallLowering {
bool lowerReturnVal(MachineIRBuilder &B, const Value *Val,
ArrayRef<Register> VRegs, MachineInstrBuilder &Ret) const;
+ void addOriginalExecToReturn(MachineFunction &MF,
+ MachineInstrBuilder &Ret) const;
+
public:
AMDGPUCallLowering(const AMDGPUTargetLowering &TLI);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
index 1b909568fc555..c5063c4de4ad3 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
@@ -300,6 +300,10 @@ def : GINodeEquiv<G_AMDGPU_S_BUFFER_LOAD_SSHORT, SIsbuffer_load_short>;
def : GINodeEquiv<G_AMDGPU_S_BUFFER_LOAD_USHORT, SIsbuffer_load_ushort>;
def : GINodeEquiv<G_AMDGPU_S_BUFFER_PREFETCH, SIsbuffer_prefetch>;
+def : GINodeEquiv<G_AMDGPU_WHOLE_WAVE_FUNC_SETUP, AMDGPUwhole_wave_setup>;
+// G_AMDGPU_WHOLE_WAVE_FUNC_RETURN is simpler than AMDGPUwhole_wave_return,
+// so we don't mark it as equivalent.
+
class GISelSop2Pat <
SDPatternOperator node,
Instruction inst,
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index d75c7a178b4a8..0421ed87e61f4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -1138,6 +1138,7 @@ CCAssignFn *AMDGPUCallLowering::CCAssignFnForCall(CallingConv::ID CC,
case CallingConv::Cold:
return CC_AMDGPU_Func;
case CallingConv::AMDGPU_Gfx:
+ case CallingConv::AMDGPU_Gfx_WholeWave:
return CC_SI_Gfx;
case CallingConv::AMDGPU_KERNEL:
case CallingConv::SPIR_KERNEL:
@@ -1163,6 +1164,7 @@ CCAssignFn *AMDGPUCallLowering::CCAssignFnForReturn(CallingConv::ID CC,
case CallingConv::AMDGPU_LS:
return RetCC_SI_Shader;
case CallingConv::AMDGPU_Gfx:
+ case CallingConv::AMDGPU_Gfx_WholeWave:
return RetCC_SI_Gfx;
case CallingConv::C:
case CallingConv::Fast:
@@ -5777,6 +5779,8 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(BUFFER_ATOMIC_FMIN)
NODE_NAME_CASE(BUFFER_ATOMIC_FMAX)
NODE_NAME_CASE(BUFFER_ATOMIC_COND_SUB_U32)
+ NODE_NAME_CASE(WHOLE_WAVE_SETUP)
+ NODE_NAME_CASE(WHOLE_WAVE_RETURN)
}
return nullptr;
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
index 0dd2183b72b24..5716711de3402 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
@@ -607,6 +607,12 @@ enum NodeType : unsigned {
BUFFER_ATOMIC_FMAX,
BUFFER_ATOMIC_COND_SUB_U32,
LAST_MEMORY_OPCODE = BUFFER_ATOMIC_COND_SUB_U32,
+
+ // Set up a whole wave function.
+ WHOLE_WAVE_SETUP,
+
+ // Return from a whole wave function.
+ WHOLE_WAVE_RETURN,
};
} // End namespace AMDGPUISD
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
index ce58e93a15207..e305f08925cc6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
@@ -348,6 +348,17 @@ def AMDGPUfdot2_impl : SDNode<"AMDGPUISD::FDOT2",
def AMDGPUperm_impl : SDNode<"AMDGPUISD::PERM", AMDGPUDTIntTernaryOp, []>;
+// Marks the entry into a whole wave function.
+def AMDGPUwhole_wave_setup : SDNode<
+ "AMDGPUISD::WHOLE_WAVE_SETUP", SDTypeProfile<1, 0, [SDTCisInt<0>]>,
+ [SDNPHasChain, SDNPSideEffect]>;
+
+// Marks the return from a whole wave function.
+def AMDGPUwhole_wave_return : SDNode<
+ "AMDGPUISD::WHOLE_WAVE_RETURN", SDTNone,
+ [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
+>;
+
// SI+ export
def AMDGPUExportOp : SDTypeProfile<0, 8, [
SDTCisInt<0>, // i8 tgt
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index b632b16f5c198..d86e7735a07bc 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -4141,6 +4141,10 @@ bool AMDGPUInstructionSelector::select(MachineInstr &I) {
return true;
case AMDGPU::G_AMDGPU_WAVE_ADDRESS:
return selectWaveAddress(I);
+ case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN: {
+ I.setDesc(TII.get(AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN));
+ return true;
+ }
case AMDGPU::G_STACKRESTORE:
return selectStackRestore(I);
case AMDGPU::G_PHI:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index b20760c356263..a07699ae1eb23 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -5458,6 +5458,10 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case AMDGPU::G_PREFETCH:
OpdsMapping[0] = getSGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
break;
+ case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_SETUP:
+ case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN:
+ OpdsMapping[0] = AMDGPU::getValueMapping(AMDGPU::VCCRegBankID, 1);
+ break;
}
return getInstructionMapping(/*ID*/1, /*Cost*/1,
diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
index bc95d3f040e1d..098c2dc2405df 100644
--- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
@@ -3155,7 +3155,7 @@ bool GCNHazardRecognizer::fixRequiredExportPriority(MachineInstr *MI) {
// Check entry priority at each export (as there will only be a few).
// Note: amdgpu_gfx can only be a callee, so defer to caller setprio.
bool Changed = false;
- if (CC != CallingConv::AMDGPU_Gfx)
+ if (CC != CallingConv::AMDGPU_Gfx && CC != CallingConv::AMDGPU_Gfx_WholeWave)
Changed = ensureEntrySetPrio(MF, NormalPriority, TII);
auto NextMI = std::next(It);
diff --git a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp
index 6a3867937d57f..b88df50c6c999 100644
--- a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp
@@ -946,8 +946,18 @@ static Register buildScratchExecCopy(LiveRegUnits &LiveUnits,
initLiveUnits(LiveUnits, TRI, FuncInfo, MF, MBB, MBBI, IsProlog);
- ScratchExecCopy = findScratchNonCalleeSaveRegister(
- MRI, LiveUnits, *TRI.getWaveMaskRegClass());
+ if (FuncInfo->isWholeWaveFunction()) {
+ // Whole wave functions already have a copy of the original EXEC mask that
+ // we can use.
+ assert(IsProlog && "Epilog should look at return, not setup");
+ ScratchExecCopy =
+ TII->getWholeWaveFunctionSetup(MF)->getOperand(0).getReg();
+ assert(ScratchExecCopy && "Couldn't find copy of EXEC");
+ } else {
+ ScratchExecCopy = findScratchNonCalleeSaveRegister(
+ MRI, LiveUnits, *TRI.getWaveMaskRegClass());
+ }
+
if (!ScratchExecCopy)
report_fatal_error("failed to find free scratch register");
@@ -996,10 +1006,15 @@ void SIFrameLowering::emitCSRSpillStores(
};
StoreWWMRegisters(WWMScratchRegs);
+
+ auto EnableAllLanes = [&]() {
+ unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
+ BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addImm(-1);
+ };
+
if (!WWMCalleeSavedRegs.empty()) {
if (ScratchExecCopy) {
- unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
- BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addImm(-1);
+ EnableAllLanes();
} else {
ScratchExecCopy = buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
/*IsProlog*/ true,
@@ -1008,7 +1023,18 @@ void SIFrameLowering::emitCSRSpillStores(
}
StoreWWMRegisters(WWMCalleeSavedRegs);
- if (ScratchExecCopy) {
+ if (FuncInfo->isWholeWaveFunction()) {
+ // SI_WHOLE_WAVE_FUNC_SETUP has outlived its purpose, so we can remove
+ // it now. If we have already saved some WWM CSR registers, then the EXEC is
+ // already -1 and we don't need to do anything else. Otherwise, set EXEC to
+ // -1 here.
+ if (!ScratchExecCopy)
+ buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL, /*IsProlog*/ true,
+ /*EnableInactiveLanes*/ true);
+ else if (WWMCalleeSavedRegs.empty())
+ EnableAllLanes();
+ TII->getWholeWaveFunctionSetup(MF)->eraseFromParent();
+ } else if (ScratchExecCopy) {
// FIXME: Split block and make terminator.
unsigned ExecMov = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
BuildMI(MBB, MBBI, DL, TII->get(ExecMov), TRI.getExec())
@@ -1083,11 +1109,6 @@ void SIFrameLowering::emitCSRSpillRestores(
Register ScratchExecCopy;
SmallVector<std::pair<Register, int>, 2> WWMCalleeSavedRegs, WWMScratchRegs;
FuncInfo->splitWWMSpillRegisters(MF, WWMCalleeSavedRegs, WWMScratchRegs);
- if (!WWMScratchRegs.empty())
- ScratchExecCopy =
- buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
- /*IsProlog*/ false, /*EnableInactiveLanes*/ true);
-
auto RestoreWWMRegisters =
[&](SmallVectorImpl<std::pair<Register, int>> &WWMRegs) {
for (const auto &Reg : WWMRegs) {
@@ -1098,6 +1119,36 @@ void SIFrameLowering::emitCSRSpillRestores(
}
};
+ if (FuncInfo->isWholeWaveFunction()) {
+ // For whole wave functions, the EXEC is already -1 at this point.
+ // Therefore, we can restore the CSR WWM registers right away.
+ RestoreWWMRegisters(WWMCalleeSavedRegs);
+
+ // The original EXEC is the first operand of the return instruction.
+ const MachineInstr &Return = MBB.instr_back();
+ assert(Return.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN &&
+ "Unexpected return inst");
+ Register OrigExec = Return.getOperand(0).getReg();
+
+ if (!WWMScratchRegs.empty()) {
+ unsigned XorOpc = ST.isWave32() ? AMDGPU::S_XOR_B32 : AMDGPU::S_XOR_B64;
+ BuildMI(MBB, MBBI, DL, TII->get(XorOpc), TRI.getExec())
+ .addReg(OrigExec)
+ .addImm(-1);
+ RestoreWWMRegisters(WWMScratchRegs);
+ }
+
+ // Restore original EXEC.
+ unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
+ BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addReg(OrigExec);
+ return;
+ }
+
+ if (!WWMScratchRegs.empty())
+ ScratchExecCopy =
+ buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
+ /*IsProlog*/ false, /*EnableInactiveLanes*/ true);
+
RestoreWWMRegisters(WWMScratchRegs);
if (!WWMCalleeSavedRegs.empty()) {
if (ScratchExecCopy) {
@@ -1634,6 +1685,7 @@ void SIFrameLowering::determineCalleeSaves(MachineFunction &MF,
NeedExecCopyReservedReg = true;
else if (MI....
[truncated]
|
Whole wave functions are functions that will run with a full EXEC mask.
They will not be invoked directly, but instead will be launched by way
of a new intrinsic,
llvm.amdgcn.call.whole.wave
(to be added ina future patch). These functions are meant as an alternative to the
llvm.amdgcn.init.whole.wave
orllvm.amdgcn.strict.wwm
intrinsics.Whole wave functions will set EXEC to -1 in the prologue and restore the
original value of EXEC in the epilogue. They must have a special first
argument,
i1 %active
, that is going to be mapped to EXEC. They mayhave either the default calling convention or amdgpu_gfx. The inactive
lanes need to be preserved for all registers used, active lanes only for
the CSRs.
At the IR level, arguments to a whole wave function (other than
%active
) contain poison in their inactive lanes. Likewise, the returnvalue for the inactive lanes is poison.
This patch contains the following work:
used for managing the EXEC mask. SI_SETUP_WHOLE_WAVE_FUNC will return
a SReg_1 representing
%active
, which needs to be passed intoSI_WHOLE_WAVE_FUNC_RETURN.
special handling of %active. Since the return may be in a different
basic block, it's difficult to add the virtual reg for %active to
SI_WHOLE_WAVE_FUNC_RETURN, so we initially generate an IMPLICIT_DEF
which is later replaced via a custom inserter.
marks any used VGPRs are WWM registers, which are then spilled and
restored with the usual logic.
Future patches will include the
llvm.amdgcn.call.whole.wave
intrinsicand a lot of optimization work (especially in order to reduce spills around
function calls).