Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions test/cpp/tensorexpr/test_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,36 @@ TEST_F(Kernel, _3) {
}
}

TEST_F(Kernel, ParallelStrided) {
KernelScope kernel_scope;

const auto graph_string = R"IR(
graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu),
%1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)):
%2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1)
%3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2)
return (%3))IR";
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);

auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat));
auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat))
.index(
{Slice(None, None, 2),
Slice(None, None, 2),
Slice(None, None, 2)});
auto ref = a * (a * b);
auto o = at::zeros_like(ref);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a, b};
std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
o = stack[0].toTensor();
for (size_t i = 0; i < 5 * 3; i++) {
CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
}
}

TEST_F(Kernel, DISABLED_Shape_Inference) {
// disabled: doesn't do stride propagation, and isn't being used currently

Expand Down
87 changes: 87 additions & 0 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/csrc/jit/tensorexpr/kernel.h>

#include <ATen/ExpandUtils.h>
#include <ATen/Parallel.h>
#include <ATen/TensorGeometry.h>
#include <c10/util/irange.h>
#include <c10/util/string_utils.h>
Expand Down Expand Up @@ -2487,6 +2488,86 @@ void fuseAllLoops(StmtPtr st) {
}
}

// Compute the trip count of a loop if it is a constant.
c10::optional<int64_t> tripCount(ForPtr loop) {
auto tc = IRSimplifier::simplify(
cast<int64_t>(ExprHandle(loop->stop()) - ExprHandle(loop->start())));
if (auto val = to<LongImm>(tc.node())) {
return val->value();
}
return c10::nullopt;
}

// Prune innermost loops until iterations satisfies a minimum grain size.
static void pruneByGrainSize(std::vector<ForPtr>& loops) {
constexpr int64_t minGrainSize = 32768;
int64_t grainSize = 1;
for (int64_t i = loops.size(); i > 0; i--) {
auto tc = tripCount(loops[i - 1]);
if (!tc) {
break;
}
grainSize *= *tc;
if (grainSize < minGrainSize) {
loops.pop_back();
}
}
}

// Retain enough outermost loops to fill the number of threads.
static void pruneByThreadCount(std::vector<ForPtr>& loops) {
int64_t trips = 1;
auto threads = at::get_num_threads();
auto it = loops.begin();
for (; it != loops.end(); it++) {
if (trips >= threads) {
break;
}
auto tc = tripCount(*it);
if (!tc) {
break;
}
trips *= *tc;
}
loops.erase(it, loops.end());
}

// Flatten and parallelize outer loops, subject to a minimum number of elements
// in the inner loop, and a maximum level of thread-level parallelism in the
// outer loops.
template <typename Bufs>
static void parallelizeOuterLoops(LoopNest& l, Bufs&& bufs) {
for (auto const& buf : bufs) {
auto loops = l.getLoopStmtsFor(buf);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this function is called after fuseAllLoops it is possible that multiple buffers belong to the same loopnest. So, we could be repeating this loop multiple times for the same loopnest. I understand that that may not be incorrect at this point, but it could lead to bugs in future.

IMO, we shouldn't be looking at output buffers and their loopnests. Instead, we should just take root_stmt in the given LoopNest and apply parallelization for all loopnests in that stmt. Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems OK to me, sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess where things could get a little weird is if multiple buffers are updated at different levels of the loopnest, e.g.:

for i:
  y1[] = ...
  for j:
    y2[] = ...

The current approach sort of gives each buffer an "independent" chance to affect the loop parallelization. That doesn't seem terrible, tbh.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the more I think about it, is there really any advantage to starting with the root stmt and working down? From my POV it just makes the code a lot more complicated; the way things happen now I just get a nice vector of loops leading to a buffer and try to flatten them. If it's not flattenable, it simply fails and I give up.

Although, maybe it's not too much work to build up my own vector starting from the root. Idk.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are okay with having the same set of loops being handled here for different bufs, then I have no objections to it.

Personally, I felt starting from the root_stmt might be better. We might need another API to extract all loops in the root_stmt. So, may be we can do this in future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, going through all buffers would not miss parallelism opportunities like the following

for i 
  for j1
      y1 = ... [data dependence exists between iterations]
  for j2
     y2 = ... [no data dependence between iterations]

i+j1 cannot be parallelized because there's data dependence between iterations for y1; but i+j2 can be parallelized and we should not miss it. If this is the thing we try to do here, I guess we need a distribute transformation before flatten. flatten currently only handles perfectly nested loops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the approach here will definitely miss opportunities where some nested loops are parallelizable. It's really kind of a best-effort thing to get simple elementwise fusions right, not a general solution to parallelism.

pruneByGrainSize(loops);
pruneByThreadCount(loops);

// There are no loops to parallelize; give up.
if (loops.size() == 0) {
continue;
}
// The loop nest contains a reduction; give up.
auto reductions = NodeFinder<ReduceOp>::find(loops[0]);
if (reductions.size() > 0) {
continue;
}
// The loop nest has loop carried dependences; give up.
if (LoopNest::hasLoopCarriedDependence(loops[0])) {
continue;
}
// Try to flatten the outer loops and parallelize them if successful.
ForPtr flattened = nullptr;
if (loops.size() == 1) {
flattened = loops[0];
} else {
LoopNest::flatten(loops, &flattened);
}
if (flattened) {
flattened->set_parallel();
}
}
}

StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
torch::jit::tensorexpr::LoopNest l(st, bufOutputs_);
GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
Expand Down Expand Up @@ -2528,6 +2609,8 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
if (backendType == kLLVMCodeGen) {
fuseAllLoops(l.root_stmt());
GRAPH_DEBUG("after fuse", *l.root_stmt());
parallelizeOuterLoops(l, bufOutputs_);
GRAPH_DEBUG("after parallelize", *l.root_stmt());
}

if (backendType == kCudaCodeGen) {
Expand Down Expand Up @@ -2602,9 +2685,13 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
}

l.prepareForCodegen();
GRAPH_DEBUG("after prepareForCodegen", *l.root_stmt());
l.simplify();
GRAPH_DEBUG("after simplification", *l.root_stmt());

if (backendType == kLLVMCodeGen && !hasReduction) {
l.vectorizeInnerLoops();
GRAPH_DEBUG("after vectorization", *l.root_stmt());
}

StmtPtr stmt = l.root_stmt();
Expand Down
24 changes: 17 additions & 7 deletions torch/csrc/jit/tensorexpr/llvm_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,15 +274,24 @@ class LLVMCodeGenImpl : public IRVisitor {
}
};

extern "C" {
typedef void (*ParallelCallee)(int index, int8_t* packed_data);
void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data) {
void DispatchParallel(
int8_t* func,
int start,
int stop,
int8_t* packed_data) noexcept {
// TODO: preserve the func type.
ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
for (int index = f_begin; index < f_end; index++) {
callee(index, packed_data);
}
});
try {
ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
for (int index = f_begin; index < f_end; index++) {
callee(index, packed_data);
}
});
} catch (...) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kinda curious about this place: why not terminate if there's an exception? I guess the execution of the left stmts would ultimately lead to wrong results?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting point... if an exception happens here things are really screwed up, because we don't know how to unwind past llvm-generated frames. But no exceptions should be possible here, since we're just parallel-dispatching to our own kernel, which doesn't throw exceptions. So I was mainly putting the try-catch here to ensure that the compiler knew that it wouldn't need to unwind this frame.

}
}
}

} // namespace tensorexpr
Expand Down Expand Up @@ -1287,6 +1296,7 @@ void LLVMCodeGenImpl::processParallelFor(ForPtr v) {
module_->getOrInsertFunction("DispatchParallel", dispatcher_fntype);
llvm::Function* dispatcher =
llvm::cast<llvm::Function>(dispatcher_callee.getCallee());
dispatcher->addFnAttr(llvm::Attribute::NoUnwind);
irb_.CreateCall(
dispatcher, {func_value, start, stop, packed_caller_args_ptr});
value_ = llvm::ConstantInt::get(IntTy_, 0);
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/tensorexpr/llvm_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ namespace torch {
namespace jit {
namespace tensorexpr {

void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data);
extern "C" {
void DispatchParallel(
int8_t* func,
int start,
int stop,
int8_t* packed_data) noexcept;
}

inline std::string formatError(llvm::Error&& err, const char* msg) {
static constexpr char* defaultErrorMsg = "Unexpected failure in LLVM JIT";
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/tensorexpr/loopnest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ class Vectorizer : public IRMutator {
});
}

ExprPtr mutate(ModPtr v) override {
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return ExprHandle(inputs[0]) % ExprHandle(inputs[1]);
});
}

ExprPtr mutate(AndPtr v) override {
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
Expand Down