Skip to content

Commit

Permalink
[nvfuser_upstream_push] nvfuser code base bump 052422 (#78244)
Browse files Browse the repository at this point in the history
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

A few bigger updates:
1. Initial support of cp.async and cp.async.wait: csarofeen#1619
2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen#1643
3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen#1440

Commits that's actually in this PR from the csarofeen branch
```
* dd23252 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710)
* b3d1c3f Fix missing cooperative launch (#1726)
* dc670a2 Async gmem copy support on sm80+ (#1619)
* 5e6a8da Add turing mma support and test (#1643)
* d6d6b7d Fix rFactor when there are indirect root domain(s), and refactor (#1723)
* 7093e39 Mma op integration on ampere (#1440)
* fade8da patch python test for bfloat16 (#1724)
* 8fbd0b1 Fine-grained kernel profiling (#1720)
* 77c1b4f Adding dry run mode to skip arch dependent checks (#1702)
* 151d95b More precise concretization analysis (#1719)
* f4d3630 Enable complex python tests (#1667)
* 4ceeee5 Minor bugfix in transform_rfactor.cpp (#1715)
* 3675c70 Separate root domain and rfactor domain in TransformPrinter (#1716)
* f68b830 Fix scheduling with polymorphic broadcast (#1714)
* 4ab5ef7 updating_ci_machine (#1718)
* 56585c5 Merge pull request #1711 from csarofeen/upstream_master_bump_0517
* 174d453 Allow using nvFuser on CUDA extension (#1701)
* 18bee67 Validate LOOP concrete IDs have complete IterDomains (#1676)
```
Pull Request resolved: #78244
Approved by: https://github.com/csarofeen, https://github.com/malfet
  • Loading branch information
jjsjann123 authored and malfet committed Jun 8, 2022
1 parent c2a3c81 commit 9e52ad2
Show file tree
Hide file tree
Showing 69 changed files with 5,805 additions and 743 deletions.
2 changes: 2 additions & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ libtorch_nvfuser_runtime_sources = [
"torch/csrc/jit/codegen/cuda/runtime/helpers.cu",
"torch/csrc/jit/codegen/cuda/runtime/index_utils.cu",
"torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu",
"torch/csrc/jit/codegen/cuda/runtime/memory.cu",
"torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu",
"torch/csrc/jit/codegen/cuda/runtime/tensor.cu",
"torch/csrc/jit/codegen/cuda/runtime/tuple.cu",
Expand Down Expand Up @@ -679,6 +680,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/lower_index.cpp",
"torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp",
"torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp",
"torch/csrc/jit/codegen/cuda/lower_instrument.cpp",
"torch/csrc/jit/codegen/cuda/lower_loops.cpp",
"torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp",
"torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp",
Expand Down
56 changes: 51 additions & 5 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@

os.environ['PYTORCH_NVFUSER_DISABLE'] = 'fallback,fma,unroll_with_rng'
os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0'
# TODO: enable complex when we fixes the extremal cases in OpInfo
# see issue https://github.com/csarofeen/pytorch/issues/1730"
# os.environ['PYTORCH_NVFUSER_ENABLE'] = 'complex'

if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
torch._C._jit_set_texpr_fuser_enabled(False)
Expand Down Expand Up @@ -159,7 +162,9 @@ def setUp(self):
torch.float16,
torch.float32,
torch.float64,
torch.bool
torch.bool,
torch.complex64,
torch.complex128,
]
if TEST_BF16:
self.support_tensor_dtypes.append(torch.bfloat16)
Expand Down Expand Up @@ -595,7 +600,9 @@ def t(x: torch.Tensor, y: torch.Tensor):
# bfloat16 kernels instead of eager mode
# implementation, since mismatch in cast
# adds excessive noise.
o = t(x.to(torch.float64), y.to(torch.float64)).to(torch.bfloat16)
o = t(x.to(torch.float64), y.to(torch.float64))
if o.dtype.is_floating_point:
o = o.to(torch.bfloat16)
else:
o = t(x, y)

Expand All @@ -609,7 +616,11 @@ def test_unary_ops(self):
*self.int_types,
torch.float16,
torch.float32,
torch.float64
torch.float64,
# TODO: revert this
# see issue https://github.com/csarofeen/pytorch/issues/1730"
# torch.cfloat,
# torch.cdouble,
]
if TEST_BF16:
data_types.append(torch.bfloat16)
Expand Down Expand Up @@ -654,7 +665,10 @@ def test_unary_ops(self):
torch.tan,
torch.tanh,
torch.nn.functional.silu]
skip_complex = {torch.rsqrt, torch.reciprocal}
for op, dtype in itertools.product(operations, data_types):
if dtype.is_complex and op in skip_complex:
continue
self._unary_test_helper(op, dtype, False) # test special numbers
self._unary_test_helper(op, dtype, True) # test random data

Expand Down Expand Up @@ -760,12 +774,20 @@ def t_doublex_tensory(x: float, y: torch.Tensor):
o = operation(x, y)
o = 2 + o
return o

def t_cdoublex_tensory(x: complex, y: torch.Tensor):
o = operation(x, y)
o = 2 + o
return o

# Omit both scalar cases and swap cases
assert category1 == "scalar" and category2 != "scalar"
if dtype_arg1.is_floating_point:
return t_doublex_tensory
if dtype_arg1 == torch.int64 or dtype_arg1 == torch.int32:
return t_intx_tensory
if dtype_arg1.is_complex or dtype_arg1 == torch.int32:
return t_cdoublex_tensory
raise NotImplementedError

def _binary_test_helper(self, operation, dtypes, random_data, categories="ndim"):
Expand Down Expand Up @@ -917,13 +939,12 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops(self):
# disabled bf16 / fp16 data types because of accuracy tolerance
data_types = [
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64
torch.float64,
]
if TEST_BF16:
data_types.append(torch.bfloat16)
Expand Down Expand Up @@ -958,6 +979,31 @@ def test_binary_ops(self):
for op, dtypes in itertools.product(operations, binary_dtype_combinations):
self._binary_test_helper(op, dtypes, False) # special numbers

# TODO: revert this
@unittest.skipIf(True, "see issue https://github.com/csarofeen/pytorch/issues/1730")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops_complex(self):
data_types = [torch.cfloat, torch.cdouble]
operations = [torch.mul, torch.div, torch.pow, torch.eq, torch.ne]

category_types = [
"scalar",
"0dim",
"0dimcpu",
"ndim"
]

binary_dtype_combinations = list(itertools.combinations(data_types, 2))
category_combinations = list(itertools.combinations(category_types, 2))

for op, dtypes, categories in itertools.product(operations, binary_dtype_combinations, category_combinations):
self._binary_test_helper(op, dtypes, True, categories) # random data

for op, dtypes in itertools.product(operations, binary_dtype_combinations):
self._binary_test_helper(op, dtypes, False) # special numbers

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
Expand Down
119 changes: 117 additions & 2 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,42 @@ class CudaKernelGenerator : private OptOutConstDispatch {
TORCH_INTERNAL_ASSERT(false, "Unreachable");
}

//! Utility for generating vectorized pointer access in ldsm and
//! cpasync.
//! TODO: this access pattern as is could be merged with exisiting
//! vectorization handling logic but this path will be updated in
//! follow ups to optimize the generated assembly so keeping them
//! separate path for now.
std::string genVectorPointer(Val* val, DataType dtype, int vec_size) {
std::stringstream ss;

ss << "reinterpret_cast<Array<" << dtype << "," << vec_size << ","
<< vec_size << ">*>(&" << gen(val) << ")";

return ss.str();
}

// Utility function to emit a cp.async intrinsic
void genCpAsync(const LoadStoreOp* ldst, int vec_size) {
auto dtype = ldst->in()->getDataType().value();

indent() << "Ampere::cpAsync("
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
<< genVectorPointer(ldst->in(), dtype, vec_size) << ");\n";
}

void genLdMatrix(const LoadStoreOp* ldst, int vector_word_size) {
auto dtype = ldst->in()->getDataType().value();
indent() << "Turing::ldMatrix";
if (ldst->opType() == LoadStoreOpType::LdMatrixTranspose) {
code_ << "T";
}
code_ << " (";
code_ << "*" << genVectorPointer(ldst->out(), dtype, vector_word_size)
<< ","
<< "&" << gen(ldst->in()) << ");\n";
}

void handle(const UnaryOp* uop) final {
bool is_vector_op = false;
size_t vector_word_size = 1;
Expand Down Expand Up @@ -918,7 +954,15 @@ class CudaKernelGenerator : private OptOutConstDispatch {
if (init) {
ss << "init";
}
ss << toString(options.macro) << toString(options.operand_layout);
ss << toString(options.macro);

if (isVolta(options.macro)) {
ss << toString(options.operand_layout);
} else if (isTuring(options.macro) || isAmpere(options.macro)) {
// mma's in turing and ampere TN only, transpose is handled either
// via ldmatrix for fp16 or explicitly for other types.
ss << "TN";
}
// TODO: additional parameter could be removed by swizzling iterdomain
auto acc_stride = mma->accStride();
TORCH_INTERNAL_ASSERT(acc_stride > 0);
Expand Down Expand Up @@ -1123,6 +1167,52 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

void handle(const LoadStoreOp* ldst) {
// TODO:
// Need to gradually merge the code path of this
// with UnaryOp::Set for vectorization.
// There is quite a bit of possible clean up.
bool vectorize_op = false;
size_t vector_word_size = 1;
auto ti = ldst->out()->as<kir::TensorIndex>();

// Check vectorization and set vector word size
for (auto id : ti->view()->domain()->domain()) {
if (!isParallelTypeVectorize(id->getParallelType())) {
continue;
}

ExpressionEvaluator expr_eval(id->fusion());
auto vector_size_optional = expr_eval.evaluate(id->extent());

TORCH_INTERNAL_ASSERT(
vector_size_optional.has_value(),
"Could not evaluate constant value bound to vectorized dim.");

TORCH_INTERNAL_ASSERT(
id->getParallelType() != ParallelType::MisalignedVectorize,
"LoadStoreOp: no support yet for mis-aligned vectorization");
vector_word_size = vector_size_optional.value();
vectorize_op = true;
break;
}

// Dispatch instruction generation:
switch (ldst->opType()) {
case LoadStoreOpType::LdMatrix:
case LoadStoreOpType::LdMatrixTranspose:
TORCH_INTERNAL_ASSERT(
vectorize_op, "LdMatrix: Vectorization required: ", ldst);
genLdMatrix(ldst, vector_word_size);
break;
case LoadStoreOpType::CpAsync:
genCpAsync(ldst, vector_word_size);
break;
default:
TORCH_INTERNAL_ASSERT(false, "LoadStoreOp: Unknown op type");
}
}

void handle(const WelfordOp* wop) final {
TORCH_INTERNAL_ASSERT(wop->out()->isA<kir::TensorIndex>());

Expand Down Expand Up @@ -1288,6 +1378,19 @@ class CudaKernelGenerator : private OptOutConstDispatch {
return flags.str();
}

void addProfileArguments(ArgumentBuilder& func_args, const Expr* expr) {
if (isEnabled(EnableOption::KernelProfile) &&
kernel_->profile().isProfiled(expr)) {
const auto& buffer_indices =
kernel_->profile().getIndicesInProfileBuffer(expr);
auto buffer = kernel_->profile().getBuffer();
TORCH_INTERNAL_ASSERT(buffer != nullptr);
for (const auto& index : buffer_indices) {
func_args.arg(varName(buffer)).append("[").append(index).append("]");
}
}
}

void handle(const kir::GridReduction* grop) final {
TORCH_INTERNAL_ASSERT(grop->out()->isA<kir::TensorIndex>());

Expand Down Expand Up @@ -1345,6 +1448,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
func_args.arg(genInline(grop->entrance_index()));
func_args.arg(genInline(grop->entrances()));

addProfileArguments(func_args, grop);

indent() << "reduction::gridReduce<" << template_args << ">(\n";
indent() << kTab << func_args << ");\n";
}
Expand Down Expand Up @@ -1412,6 +1517,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
// reduction_op
func_args.arg(genReductionOp(op_type, out->dtype()));

addProfileArguments(func_args, grop);

indent() << kTab << func_args << ");\n";
}

Expand Down Expand Up @@ -1483,6 +1590,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
func_args.arg(read_pred);
}

addProfileArguments(func_args, grouped_grop);

indent() << "reduction::gridReduceGroup<" << template_args << ">(\n";
indent() << kTab << func_args << ");\n";
}
Expand Down Expand Up @@ -1543,6 +1652,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
func_args.arg(read_pred);
}

addProfileArguments(func_args, grouped_grop);

indent() << genFusedReductionName(ir_utils::getTvOutput(grouped_grop))
<< ".reduceGroup(\n";
indent() << kTab << func_args << ");\n";
Expand Down Expand Up @@ -2012,7 +2123,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

void handle(const kir::BlockSync*) final {
void handle(const kir::BlockSync* sync) final {
// Use a custom synchronization method if enabled
if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
indent() << "block_sync::sync();\n";
Expand All @@ -2021,6 +2132,10 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

void handle(const kir::CpAsyncWait* cpasync_wait) final {
indent() << "Ampere::cpAsyncBarrier();\n";
}

void handle(const kir::GridSync* sync) final {
// Use a custom synchronization method if enabled
bool bidx = sync->syncDims().get(ParallelType::BIDx);
Expand Down
Loading

0 comments on commit 9e52ad2

Please sign in to comment.