Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1531,13 +1531,12 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
}
}

Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
// Find all tensors we need to compute (including dependencies) and put them
// in a topological order
std::vector<Tensor*> tensors_to_compute =
findAllNeededTensors(tensorOutputs_);

torch::jit::tensorexpr::LoopNest l(tensorOutputs_, tensors_to_compute);
Stmt* TensorExprKernel::transformLoops(BackendType backendType, Stmt* st) {
std::unordered_set<const Buf*> output_bufs;
for (auto t : tensorOutputs_) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make replace tensorOutputs_ with bufOutputs_ in TensorExprKernel and collect Bufs instead of Tensors in the first place? That would get rid of this loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I will do this cleanup as a followup.

output_bufs.insert(t->buf());
}
torch::jit::tensorexpr::LoopNest l(st, output_bufs);
GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");

bool hasReduction = NodeFinder<ReduceOp>::find(l.root_stmt()).size() != 0;
Expand Down Expand Up @@ -1929,22 +1928,36 @@ Tensor* TensorExprKernel::computeSoftmax(
},
{output_dims[softmax_dim]});
if (!log_softmax) {
return Compute("aten_softmax", output_dims, [&](ParameterList& indices) {
return e->call(indices) / sum->call(remove_softmax_dim_index(indices));
});
auto result =
Compute("aten_softmax", output_dims, [&](ParameterList& indices) {
return e->call(indices) /
sum->call(remove_softmax_dim_index(indices));
});
return new Tensor(
result->buf(),
new Block({max->stmt(), e->stmt(), sum->stmt(), result->stmt()}));
}

auto log_sum = Compute(
"aten_softmax_log_sum", non_softmax_dims, [&](ParameterList& indices) {
return log(sum->call(indices));
});
return Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) {
auto inp = tensorOrConstant(
v->node()->inputs()[0], convert_indices_to_expr_handle(indices));
auto non_softmax_indices = remove_softmax_dim_index(indices);
return inp - max->call(non_softmax_indices) -
log_sum->call(non_softmax_indices);
});
auto result =
Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) {
auto inp = tensorOrConstant(
v->node()->inputs()[0], convert_indices_to_expr_handle(indices));
auto non_softmax_indices = remove_softmax_dim_index(indices);
return inp - max->call(non_softmax_indices) -
log_sum->call(non_softmax_indices);
});
return new Tensor(
result->buf(),
new Block(
{max->stmt(),
e->stmt(),
sum->stmt(),
log_sum->stmt(),
result->stmt()}));
}

TensorExprKernel::ReductionInfo TensorExprKernel::getReductionInfo(
Expand Down Expand Up @@ -2102,11 +2115,18 @@ Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
void TensorExprKernel::compile() {
KernelScope kernelScope(&kernelArena_);
GRAPH_DUMP("TensorExprKernel graph:", graph_);

// Vector to collect the Stmts corresponding to all tensors.
std::vector<Stmt*> tensor_stmts;

// Bind inputs to buffers.
nInputs_ = graph_->inputs().size();
for (auto const& input : graph_->inputs()) {
bindInput(input);
inputTypes_.push_back(input->type());
if (input->type()->kind() == TypeKind::TensorType) {
tensor_stmts.push_back(Stmt::clone(tensors_.at(input->unique())->stmt()));
}
}

// Bind nodes to tensor compute expressions.
Expand All @@ -2117,6 +2137,8 @@ void TensorExprKernel::compile() {
for (auto const& output : n->outputs()) {
if (output->hasUses()) {
tensors_.emplace(output->unique(), computeValue(output));
tensor_stmts.push_back(
Stmt::clone(tensors_.at(output->unique())->stmt()));
}
}
}
Expand All @@ -2137,6 +2159,9 @@ void TensorExprKernel::compile() {
// since NNC views it as contiguous. Only convert it to the right
// strides at the end of the kernel (if already contiguous it's a no-op)
Tensor* properly_strided_output = convertOutputToCorrectStrides(output);
if (tensors_.at(output->unique()) != properly_strided_output) {
tensor_stmts.push_back(Stmt::clone(properly_strided_output->stmt()));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having a vector of Stmts and cloning tensor stmts into it, why can't we have a Block and append tensor stmts to it immediately, without cloning anything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran into some issues when I tried without cloning the stmts. But I wasn't adding them to a Block then. Let me try this cleanup again in a follow up PR.

}
tensors_[output->unique()] = properly_strided_output;
const auto& tt = output->type()->expect<TensorType>();
auto sizes = *tt->sizes().concrete_sizes();
Expand All @@ -2159,7 +2184,7 @@ void TensorExprKernel::compile() {
}

BackendType backendType = inferBackendTypeFromDevice(device_);
Stmt* stmt = generateStmt(backendType);
Stmt* stmt = transformLoops(backendType, new Block(tensor_stmts));
// Set up formal params (inputs, then outputs) for kernel.
std::vector<CodeGen::BufferArg> params = prepareBufferArgs();

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/tensorexpr/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class TORCH_API TensorExprKernel {

Tensor* computeValue(const torch::jit::Value* v);

Stmt* generateStmt(BackendType backendType);
Stmt* transformLoops(BackendType backendType, Stmt* st);
std::vector<CodeGen::BufferArg> prepareBufferArgs();

std::string getCodeGenName(BackendType backendType);
Expand Down