Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/transform_rfactor.cpp",
"torch/csrc/jit/codegen/cuda/type.cpp",
"torch/csrc/jit/codegen/cuda/utils.cpp",
"torch/csrc/jit/passes/frozen_conv_add_relu_fusion_cuda.cpp",
"torch/csrc/jit/tensorexpr/cuda_codegen.cpp",
"torch/csrc/jit/runtime/register_cuda_ops.cpp",
]
Expand Down
107 changes: 11 additions & 96 deletions torch/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,105 +15,20 @@
namespace torch {
namespace jit {

namespace {
void fuseFrozenConvAddReluImpl(std::shared_ptr<Graph>& graph) {
#ifdef USE_CUDA
#if AT_CUDNN_ENABLED()
SubgraphRewriter rewriter;

// CUDNN does not support conv1d
std::array<std::string, 2> conv_operators = {"conv2d", "conv3d"};
std::array<std::string, 2> add_operators = {"add", "add_"};
std::array<std::string, 2> relu_operators = {"relu", "relu_"};

auto conv_relu_rstring = CodeTemplate(R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
%res = aten::${relu}(%x)
return (%res))");

std::string conv_relu_fused = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%res = aten::cudnn_convolution_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
return (%res))";

auto conv_add_relu_rstring = CodeTemplate(R"(
graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
%y = aten::${add}(%x, %z, %alpha)
%res = aten::${relu}(%y)
return (%res))");

std::string conv_add_relu_fused = R"(
graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%res = aten::cudnn_convolution_add_relu(%input, %weight, %z, %alpha, %bias, %stride, %padding, %dilation, %groups)
return (%res))";

for (const auto& conv : conv_operators) {
for (const auto& relu : relu_operators) {
TemplateEnv env;
env.s("conv", conv);
env.s("relu", relu);
rewriter.RegisterRewritePattern(
conv_relu_rstring.format(env), conv_relu_fused);
for (const auto& add : add_operators) {
env.s("add", add);
rewriter.RegisterRewritePattern(
conv_add_relu_rstring.format(env), conv_add_relu_fused);
}
}
}

auto filter = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
auto weight = toIValue(match.values_map.at(vmap.at("weight")));
if (!weight.has_value() || !weight.value().isTensor()) {
return false;
}
const at::Tensor& weight_t = weight.value().toTensor();
if (!weight_t.device().is_cuda() || !weight_t.is_contiguous()) {
return false;
}

// bias is optional
if (vmap.find("bias") != vmap.end()) {
auto bias = toIValue(match.values_map.at(vmap.at("bias")));
if (bias.has_value() && bias.value().isTensor()) {
const at::Tensor& bias_t = bias.value().toTensor();
if (bias_t.dtype() != weight_t.dtype() || bias_t.ndimension() != 1 ||
bias_t.size(0) != weight_t.size(0) || !bias_t.device().is_cuda()) {
return false;
}
}
}

// z is optional
if (vmap.find("z") != vmap.end()) {
auto z = toIValue(match.values_map.at(vmap.at("z")));
if (z.has_value() && z.value().isTensor()) {
const at::Tensor& z_t = z.value().toTensor();
if (z_t.dtype() != weight_t.dtype() ||
z_t.size(0) != weight_t.size(0) || !z_t.is_contiguous() ||
!z_t.device().is_cuda()) {
return false;
}
}
}
return true;
};

// Convert _convolution and in-place operators for simpler replacement pattern
// matching
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);

rewriter.runOnGraph(graph, filter);
#endif
#endif
std::function<void(std::shared_ptr<Graph>&)>& getFuseFrozenConvAddReluImpl() {
static std::function<void(std::shared_ptr<Graph>&)> impl;
return impl;
}
} // namespace

// Implementation is in frozen_conv_add_relu_fusion.cpp; at runtime the
// implementation is registered in _fuseFrozenConvAddReluImpl. This allows
// the GPU code to be built separately from CPU-only code.
void FuseFrozenConvAddRelu(std::shared_ptr<Graph>& graph) {
fuseFrozenConvAddReluImpl(graph);
if (getFuseFrozenConvAddReluImpl()) {
getFuseFrozenConvAddReluImpl()(graph);
} else {
TORCH_WARN("No definition of _fuseFrozenConvAddReluImpl found");
}
}

} // namespace jit
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
namespace torch {
namespace jit {

TORCH_API extern std::function<void(std::shared_ptr<Graph>&)>&
getFuseFrozenConvAddReluImpl();

TORCH_API void FuseFrozenConvAddRelu(std::shared_ptr<Graph>& graph);

} // namespace jit
Expand Down
118 changes: 118 additions & 0 deletions torch/csrc/jit/passes/frozen_conv_add_relu_fusion_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#include <ATen/Utils.h>

#include <ATen/cuda/CUDAConfig.h>
#include <torch/csrc/jit/frontend/code_template.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>

namespace torch {
namespace jit {

namespace {
void fuseFrozenConvAddReluImpl(std::shared_ptr<Graph>& graph) {
#if AT_CUDNN_ENABLED()
SubgraphRewriter rewriter;

// CUDNN does not support conv1d
std::array<std::string, 2> conv_operators = {"conv2d", "conv3d"};
std::array<std::string, 2> add_operators = {"add", "add_"};
std::array<std::string, 2> relu_operators = {"relu", "relu_"};

auto conv_relu_rstring = CodeTemplate(R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
%res = aten::${relu}(%x)
return (%res))");

std::string conv_relu_fused = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%res = aten::cudnn_convolution_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
return (%res))";

auto conv_add_relu_rstring = CodeTemplate(R"(
graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
%y = aten::${add}(%x, %z, %alpha)
%res = aten::${relu}(%y)
return (%res))");

std::string conv_add_relu_fused = R"(
graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%res = aten::cudnn_convolution_add_relu(%input, %weight, %z, %alpha, %bias, %stride, %padding, %dilation, %groups)
return (%res))";

for (const auto& conv : conv_operators) {
for (const auto& relu : relu_operators) {
TemplateEnv env;
env.s("conv", conv);
env.s("relu", relu);
rewriter.RegisterRewritePattern(
conv_relu_rstring.format(env), conv_relu_fused);
for (const auto& add : add_operators) {
env.s("add", add);
rewriter.RegisterRewritePattern(
conv_add_relu_rstring.format(env), conv_add_relu_fused);
}
}
}

auto filter = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
auto weight = toIValue(match.values_map.at(vmap.at("weight")));
if (!weight.has_value() || !weight.value().isTensor()) {
return false;
}
const at::Tensor& weight_t = weight.value().toTensor();
if (!weight_t.device().is_cuda() || !weight_t.is_contiguous()) {
return false;
}

// bias is optional
if (vmap.find("bias") != vmap.end()) {
auto bias = toIValue(match.values_map.at(vmap.at("bias")));
if (bias.has_value() && bias.value().isTensor()) {
const at::Tensor& bias_t = bias.value().toTensor();
if (bias_t.dtype() != weight_t.dtype() || bias_t.ndimension() != 1 ||
bias_t.size(0) != weight_t.size(0) || !bias_t.device().is_cuda()) {
return false;
}
}
}

// z is optional
if (vmap.find("z") != vmap.end()) {
auto z = toIValue(match.values_map.at(vmap.at("z")));
if (z.has_value() && z.value().isTensor()) {
const at::Tensor& z_t = z.value().toTensor();
if (z_t.dtype() != weight_t.dtype() ||
z_t.size(0) != weight_t.size(0) || !z_t.is_contiguous() ||
!z_t.device().is_cuda()) {
return false;
}
}
}
return true;
};

// Convert _convolution and in-place operators for simpler replacement pattern
// matching
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);

rewriter.runOnGraph(graph, filter);
#endif
}

auto dummyInitializer = []() {
getFuseFrozenConvAddReluImpl() = fuseFrozenConvAddReluImpl;
return true;
}();

} // namespace

} // namespace jit
} // namespace torch