Skip to content

Commit

Permalink
squashed upstream_push_0627
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Jun 27, 2022
1 parent 590d3e5 commit 5caf4a2
Show file tree
Hide file tree
Showing 49 changed files with 4,685 additions and 1,312 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -52,6 +52,8 @@ namespace c10 {
_(prim, squeeze_copy) \
_(prim, unsqueeze_copy) \
_(prim, flatten_copy) \
_(prim, expand_copy) \
_(prim, expand_as_copy) \
_(prim, DifferentiableGraph) \
_(prim, TensorExprGroup) \
_(prim, TensorExprDynamicGroup) \
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/cpp/nvfuser/CMakeLists.txt
Expand Up @@ -26,7 +26,7 @@ if(USE_CUDA)

target_link_libraries(nvfuser_bench PRIVATE torch_library benchmark)
if(NOT MSVC)
target_compile_options(nvfuser_bench PRIVATE -Wno-unused-variable -Werror)
target_compile_options(nvfuser_bench PRIVATE -Wno-unused-variable -Wno-deprecated-copy -Werror)
endif()

endif()
2 changes: 2 additions & 0 deletions build_variables.bzl
Expand Up @@ -655,6 +655,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/graph_fuser.cpp",
"torch/csrc/jit/codegen/cuda/grouped_reduction.cpp",
"torch/csrc/jit/codegen/cuda/index_compute.cpp",
"torch/csrc/jit/codegen/cuda/lower_index_compute.cpp",
"torch/csrc/jit/codegen/cuda/index_reference_replay.cpp",
"torch/csrc/jit/codegen/cuda/instrumentation.cpp",
"torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp",
Expand Down Expand Up @@ -698,6 +699,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp",
"torch/csrc/jit/codegen/cuda/lower2device.cpp",
"torch/csrc/jit/codegen/cuda/manager.cpp",
"torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp",
"torch/csrc/jit/codegen/cuda/mutator.cpp",
"torch/csrc/jit/codegen/cuda/non_divisible_split.cpp",
"torch/csrc/jit/codegen/cuda/ops/alias.cpp",
Expand Down
27 changes: 27 additions & 0 deletions test/test_jit_cuda_fuser.py
Expand Up @@ -4812,6 +4812,33 @@ def t_cpu(x):

self.assertGraphContainsExactly(t_cpu_jit.graph_for(x), FUSION_GUARD, 0)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_expand(self):
device = "cuda"
x = torch.randn(3, 5, device=device)
y = torch.randn(4, 2, 3, 5, device=device)

def t(x, y):
with torch.jit.strict_fusion():
x = x.relu()
o0 = x.expand(2, 3, 5)
o1 = x.expand_as(y)
return o0, o1

t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y, check_stride=True)

def t2(x, y):
o0 = x.expand(2, 3, 5)
o1 = x.expand_as(y)
x.add_(1)
return o0, o1

t2_jit = torch.jit.script(t2)
self._run_helper(t2_jit, t2, x, y, check_stride=True, num_fusion=0)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
Expand Down
151 changes: 111 additions & 40 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Expand Up @@ -19,6 +19,20 @@ namespace cuda {

namespace {

TensorView* maybe_broadcast_inner_to_rank(TensorView* t, size_t rank) {
size_t t_rank = TensorDomain::noReductions(t->getMaybeRFactorDomain()).size();

// broadcast inner on inp to match rank with other.
if (t_rank < rank) {
const int num_bcast = static_cast<int>(rank - t_rank);
std::vector<bool> inner_bcast_dims(rank, false);
std::fill(
inner_bcast_dims.begin(), inner_bcast_dims.begin() + num_bcast, true);
t = broadcast(t, inner_bcast_dims);
}
return t;
}

Val* simplifiedInt(Val* val) {
TORCH_INTERNAL_ASSERT(
val->isConstInt(), "Expecting Const Int's only in this routine.");
Expand Down Expand Up @@ -96,6 +110,49 @@ Val* newScalar(ValType vtype, DataType dtype) {
" in newScalar.");
}

IterType promoteIterType(IterType type1, IterType type2) {
// Iteration: Default
// Reduction: Should not appear here
// Broadcast: Propagated only if type1 and type2 are Broadcast
// Gather: Converted to Iteration
// Stride: Shold not appear here
// VectorComponent: Converted to Iteration

TORCH_INTERNAL_ASSERT(
type1 != IterType::Reduction && type1 != IterType::Stride,
"Invalid IterType: ",
type1)
TORCH_INTERNAL_ASSERT(
type2 != IterType::Reduction && type2 != IterType::Stride,
"Invalid IterType: ",
type2);

// Do not propagate Gather and VectorComponent
if (type1 == IterType::Gather || type1 == IterType::VectorComponent) {
type1 = IterType::Iteration;
}
if (type2 == IterType::Gather || type2 == IterType::VectorComponent) {
type2 = IterType::Iteration;
}

// At this point, type1 and type2 must be either Iteration or
// Broadcast
TORCH_INTERNAL_ASSERT(
type1 == IterType::Iteration || type1 == IterType::Broadcast,
"Unexpected IterType: ",
type1);
TORCH_INTERNAL_ASSERT(
type2 == IterType::Iteration || type2 == IterType::Broadcast,
"Unexpected IterType: ",
type2);

if (type1 == IterType::Broadcast) {
return type2;
} else {
return type1;
}
}

TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
std::vector<TensorView*> tvs;
for (auto val : vals) {
Expand Down Expand Up @@ -141,12 +198,8 @@ TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
}
extent_vals[i] = promoteSize(extent_vals[i], dom[i]->extent());
if (iter_types[i].has_value()) {
// TODO: Enable, see conv tests and gather promotion/gather broadcast
// behavior.
//
// TORCH_INTERNAL_ASSERT(
// iter_types[i].value() == dom[i]->getIterType(),
// "Invalid iter type promotion in newOutputTv for expression.");
iter_types[i] =
promoteIterType(iter_types[i].value(), dom[i]->getIterType());
} else {
iter_types[i] = dom[i]->getIterType();
}
Expand Down Expand Up @@ -210,17 +263,7 @@ std::vector<Val*> maybeBroadcast(const std::vector<Val*>& vals) {
for (const auto i : c10::irange(vals.size())) {
if (vals[i]->getValType().value() == ValType::TensorView) {
auto tv = vals[i]->as<TensorView>();
size_t tv_dims =
TensorDomain::noReductions(tv->getMaybeRFactorDomain()).size();
if (tv_dims < n_dims) {
std::vector<bool> bcast_flags(n_dims, false);
for (const auto j : c10::irange(n_dims - tv_dims)) {
bcast_flags[j] = true;
}
out_vals[i] = broadcast(tv, bcast_flags);
} else {
out_vals[i] = vals[i];
}
out_vals[i] = maybe_broadcast_inner_to_rank(tv, n_dims);
} else {
out_vals[i] = vals[i];
}
Expand Down Expand Up @@ -518,8 +561,9 @@ namespace {
// Helper function to reduce repetitive code
template <typename T1, typename T2>
TensorView* arithOpOverloads(Val* (*func)(Val*, Val*), T1* v1, T2* v2) {
return func(v1->template as<Val>(), v2->template as<Val>())
->template as<TensorView>();
Val* out = func(v1->template as<Val>(), v2->template as<Val>());
TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
return out->as<TensorView>();
}

template <typename T1, typename T2>
Expand All @@ -528,9 +572,10 @@ TensorView* arithOpOverloads(
T1* v1,
T2* v2,
DataType common_dtype) {
return binaryOp(
type, v1->template as<Val>(), v2->template as<Val>(), common_dtype)
->template as<TensorView>();
Val* out = binaryOp(
type, v1->template as<Val>(), v2->template as<Val>(), common_dtype);
TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
return out->as<TensorView>();
}

template <typename T1, typename T2, typename T3>
Expand All @@ -540,11 +585,12 @@ TensorView* arithOpOverloads(
T2* v2,
T3* v3) {
auto vals = maybeBroadcast({v1, v2, v3});
return func(
vals[0]->template as<Val>(),
vals[1]->template as<Val>(),
vals[2]->template as<Val>())
->template as<TensorView>();
Val* out = func(
vals[0]->template as<Val>(),
vals[1]->template as<Val>(),
vals[2]->template as<Val>());
TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
return out->as<TensorView>();
}

template <typename T1, typename T2, typename T3, typename T4>
Expand All @@ -555,12 +601,13 @@ TensorView* arithOpOverloads(
T3* v3,
T4* v4) {
auto vals = maybeBroadcast({v1, v2, v3, v4});
return func(
vals[0]->template as<Val>(),
vals[1]->template as<Val>(),
vals[2]->template as<Val>(),
vals[3]->template as<Val>())
->template as<TensorView>();
Val* out = func(
vals[0]->template as<Val>(),
vals[1]->template as<Val>(),
vals[2]->template as<Val>(),
vals[3]->template as<Val>());
TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
return out->as<TensorView>();
}

// Output type promotion logic for binary operators
Expand Down Expand Up @@ -906,6 +953,17 @@ static TensorView* newForReduction(
return IrBuilder::create<TensorView>(td, data_type);
}

namespace {

// PyTorch accepts reductions of zero-dimensional tensors, which are
// just ignored.
TensorView* reductionOpZeroDimTensor(TensorView* inp) {
TORCH_INTERNAL_ASSERT(inp->domain()->noReductions().size() == 0);
return set(inp);
}

} // namespace

TensorView* reductionOp(
BinaryOpType reduction_op_type,
const std::vector<int>& axes,
Expand All @@ -921,10 +979,13 @@ TensorView* reductionOp(
TensorDomain::sameAs(tv->getMaybeRFactorDomain(), tv->domain()->domain()),
"Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt.");

TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor");

TORCH_CHECK(axes.size() > 0, "No reduction axis specified");

// PyTorch allows reduction of 0-dim tensors
if (tv->domain()->noReductions().size() == 0) {
return reductionOpZeroDimTensor(tv);
}

std::vector<unsigned int> uint_axes;
const int ndims = tv->domain()->noReductions().size();
for (int axis : axes) {
Expand Down Expand Up @@ -963,7 +1024,6 @@ TensorView* reductionOp(
for (auto axis : uint_axes) {
is_broadcast.at(axis) = true;
}

out = broadcast(out, is_broadcast);
}
return out;
Expand Down Expand Up @@ -1081,12 +1141,15 @@ TensorView* expand(TensorView* inp, const std::vector<Val*>& expanded_sizes) {
auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());

TORCH_CHECK(
expanded_sizes.size() == inp_domain.size(),
"Invalid expand, number of sizes provided is expected to be ",
expanded_sizes.size() >= inp_domain.size(),
"Invalid expand, number of sizes provided is expected to be at least ",
inp_domain.size(),
" but received ",
expanded_sizes.size());

inp = maybe_broadcast_inner_to_rank(inp, expanded_sizes.size());
inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());

std::vector<Val*> maybe_expanded_sizes;
maybe_expanded_sizes.resize(inp_domain.size(), nullptr);

Expand Down Expand Up @@ -1154,12 +1217,15 @@ TensorView* expand_as(TensorView* inp, TensorView* other) {
TensorDomain::noReductions(other->getMaybeRFactorDomain());

TORCH_CHECK(
inp_domain.size() == other_domain.size(),
"Invalid expand_as, dimensions of inp don't match dimensions of other, expected other to be ",
inp_domain.size() <= other_domain.size(),
"Invalid expand_as, dimensions of inp is higher than dimensions of other, expected other to be at least ",
inp_domain.size(),
" but received ",
other_domain.size());

inp = maybe_broadcast_inner_to_rank(inp, other_domain.size());
inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());

std::vector<IterDomain*> out_domain;
std::vector<Val*> maybe_expanded_sizes;
bool expanded = false;
Expand Down Expand Up @@ -1447,6 +1513,11 @@ Val* where(Val* c, Val* v1, Val* v2) {
promote_type(v1->getDataType().value(), v2->getDataType().value());
auto out_vtype =
promote_type(v1->getValType().value(), v2->getValType().value());
// Even when v1 and v2 are scalar, the output is a tensor if the
// conditional input is a tensor.
if (c->getValType() == ValType::TensorView) {
out_vtype = ValType::TensorView;
}
auto vals = maybeBroadcast({c, v1, v2});
Val* out = nullptr;
if (out_vtype == ValType::TensorView) {
Expand Down
15 changes: 8 additions & 7 deletions torch/csrc/jit/codegen/cuda/arith.h
Expand Up @@ -256,13 +256,14 @@ TORCH_CUDA_CU_API TensorView* broadcast(
TensorView* inp,
const std::vector<bool>& is_broadcast_dim);

// Expands input based on provided sizes. expand_sizes should be the same size
// as the input's root domain (really rfactor), and should be -1 for any
// dimension that should remain a symbolic size. For dimensions that remain
// broadcast after the expand should be set to 1, any dimension being expanded
// must be marked as a braodcast in the input and will be expanded to the
// provided constant size. Any dimension that's symbolic in the input but
// specified as a non -1 value will be set to that constant value.
// Expands input based on provided sizes. expand_sizes should be larger than
// the input's root domain (really rfactor) and will broadcast on inner
// dimensions. expand_sizes should be -1 for any dimension that should remain a
// symbolic size. For dimensions that remain broadcast after the expand should
// be set to 1, any dimension being expanded must be marked as a broadcast in
// the input and will be expanded to the provided constant size. Any dimension
// that's symbolic in the input but specified as a non -1 value will be set to
// that constant value.
TORCH_CUDA_CU_API TensorView* expand(
TensorView* inp,
const std::vector<Val*>& expanded_sizes);
Expand Down

0 comments on commit 5caf4a2

Please sign in to comment.