-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[NNC] Build aggregate stmt for kernel before LoopNest. #53024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1531,13 +1531,12 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { | |
| } | ||
| } | ||
|
|
||
| Stmt* TensorExprKernel::generateStmt(BackendType backendType) { | ||
| // Find all tensors we need to compute (including dependencies) and put them | ||
| // in a topological order | ||
| std::vector<Tensor*> tensors_to_compute = | ||
| findAllNeededTensors(tensorOutputs_); | ||
|
|
||
| torch::jit::tensorexpr::LoopNest l(tensorOutputs_, tensors_to_compute); | ||
| Stmt* TensorExprKernel::transformLoops(BackendType backendType, Stmt* st) { | ||
| std::unordered_set<const Buf*> output_bufs; | ||
| for (auto t : tensorOutputs_) { | ||
| output_bufs.insert(t->buf()); | ||
| } | ||
| torch::jit::tensorexpr::LoopNest l(st, output_bufs); | ||
| GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n"); | ||
|
|
||
| bool hasReduction = NodeFinder<ReduceOp>::find(l.root_stmt()).size() != 0; | ||
|
|
@@ -1929,22 +1928,36 @@ Tensor* TensorExprKernel::computeSoftmax( | |
| }, | ||
| {output_dims[softmax_dim]}); | ||
| if (!log_softmax) { | ||
| return Compute("aten_softmax", output_dims, [&](ParameterList& indices) { | ||
| return e->call(indices) / sum->call(remove_softmax_dim_index(indices)); | ||
| }); | ||
| auto result = | ||
| Compute("aten_softmax", output_dims, [&](ParameterList& indices) { | ||
| return e->call(indices) / | ||
| sum->call(remove_softmax_dim_index(indices)); | ||
| }); | ||
| return new Tensor( | ||
| result->buf(), | ||
| new Block({max->stmt(), e->stmt(), sum->stmt(), result->stmt()})); | ||
| } | ||
|
|
||
| auto log_sum = Compute( | ||
| "aten_softmax_log_sum", non_softmax_dims, [&](ParameterList& indices) { | ||
| return log(sum->call(indices)); | ||
| }); | ||
| return Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) { | ||
| auto inp = tensorOrConstant( | ||
| v->node()->inputs()[0], convert_indices_to_expr_handle(indices)); | ||
| auto non_softmax_indices = remove_softmax_dim_index(indices); | ||
| return inp - max->call(non_softmax_indices) - | ||
| log_sum->call(non_softmax_indices); | ||
| }); | ||
| auto result = | ||
| Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) { | ||
| auto inp = tensorOrConstant( | ||
| v->node()->inputs()[0], convert_indices_to_expr_handle(indices)); | ||
| auto non_softmax_indices = remove_softmax_dim_index(indices); | ||
| return inp - max->call(non_softmax_indices) - | ||
| log_sum->call(non_softmax_indices); | ||
| }); | ||
| return new Tensor( | ||
| result->buf(), | ||
| new Block( | ||
| {max->stmt(), | ||
| e->stmt(), | ||
| sum->stmt(), | ||
| log_sum->stmt(), | ||
| result->stmt()})); | ||
| } | ||
|
|
||
| TensorExprKernel::ReductionInfo TensorExprKernel::getReductionInfo( | ||
|
|
@@ -2102,11 +2115,18 @@ Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { | |
| void TensorExprKernel::compile() { | ||
| KernelScope kernelScope(&kernelArena_); | ||
| GRAPH_DUMP("TensorExprKernel graph:", graph_); | ||
|
|
||
| // Vector to collect the Stmts corresponding to all tensors. | ||
| std::vector<Stmt*> tensor_stmts; | ||
|
|
||
| // Bind inputs to buffers. | ||
| nInputs_ = graph_->inputs().size(); | ||
| for (auto const& input : graph_->inputs()) { | ||
| bindInput(input); | ||
| inputTypes_.push_back(input->type()); | ||
| if (input->type()->kind() == TypeKind::TensorType) { | ||
| tensor_stmts.push_back(Stmt::clone(tensors_.at(input->unique())->stmt())); | ||
| } | ||
| } | ||
|
|
||
| // Bind nodes to tensor compute expressions. | ||
|
|
@@ -2117,6 +2137,8 @@ void TensorExprKernel::compile() { | |
| for (auto const& output : n->outputs()) { | ||
| if (output->hasUses()) { | ||
| tensors_.emplace(output->unique(), computeValue(output)); | ||
| tensor_stmts.push_back( | ||
| Stmt::clone(tensors_.at(output->unique())->stmt())); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -2137,6 +2159,9 @@ void TensorExprKernel::compile() { | |
| // since NNC views it as contiguous. Only convert it to the right | ||
| // strides at the end of the kernel (if already contiguous it's a no-op) | ||
| Tensor* properly_strided_output = convertOutputToCorrectStrides(output); | ||
| if (tensors_.at(output->unique()) != properly_strided_output) { | ||
| tensor_stmts.push_back(Stmt::clone(properly_strided_output->stmt())); | ||
|
||
| } | ||
| tensors_[output->unique()] = properly_strided_output; | ||
| const auto& tt = output->type()->expect<TensorType>(); | ||
| auto sizes = *tt->sizes().concrete_sizes(); | ||
|
|
@@ -2159,7 +2184,7 @@ void TensorExprKernel::compile() { | |
| } | ||
|
|
||
| BackendType backendType = inferBackendTypeFromDevice(device_); | ||
| Stmt* stmt = generateStmt(backendType); | ||
| Stmt* stmt = transformLoops(backendType, new Block(tensor_stmts)); | ||
| // Set up formal params (inputs, then outputs) for kernel. | ||
| std::vector<CodeGen::BufferArg> params = prepareBufferArgs(); | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make replace
tensorOutputs_withbufOutputs_inTensorExprKerneland collect Bufs instead of Tensors in the first place? That would get rid of this loop.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I will do this cleanup as a followup.