Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LLs in AsyncCopyGlobalToLocalOp lowering. #4070

Merged
merged 6 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,25 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
return multiDimIdx;
}

// Emits IR to load data from shared memory into registers, or to store data
// from registers into shared memory.
//
// You supply perVectorCallback, which is called once per group of register
// elements to transfer. You can use this callback to emit IR to load or store
// data from or to shared memory.
//
// elemLlvmTy should be dstTy's element type converted to an LLVM-dialect type.
//
// If maxVecElems is provided, we won't vectorize more than this many elements.
//
// Returns true on success.
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);

inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
Location loc, const TargetInfoBase &target, unsigned inVec,
RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout,
Expand Down
58 changes: 45 additions & 13 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class LinearLayout {
/*size=getInDimSizeLog2(inDim)*/>
bases;

llvm::SetVector<StringAttr> outDimNames;
llvm::MapVector<StringAttr, int32_t /*size*/> outDims;
bool surjective;

public:
Expand All @@ -349,9 +349,38 @@ class LinearLayout {
// Creates a LinearLayout from a list of bases. These are interpreted
// according to the rules written for the member variable `bases`.
//
// Assert-fails if requireSurjective is true and the bases are not surjective.
explicit LinearLayout(BasesT bases, ArrayRef<StringAttr> outDimNames,
bool requireSurjective = true);
// Calculates the out-dim sizes according to the bases. Consider the
// following example.
//
// L(in1=1) = (out1=1, out2=0)
// L(in1=2) = (out1=5, out2=1)
// L(in1=4) = (out1=2, out2=2)
//
// To calculate the out-dim sizes, we first find the largest values for out1
// and out2, namely 5 and 2, then round these up to the next power of 2,
// namely 8 and 4. These are the out-dim sizes.
//
// Assert-fails if the layout is not surjective given these out-dim sizes.
// That is, every possible out-dim in range [0, size) must be produced by
// xor'ing some combination of bases.
explicit LinearLayout(BasesT bases, ArrayRef<StringAttr> outDimNames);

// Creates a LinearLayout given a list of bases and the explicit out-dimension
// sizes. Allows the layout to be non-surjective.
//
// To see why we need to explicitly pass out-dim sizes when creating a
// non-surjective layout, consider the following example.
//
// L(in1=1) = 1
// L(in1=2) = 4
//
// If we naively infer the out-dim sizes from these bases, we'd infer a size
// of nextPow2(4) = 8. But given that the layout is non-surjective, who is to
// say that the codomain is not (say) [0,32)? We can't tell, thus we need to
// be explicit about the sizes.
explicit LinearLayout(BasesT bases,
ArrayRef<std::pair<StringAttr, int32_t>> outDims,
bool requireSurjective);

// Construct a LinearLayout from an explicit list of bases. (This constructor
// is needed because llvm::MapVector does not have a constructor that accepts
Expand All @@ -373,10 +402,14 @@ class LinearLayout {
// },
// {"out1", "out2"})
//
// Assert-fails if requireSurjective is true and the bases are not surjective.
// The overload that infers out-dim sizes assert-fails if the layout is not
// surjective.
explicit LinearLayout(
ArrayRef<std::pair<StringAttr, std::vector<std::vector<int32_t>>>> bases,
ArrayRef<StringAttr> outDimNames);
explicit LinearLayout(
ArrayRef<std::pair<StringAttr, std::vector<std::vector<int32_t>>>> bases,
ArrayRef<StringAttr> outDimNames, bool requireSurjective = true);
ArrayRef<std::pair<StringAttr, int32_t>> outDims, bool requireSurjective);

bool isSurjective() const { return surjective; }

Expand All @@ -399,21 +432,17 @@ class LinearLayout {
// These are in minor-to-major order, although if you don't flatten the dims
// (e.g. by reshaping) then the order doesn't really affect anything.
auto getInDimNames() const { return llvm::make_first_range(bases); }
ArrayRef<StringAttr> getOutDimNames() const {
return outDimNames.getArrayRef();
}
auto getOutDimNames() const { return llvm::make_first_range(outDims); }

// Gets the position that this outDim occupies in getOutDimNames(). Asserts
// if the dim is not present.
int32_t getOutDimIndex(StringAttr outDim) const;

bool hasInDim(StringAttr inDim) const { return bases.contains(inDim); }
bool hasOutDim(StringAttr outDim) const {
return outDimNames.contains(outDim);
}
bool hasOutDim(StringAttr outDim) const { return outDims.contains(outDim); }

int32_t getNumInDims() const { return bases.size(); }
int32_t getNumOutDims() const { return outDimNames.size(); }
int32_t getNumOutDims() const { return outDims.size(); }

// Asserts if the dimension is not present.
int32_t getInDimSizeLog2(StringAttr inDim) const;
Expand Down Expand Up @@ -613,6 +642,9 @@ class LinearLayout {
friend bool operator!=(LinearLayout lhs, LinearLayout rhs) {
return !(lhs == rhs);
}

private:
void checkInvariants(bool requireSurjective);
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
Expand Down
55 changes: 31 additions & 24 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,11 @@ emitIndicesUsingLinearLayouts(Location loc, RewriterBase &rewriter,
return ret;
}

// elemLlvmTy should be dstTy's element type converted to an LLVM-dialect type.
//
// Calls perVectorCallback once for each group of register elems to transfer,
// and passes the shmem address for that group.
//
// Returns true on success.
static bool emitTransferBetweenRegistersAndShared(
bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy,
Value shmemBase, ArrayRef<Value> shmemStrides, Location loc,
RewriterBase &rewriter, const TargetInfoBase &target,
std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
MLIRContext *ctx = rewriter.getContext();

Expand Down Expand Up @@ -301,17 +296,26 @@ static bool emitTransferBetweenRegistersAndShared(

// TODO(jlebar): We don't currently support loading from shared memory in a
// different CTA. We'd need to emit `mapa.shared::cluster` instructions.
if (regToSharedLayout.getInDimSize(kBlock) !=
regToSharedLayout.getOutDimSize(kBlock)) {
return false;
}
for (int i = 1; i < regToSharedLayout.getInDimSize(kBlock); i *= 2) {
for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock);
inBlock *= 2) {
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply(
{{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, i}})));
auto offsets = ArrayRef(idx).drop_back(1);
int32_t block = idx.back();
if (!llvm::all_of(offsets, [&](auto offset) { return offset == 0; }) ||
block != i) {
{{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}})));
// offsetX1, ..., offsetXN must all be 0.
if (!llvm::all_of(ArrayRef(idx).drop_back(1),
[&](auto offset) { return offset == 0; })) {
return false;
}

// We now have
// regToSharedLayout(0, ..., block=inBlock) => (0, ..., block=outBlock).
// To confirm that there's no cross-block communication, we must also have
// outBlock == inBlock or outBlock == 0.
//
// The fact that outBlock == 0 works is nonobvious. It occurs when the
// shared layout is broadcasted in its block dim, i.e. multiple blocks
// contain the same data.
int32_t outBlock = idx.back();
if (outBlock != inBlock && outBlock != 0) {
return false;
}
}
Expand All @@ -327,7 +331,9 @@ static bool emitTransferBetweenRegistersAndShared(
// calling getNumConsecutiveInOut(), we could flatten consecutive out-dims
// which have known strides. This would allow us to vectorize across multiple
// shmem out dimensions where possible.
const int vecElems = regToSharedLayout.getNumConsecutiveInOut();
const int vecElems =
std::min(regToSharedLayout.getNumConsecutiveInOut(),
maxVecElems.value_or(std::numeric_limits<int>::max()));

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(regToSharedLayout.getInDimSize(kLane));
Expand Down Expand Up @@ -369,8 +375,9 @@ std::optional<SmallVector<Value>> loadSharedToRegistersUsingLinearLayouts(
const TargetInfoBase &target) {
SmallVector<Value> ret;
bool success = emitTransferBetweenRegistersAndShared(
dstTy, srcTy, elemLlvmTy, smemObj.getBase(), smemObj.getStrides(), loc,
rewriter, target, [&](VectorType vecTy, Value vecAddr) {
dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.getBase(),
smemObj.getStrides(), loc, rewriter, target,
[&](VectorType vecTy, Value vecAddr) {
auto vecVal = load(vecTy, vecAddr);
vecVal.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
Expand All @@ -391,8 +398,8 @@ bool storeDistributedToSharedUsingLinearLayouts(
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target) {
bool success = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, elemLlvmTy, smemBase, dstStrides, loc, rewriter, target,
[&](VectorType vecTy, Value vecAddr) {
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
ArrayRef<Value> vals = srcVals.take_front(vecTy.getNumElements());
srcVals = srcVals.drop_front(vecTy.getNumElements());

Expand Down
13 changes: 8 additions & 5 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ LinearLayout shrinkCodomain(const LinearLayout &layout, StringAttr inDimName,
}

// Construct O'.
LinearLayout transform(std::move(newBases), layout.getOutDimNames());
LinearLayout transform(std::move(newBases),
llvm::to_vector(layout.getOutDimNames()));

// Compose O' with L.
return layout.compose(transform);
Expand Down Expand Up @@ -306,12 +307,13 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,

LinearLayout cgaLayout =
ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape)
.transposeOuts(ctaLayout.getOutDimNames());
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

// Calculate the shape of the ctaLayout, which is `shape` divided by the
// cgaLayout's size.
llvm::SmallDenseMap<StringAttr, int64_t> ctaShape;
assert(ctaLayout.getOutDimNames() == cgaLayout.getOutDimNames());
assert(llvm::to_vector(ctaLayout.getOutDimNames()) ==
llvm::to_vector(cgaLayout.getOutDimNames()));
for (auto dim : ctaLayout.getOutDimNames()) {
ctaShape[dim] =
std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim));
Expand Down Expand Up @@ -404,7 +406,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
// this really does seem to be correct.
ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1},
{S("dim0"), S("dim1")})
.transposeOuts(ctaLayout.getOutDimNames());
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}
Expand Down Expand Up @@ -455,7 +457,8 @@ std::optional<LinearLayout> sliceToLinearLayout(ArrayRef<int64_t> shape,
}
bases[S("register")] = newRegBases;

LinearLayout ret = LinearLayout(std::move(bases), sliceLL.getOutDimNames());
LinearLayout ret =
LinearLayout(std::move(bases), llvm::to_vector(sliceLL.getOutDimNames()));

// Match a hack in the legacy code that ensures that the number of registers
// matches getTotalElemsPerThread. Yup: We just removed all the zeros, now
Expand Down
Loading
Loading