Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,28 @@ void _batch_norm(
const torch::Tensor& mean,
const torch::Tensor& var,
const float eps) {
auto scale = gamma / torch::sqrt(var + eps);
auto bias = beta - mean * scale;
auto orig_dtype = var.dtype();
// perform compile-time weight calculations in float to improve accuracy
// resulting weights will be embedded as the original dtype
auto calculation_gamma = gamma;
auto calculation_beta = beta;
auto calculation_mean = mean;
auto calculation_var = var;
if (orig_dtype == torch::kHalf) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question is this different than the normal pytorch behavior? If so can we add a debug message here saying that we are doing this to improve accuracy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the cudnn implementation at least asserts that the weight is fp32 which would force similar calculations to fp32:
https://github.com/pytorch/pytorch/blob/4bfe2a24505049fa4fe43d24c2e3a5f5d99d9f00/aten/src/ATen/native/cudnn/BatchNorm.cpp#L110

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

calculation_gamma = calculation_gamma.to(torch::kFloat);
calculation_beta = calculation_beta.to(torch::kFloat);
calculation_mean = calculation_mean.to(torch::kFloat);
calculation_var = calculation_var.to(torch::kFloat);
}
auto scale = calculation_gamma / torch::sqrt(calculation_var + eps);
auto bias = calculation_beta - calculation_mean * scale;
LOG_DEBUG("_batch_norm Tensor Scale : " << scale.sizes());
LOG_DEBUG("_batch_norm Tensor bias : " << bias.sizes());

auto scale_weights = Weights(ctx, scale);
auto bias_weights = Weights(ctx, bias);
auto scale_weights = Weights(ctx, scale.to(orig_dtype));
auto bias_weights = Weights(ctx, bias.to(orig_dtype));

auto power = Weights(ctx, at::ones_like(scale));
auto power = Weights(ctx, at::ones_like(scale).to(orig_dtype));
auto bn =
ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1);
bn->setName(util::node_info(n).c_str());
Expand Down
30 changes: 30 additions & 0 deletions tests/core/conversion/converters/test_batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,33 @@ TEST(Converters, ATenBatchNormShouldUnpackConvertsCorrectly) {
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenBatchNormHalfConvertsCorrectly) {
const auto graph = R"IR(
graph(%input : Tensor, %running_var : Half(32, strides=[1], requires_grad=0, device=cuda:0), %running_mean : Half(32, strides=[1], requires_grad=0, device=cuda:0)):
%5 : bool = prim::Constant[value=0]()
%4 : float = prim::Constant[value=0.01]()
%3 : float = prim::Constant[value=0.001]()
%2 : bool = prim::Constant[value=1]()
%8 : Tensor = aten::batch_norm(%input, %running_var, %running_mean, %running_mean, %running_var, %5, %4, %3, %2)
return (%8))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randn({2, 32, 5, 5}, {at::kCUDA}).to(at::kHalf);
auto mean = at::ones({32}, {at::kCUDA}).to(at::kHalf);
auto var = at::zeros({32}, {at::kCUDA}).to(at::kHalf);

auto trt_in = at::clone(in);
auto trt_mean = at::clone(mean);
auto trt_var = at::clone(var);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {mean, var});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_mean, trt_var});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}, {nvinfer1::DataType::kHALF});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-2));
}