Skip to content

Commit

Permalink
Mma op integration on ampere (#1440)
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed May 23, 2022
1 parent fade8da commit 7093e39
Show file tree
Hide file tree
Showing 37 changed files with 2,503 additions and 103 deletions.
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ libtorch_nvfuser_runtime_sources = [
"torch/csrc/jit/codegen/cuda/runtime/helpers.cu",
"torch/csrc/jit/codegen/cuda/runtime/index_utils.cu",
"torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu",
"torch/csrc/jit/codegen/cuda/runtime/memory.cu",
"torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu",
"torch/csrc/jit/codegen/cuda/runtime/tensor.cu",
"torch/csrc/jit/codegen/cuda/runtime/tuple.cu",
Expand Down
82 changes: 80 additions & 2 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,33 @@ class CudaKernelGenerator : private OptOutConstDispatch {
TORCH_INTERNAL_ASSERT(false, "Unreachable");
}

//! Utility for generating vectorized pointer access in ldsm and
//! cpasync.
//! TODO: this access pattern as is could be merged with exisiting
//! vectorization handling logic but this path will be updated in
//! follow ups to optimize the generated assembly so keeping them
//! separate path for now.
std::string genVectorPointer(Val* val, DataType dtype, int vec_size) {
std::stringstream ss;

ss << "reinterpret_cast<Array<" << dtype << "," << vec_size << ","
<< vec_size << ">*>(&" << gen(val) << ")";

return ss.str();
}

void genLdMatrix(const LoadStoreOp* ldst, int vector_word_size) {
auto dtype = ldst->in()->getDataType().value();
indent() << "Turing::ldMatrix";
if (ldst->opType() == LoadStoreOpType::LdMatrixTranspose) {
code_ << "T";
}
code_ << " (";
code_ << "*" << genVectorPointer(ldst->out(), dtype, vector_word_size)
<< ","
<< "&" << gen(ldst->in()) << ");\n";
}

void handle(const UnaryOp* uop) final {
bool is_vector_op = false;
size_t vector_word_size = 1;
Expand Down Expand Up @@ -918,7 +945,15 @@ class CudaKernelGenerator : private OptOutConstDispatch {
if (init) {
ss << "init";
}
ss << toString(options.macro) << toString(options.operand_layout);
ss << toString(options.macro);

if (isVolta(options.macro)) {
ss << toString(options.operand_layout);
} else if (isTuring(options.macro) || isAmpere(options.macro)) {
// mma's in turing and ampere TN only, transpose is handled either
// via ldmatrix for fp16 or explicitly for other types.
ss << "TN";
}
// TODO: additional parameter could be removed by swizzling iterdomain
auto acc_stride = mma->accStride();
TORCH_INTERNAL_ASSERT(acc_stride > 0);
Expand Down Expand Up @@ -1123,6 +1158,49 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

void handle(const LoadStoreOp* ldst) {
// TODO:
// Need to gradually merge the code path of this
// with UnaryOp::Set for vectorization.
// There is quite a bit of possible clean up.
bool vectorize_op = false;
size_t vector_word_size = 1;
auto ti = ldst->out()->as<kir::TensorIndex>();

// Check vectorization and set vector word size
for (auto id : ti->view()->domain()->domain()) {
if (!isParallelTypeVectorize(id->getParallelType())) {
continue;
}

ExpressionEvaluator expr_eval(id->fusion());
auto vector_size_optional = expr_eval.evaluate(id->extent());

TORCH_INTERNAL_ASSERT(
vector_size_optional.has_value(),
"Could not evaluate constant value bound to vectorized dim.");

TORCH_INTERNAL_ASSERT(
id->getParallelType() != ParallelType::MisalignedVectorize,
"LoadStoreOp: no support yet for mis-aligned vectorization");
vector_word_size = vector_size_optional.value();
vectorize_op = true;
break;
}

// Dispatch instruction generation:
switch (ldst->opType()) {
case LoadStoreOpType::LdMatrix:
case LoadStoreOpType::LdMatrixTranspose:
TORCH_INTERNAL_ASSERT(
vectorize_op, "LdMatrix: Vectorization required: ", ldst);
genLdMatrix(ldst, vector_word_size);
break;
default:
TORCH_INTERNAL_ASSERT(false, "LoadStoreOp: Unknown op type");
}
}

void handle(const WelfordOp* wop) final {
TORCH_INTERNAL_ASSERT(wop->out()->isA<kir::TensorIndex>());

Expand Down Expand Up @@ -2033,7 +2111,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

void handle(const kir::BlockSync*) final {
void handle(const kir::BlockSync* sync) final {
// Use a custom synchronization method if enabled
if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
indent() << "block_sync::sync();\n";
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::WelfordOp:
ptr(handler)->handle(expr->as<WelfordOp>());
return;
case ExprType::LoadStoreOp:
ptr(handler)->handle(expr->as<LoadStoreOp>());
return;
case ExprType::MmaOp:
ptr(handler)->handle(expr->as<MmaOp>());
return;
Expand Down Expand Up @@ -260,6 +263,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::WelfordOp:
ptr(handler)->handle(expr->as<WelfordOp>());
return;
case ExprType::LoadStoreOp:
ptr(handler)->handle(expr->as<LoadStoreOp>());
return;
case ExprType::MmaOp:
ptr(handler)->handle(expr->as<MmaOp>());
return;
Expand Down Expand Up @@ -418,6 +424,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::WelfordOp:
ptr(mutator)->mutate(expr->as<WelfordOp>());
return;
case ExprType::LoadStoreOp:
ptr(mutator)->mutate(expr->as<LoadStoreOp>());
return;
case ExprType::MmaOp:
ptr(mutator)->mutate(expr->as<MmaOp>());
return;
Expand Down Expand Up @@ -641,6 +650,9 @@ void OptOutConstDispatch::handle(const GroupedReductionOp* stmt) {
void OptOutConstDispatch::handle(const WelfordOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const LoadStoreOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const MmaOp* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -761,6 +773,9 @@ void OptOutDispatch::handle(GroupedReductionOp* stmt) {
void OptOutDispatch::handle(WelfordOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(LoadStoreOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(MmaOp* stmt) {
unhandled(stmt);
}
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class TernaryOp;
class ReductionOp;
class GroupedReductionOp;
class WelfordOp;
class LoadStoreOp;
class MmaOp;
class BroadcastOp;
class TransposeOp;
Expand Down Expand Up @@ -136,6 +137,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const ReductionOp* stmt);
virtual void handle(const GroupedReductionOp* stmt);
virtual void handle(const WelfordOp* stmt);
virtual void handle(const LoadStoreOp* stmt);
virtual void handle(const MmaOp* stmt);
virtual void handle(const BroadcastOp* stmt);

Expand Down Expand Up @@ -191,6 +193,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(ReductionOp* stmt);
virtual void handle(GroupedReductionOp* stmt);
virtual void handle(WelfordOp* stmt);
virtual void handle(LoadStoreOp* stmt);
virtual void handle(MmaOp* stmt);
virtual void handle(BroadcastOp* stmt);

Expand Down Expand Up @@ -287,6 +290,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(ReductionOp*);
virtual void mutate(GroupedReductionOp*);
virtual void mutate(WelfordOp*);
virtual void mutate(LoadStoreOp*);
virtual void mutate(MmaOp*);
virtual void mutate(BroadcastOp*);

Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/executor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <nvfuser_resources/grid_sync.h>
#include <nvfuser_resources/helpers.h>
#include <nvfuser_resources/index_utils.h>
#include <nvfuser_resources/memory.h>
#include <nvfuser_resources/random_numbers.h>
#include <nvfuser_resources/tensor.h>
#include <nvfuser_resources/tensorcore.h>
Expand Down Expand Up @@ -98,6 +99,7 @@ std::string kernelPreamble() {
ss << nvfuser_resources::welford_cu;
ss << nvfuser_resources::warp_cu;
ss << nvfuser_resources::tensorcore_cu;
ss << nvfuser_resources::memory_cu;
ss << nvfuser_resources::fused_reduction_cu;

// Random utilities
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ IR_BUILDER_INSTANTIATE(TernaryOp)
IR_BUILDER_INSTANTIATE(ReductionOp)
IR_BUILDER_INSTANTIATE(GroupedReductionOp)
IR_BUILDER_INSTANTIATE(WelfordOp)
IR_BUILDER_INSTANTIATE(LoadStoreOp)
IR_BUILDER_INSTANTIATE(MmaOp)
IR_BUILDER_INSTANTIATE(BroadcastOp)

Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_cloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ void IrCloner::handle(const WelfordOp* op) {
clone_ = IrBuilder::clone(op, this);
}

void IrCloner::handle(const LoadStoreOp* op) {
clone_ = IrBuilder::clone(op, this);
}

void IrCloner::handle(const MmaOp* op) {
clone_ = IrBuilder::clone(op, this);
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_cloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
void handle(const ReductionOp*) override;
void handle(const GroupedReductionOp*) override;
void handle(const WelfordOp*) override;
void handle(const LoadStoreOp*) override;
void handle(const MmaOp*) override;
void handle(const TransposeOp*) override;
void handle(const ShiftOp*) override;
Expand Down
33 changes: 15 additions & 18 deletions torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,14 +402,22 @@ class TORCH_CUDA_CU_API TensorView : public Val {
const std::vector<int>& axes,
const std::vector<TensorView*>& tvs);

// Create a TensorView before the original tensor. A common use case is to
// write results into shared memory or registers before moving to global
// memory. Analogous to TVM Cache_Write
TensorView* cacheBefore();
//! Create a TensorView before the original tensor. A common use case is to
//! write results into shared memory or registers before moving to global
//! memory. Analogous to TVM Cache_Write
//!
//! @param cache_op: memory operator to use for the inserted op between
//! the the data tensor and the cache tensor
TensorView* cacheBefore(
c10::optional<LoadStoreOpType> cache_op = c10::nullopt);

// Create a TensorView after the original tensor. A common use case is to
// read tensor into shared memory or registers. Analogous to TVM Cache_Read
TensorView* cacheAfter();
//! Create a TensorView after the original tensor. A common use case is to
//! read tensor into shared memory or registers. Analogous to TVM Cache_Read
//!
//! @param cache_op: memory operator to use for the inserted op between
//! the the data tensor and the cache tensor
TensorView* cacheAfter(
c10::optional<LoadStoreOpType> cache_op = c10::nullopt);

// For a fusion output with other uses, we want to avoid writing to global
// memory and then reading the output again. We write to global memory
Expand Down Expand Up @@ -438,17 +446,6 @@ class TORCH_CUDA_CU_API TensorView : public Val {
return is_double_buffered_;
}

//! Fill in mma options in scheduling time.
//! Each mma op in Fusion IR must be configured once before lowering.
//! Mma options are configuration parameters used in lowering to mma
//! instrinsics, mainly the type of mma macro to use and input data layout
//! etc.
//!
//! TODO: This step will very likely be removed in a follow up PR. All of
//! the options configured here could actually be inferred from fusion IR
//! once we are feature complete.
void configureMma(MmaOptions options);

//! Transforms the innermost iterdomains according to the given mma swizzle,
//! this should be used on the tvs that are either inputs/outputs of an
//! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to
Expand Down
30 changes: 30 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,36 @@ class TORCH_CUDA_CU_API ViewOp : public Expr {
TensorView* const in_ = nullptr;
};

//! This operator explicitly models data movement between
//! state spaces on GPU. Currently the modeled state spaces include
//! global memory, shared memory and register.
//!
//! The main usage of this op is to facilitate generation of hardware
//! accelerated memory ops, i.e. ldmatrix, cp.async and more to come.
class TORCH_CUDA_CU_API LoadStoreOp : public Expr {
public:
LoadStoreOp(IrBuilderPasskey, LoadStoreOpType op_type, Val* out, Val* in);

LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner);

Val* out() const {
return out_;
}

Val* in() const {
return in_;
}

LoadStoreOpType opType() const {
return load_store_type_;
}

private:
LoadStoreOpType load_store_type_ = LoadStoreOpType::LdMatrix;
Val* const out_ = nullptr;
Val* const in_ = nullptr;
};

// Friends for direct access to split
class TensorDomain;
class ReplayTransformations;
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_iostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,11 @@ void IrPrinter::handle(const WelfordOp* wop) {
os_ << " )\n";
}

void IrPrinter::handle(const LoadStoreOp* ldst) {
indent() << ldst->out() << " = " << ldst->opType() << "( " << ldst->in()
<< " )\n";
}

void IrPrinter::handle(const BroadcastOp* bop) {
indent() << bop->out() << "\n";
indent() << " = broadcast( " << bop->in() << " )\n";
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_iostream.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch {
void handle(const ReductionOp*) final;
void handle(const GroupedReductionOp*) final;
void handle(const WelfordOp*) final;
void handle(const LoadStoreOp*) final;
void handle(const MmaOp*) final;
void handle(const BroadcastOp*) final;
void handle(const TransposeOp*) final;
Expand Down
31 changes: 29 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,25 @@ ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner)
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)) {}

LoadStoreOp::LoadStoreOp(
IrBuilderPasskey passkey,
LoadStoreOpType op_type,
Val* out,
Val* in)
: Expr(passkey, ExprType::LoadStoreOp),
load_store_type_(op_type),
out_(out),
in_(in) {
addOutput(out);
addInput(in);
}

LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
load_store_type_(src->load_store_type_),
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)) {}

IterDomain::IterDomain(
IrBuilderPasskey passkey,
Val* start,
Expand Down Expand Up @@ -1183,9 +1202,17 @@ void IterDomain::parallelize(ParallelType t) {
}

if (isMmaSwizzled()) {
// Mma swizzled axes represent data representation within a warp
// so only allow updates that keep the parallelization within
// a warp.
// Note && TODO: this check is actually used to allow indexing path
// to make copies of the iterdomains. We might eventually just want
// to lock these parallel types and not allowing any changes once
// they are swizzled.
TORCH_CHECK(
t == ParallelType::Vectorize,
"Parallel type other than vectorize not allowed for warp mapped ids");
t == ParallelType::Vectorize || t == ParallelType::TIDx ||
t == ParallelType::Serial,
"Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids");
}

parallel_type_ = t;
Expand Down
Loading

0 comments on commit 7093e39

Please sign in to comment.