Skip to content

Commit

Permalink
Add backend checks for batch norm
Browse files Browse the repository at this point in the history
  • Loading branch information
vishwakftw committed Jan 11, 2019
1 parent 07ea3e0 commit 29d45de
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion aten/src/ATen/TensorUtils.cpp
Expand Up @@ -196,7 +196,7 @@ void checkAllDefined(CheckedFrom c, ArrayRef<TensorArg> ts) {

void checkBackend(CheckedFrom c, const Tensor& t, Backend backend) {
AT_CHECK(
t.type().backend() == backend,
!t.defined() || t.type().backend() == backend,
"Expected tensor to have ", toString(backend),
" Backend, but got tensor with ", toString(t.type().backend()), " Backend ",
"(while checking arguments for ", c, ")");
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/Normalization.cpp
Expand Up @@ -461,6 +461,8 @@ std::tuple<Tensor, Tensor> batch_norm_update_stats_cpu(
std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const Tensor& weight, const Tensor& bias,
const Tensor& running_mean, const Tensor& running_var,
bool train, double momentum, double eps) {
checkBackend("batch_norm_cpu", {self, weight, bias, running_mean, running_var}, Backend::CPU);

return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm", [&] {
if (!train) {
return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, {}, {}, running_mean, running_var, train, eps);
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/cuda/Normalization.cuh
Expand Up @@ -395,6 +395,14 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda_template(const Tensor& input_
const Tensor& running_mean_, const Tensor& running_var_,
bool train, double momentum, double epsilon) {

TensorArg input_arg{ input_, "input", 1 },
weight_arg{ weight_, "weight", 2 },
bias_arg{ bias_, "bias", 3 },
run_mean_arg{ running_mean_, "running_mean", 4 },
run_var_arg{ running_var_, "running_var", 5 };
CheckedFrom c = "batch_norm_cuda";
checkAllSameGPU(c, {input_arg, weight_arg, bias_arg, run_mean_arg, run_var_arg});

using accscalar_t = at::acc_type<scalar_t, true>;
int64_t n_input = input_.size(1);
Tensor save_mean_;
Expand Down

0 comments on commit 29d45de

Please sign in to comment.