Skip to content

Commit

Permalink
Add turing mma support and test (#1643)
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed May 24, 2022
1 parent d6d6b7d commit 5e6a8da
Show file tree
Hide file tree
Showing 8 changed files with 826 additions and 39 deletions.
8 changes: 8 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,14 @@ void validateMma(Fusion* fusion) {
case MmaOptions::MacroType::Volta_16_16_4:
validateMinimumArch(7, 0);
break;
case MmaOptions::MacroType::Turing_16_8_16:
validateMinimumArch(7, 5);

// Check that operands come from ldmatrix, can be
// relaxed once swizzles can be labeled on iterdomains.
validateTuringMmaInput(mma->inA()->as<TensorView>());
validateTuringMmaInput(mma->inB()->as<TensorView>());
break;
case MmaOptions::MacroType::Ampere_16_8_16:
validateMinimumArch(8, 0);

Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/codegen/cuda/mma_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ MmaBuilder::MmaBuilder(
case MmaOptions::MacroType::Volta_16_16_4:
option_.accumulator_stride = outer_stride * 4;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
option_.accumulator_stride = outer_stride * 2;
break;
Expand Down Expand Up @@ -58,6 +59,7 @@ namespace {
LoadStoreOpType getLdMatrixType(MmaOptions options) {
bool transpose = false;
switch (options.macro) {
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
// Turing mma assumes TN as default
transpose = (options.operand == MmaOptions::Operand::A &&
Expand All @@ -84,7 +86,7 @@ bool isVolta(MmaOptions::MacroType macro) {
}

bool isTuring(MmaOptions::MacroType macro) {
return false;
return macro == MmaOptions::MacroType::Turing_16_8_16;
}

bool isAmpere(MmaOptions::MacroType macro) {
Expand All @@ -96,6 +98,7 @@ int getOutputRegisterSize(MmaOptions::MacroType macro) {
case MmaOptions::MacroType::Volta_16_16_4:
return 8;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
return 4;
break;
Expand All @@ -111,6 +114,7 @@ int getInputARegisterSize(MmaOptions::MacroType macro) {
case MmaOptions::MacroType::Volta_16_16_4:
return 4;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
return 8;
break;
Expand All @@ -126,6 +130,7 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) {
case MmaOptions::MacroType::Volta_16_16_4:
return 4;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
return 4;
default:
Expand Down Expand Up @@ -176,6 +181,7 @@ std::string toString(MmaOptions::MacroType mt) {
case MmaOptions::MacroType::Volta_16_16_4:
ss << "M16N16K4";
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
ss << "M16N8K16";
break;
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/codegen/cuda/mma_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct MmaOptions {
NoMMA = 0,
Volta_16_16_4,
Ampere_16_8_16,
Turing_16_8_16,
Ampere_16_8_8 // place holder for tf32
};

Expand All @@ -73,7 +74,7 @@ struct MmaOptions {
enum class MmaInputLayout { NT = 0, TT, TN };

//! Utility to annotate which input of mma this option struct describes
enum class Operand { NotOperand = 0, A, B };
enum class Operand { Accumulator = 0, A, B };

//! Utility to annotate which mma macro this config uses.
MacroType macro = MacroType::NoMMA;
Expand Down Expand Up @@ -117,7 +118,7 @@ class TORCH_CUDA_CU_API MmaBuilder {
//! Specifies which element in the mma op this builder is generating
//! parameters for, i.e. A or B. This is useful when generating
//! data swizzles for different elements of mma.
//! - Operand::NotOperand means the parameters describe accumulator in mma
//! - Operand::Accumulator means the parameters describe accumulator in mma
//! op.
//! - This option is ignored when configuring the mma operator itself.
MmaBuilder& operand(MmaOptions::Operand a_or_b);
Expand Down
36 changes: 36 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/memory.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,40 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) {
return smem_ptr_uint;
}

// LdMatrix has .x1, .x2 and .x4 options, currently we actively use .x2 and
// .x4. In .x2 option. the the address register of upper half warp (lane 16-31)
// are un-used but on Turing [sm75,sm80) architecture these un-used addresses
// need to be valid, in the sense that:
// 1. The data it points to has to be within allocated shared mem buffer.
// 2. The address needs to be aligned to 16 byte.
// See also:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
// This function addresses 2. above by masking out the sub-16B component
// of the address in upper warp and 1. is guaranteed by ldmatrix swizzle
// util.
// This will **not** affect any functionality. This is just modification
// of unused pointers to satisfy the alignment requirement on Turing
// hardware.
// The alignment requirement is lifted on sm80+,
// so this function is a no-op on Ampere or above.
DEVICE_INLINE void adjustPartialLdMatrixAddrInTuring(unsigned& addr_in_byte) {
#if (__CUDA_ARCH__ < 800)
const unsigned thread_id = threadIdx.x;
// Upper half warp has 8 bytes offset from aligned in .x2 option
// of ldmatrix. Currently no support for .x1 so assume always
// adjust by half warp.
constexpr unsigned half_warp = 16;
// Need to adjust to 16 byte alignment, mask out un-aligned component.
constexpr unsigned mask_out = 16 - 1;
// Adjust only in upper half warp.
// use bit math to reduce strength
if (thread_id & half_warp) {
// mask out the bits where adjust_mask has 1.
addr_in_byte &= (~mask_out);
}
#endif //(__CUDA_ARCH__ < 800)
}

} // namespace util

// Load Matrix (per warp instruction) is to take data from SMEM to Local Memory.
Expand All @@ -36,6 +70,7 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) {
DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) {
uint2& val = reinterpret_cast<uint2&>(out);
unsigned addr = util::toSmem(ptr);
util::adjustPartialLdMatrixAddrInTuring(addr);
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];"
: "=r"(val.x), "=r"(val.y)
: "r"(addr));
Expand All @@ -47,6 +82,7 @@ DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) {
DEVICE_INLINE void ldMatrixT(Array<__half, 4, 4>& out, void const* ptr) {
uint2& val = reinterpret_cast<uint2&>(out);
unsigned addr = util::toSmem(ptr);
util::adjustPartialLdMatrixAddrInTuring(addr);
asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];"
: "=r"(val.x), "=r"(val.y)
: "r"(addr));
Expand Down
68 changes: 68 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,74 @@ DEVICE_INLINE void initM16N16K4NT(Array<float, 8, 8>* accumulator) {

} // namespace Volta

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))

namespace Turing {

namespace util {
// MMA instruction wrappers (sm_75+):
DEVICE_INLINE void m16n8k16TN(
Array<float, 4, 4>* C,
Array<__half, 8, 8>* A,
Array<__half, 4, 4>* B) {
unsigned const* _A = reinterpret_cast<unsigned const*>(A);
unsigned const* _B = reinterpret_cast<unsigned const*>(B);
unsigned* _C = reinterpret_cast<unsigned*>(C);
const unsigned* _D = reinterpret_cast<const unsigned*>(C);

asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3])
: "r"(_A[0]),
"r"(_A[1]),
"r"(_B[0]),
"r"(_D[0]),
"r"(_D[1]),
"r"(_D[2]),
"r"(_D[3]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3])
: "r"(_A[2]),
"r"(_A[3]),
"r"(_B[1]),
"r"(_D[0]),
"r"(_D[1]),
"r"(_D[2]),
"r"(_D[3]));
}

} // namespace util

template <int acc_stride>
DEVICE_INLINE void initM16N8K16TN(Array<float, 4, 4>* accumulator) {
float* _C = reinterpret_cast<float*>(accumulator);
_C[0] = 0;
_C[1] = 0;
_C[acc_stride] = 0;
_C[acc_stride + 1] = 0;
}

template <int acc_stride = 2>
DEVICE_INLINE void M16N8K16TN(
Array<float, 4, 4>* C,
Array<__half, 8, 8>* A,
Array<__half, 4, 4>* B) {
// TODO: in a follow up,
// lift this fused swizzle onto iterdomain
float* _C = reinterpret_cast<float*>(C);
float C_data[4] = {_C[0], _C[1], _C[acc_stride], _C[acc_stride + 1]};

util::m16n8k16TN(reinterpret_cast<Array<float, 4, 4>*>(&C_data[0]), A, B);

_C[0] = C_data[0];
_C[1] = C_data[1];
_C[acc_stride] = C_data[2];
_C[acc_stride + 1] = C_data[3];
}

} // namespace Turing

#endif // Arch 75

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))

namespace Ampere {
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ void WarpMmaSwizzler::scheduleMmaWarpOutput(
setWarpMapped(tv, 5);
}
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
scheduleTuringM16N8K16MmaWarpOutput(tv, options);
if (tv->definition()->isA<MmaOp>()) {
Expand All @@ -240,6 +241,7 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) {
case MmaOptions::MacroType::Volta_16_16_4:
scheduleVoltaOperandRead(tv, options);
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
scheduleTuringOperandRead(tv, options);
break;
Expand Down Expand Up @@ -415,7 +417,8 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) {
: isOperandTransposed(options);
// Check mma option is supported
TORCH_CHECK(
options.macro == MmaOptions::MacroType::Ampere_16_8_16,
options.macro == MmaOptions::MacroType::Ampere_16_8_16 ||
options.macro == MmaOptions::MacroType::Turing_16_8_16,
"scheduleLdMatrix: unknown macro for ldmatrix");

if (options.operand == MmaOptions::Operand::A) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ bool TensorView::isEmptyTensor() const {

void TensorView::applyMmaSwizzle(MmaOptions options) {
switch (options.operand) {
case MmaOptions::Operand::NotOperand:
case MmaOptions::Operand::Accumulator:
mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(this, options);
break;
case MmaOptions::Operand::A:
Expand Down
Loading

0 comments on commit 5e6a8da

Please sign in to comment.