Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,42 @@ Tensor& sum_out(Tensor& result, const Tensor& self, IntArrayRef dim, ScalarType
return at::native::sum_out(result, self, dim, false, dtype);
}

int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
int64_t size = 1;
if (sizes.size() == 0) {
return 1;
}
for (auto d : dim) {
d = at::maybe_wrap_dim(d, sizes.size());
size *= sizes[d];
}
return size;
}

Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
Tensor res = t;
for (size_t i = 0; i < n_dims; i++){
if (dims_to_unsqueeze[i]) {
res = res.unsqueeze(i);
}
}
return res;
}

Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
if (!keepdim && sizes.size() > 0) {
if (dims.size()==1) {
return grad.unsqueeze(dims[0]).expand(sizes);
} else {
Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
return res.expand(sizes);
}
} else {
return grad.expand(sizes);
}
}

Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
return at::native::prod_out(
result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
Expand Down Expand Up @@ -416,6 +452,16 @@ Tensor logsumexp(const Tensor &self, IntArrayRef dims, bool keepdim) {
return at::native::logsumexp_out(result, self, dims, keepdim);
}

Tensor logsumexp_backward(const Tensor& grad, const Tensor & self, const Tensor& res, IntArrayRef dim, bool keepdim) {
Tensor grad_input = grad;
Tensor fwd_res = res;
if (!keepdim && self.dim() != 0) {
grad_input = unsqueeze_multiple(grad, dim, self.sizes().size());
fwd_res = unsqueeze_multiple(res, dim, self.sizes().size());
}
return grad_input * (self - fwd_res).exp();
}

static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto p = opt_p.value_or(2.0);
Expand Down Expand Up @@ -628,6 +674,21 @@ Tensor &var_out(Tensor &result, const Tensor &self, IntArrayRef dim, bool unbias
return std_var_out(result, self, dim, unbiased, keepdim, false);
}

Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
}

Tensor var_backward(const Tensor & grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
if (self.dim() == 0) {
return at::var_backward(grad, self, unbiased);
}
Tensor unsqueezed_grad = grad;
if (!keepdim && self.dim() > 1) {
unsqueezed_grad = unsqueeze_multiple(grad, dim, self.sizes().size());
}
return (2.0 / (at::_safe_size(self.sizes(), dim) - unbiased)) * unsqueezed_grad * (self - self.mean(dim, true));
}

Tensor std(const Tensor& self, bool unbiased) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"std only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/TensorCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ namespace at { namespace native {
DEFINE_DISPATCH(max_kernel);
DEFINE_DISPATCH(min_kernel);

Tensor index_select_backward(const Tensor& grad, int64_t dim, const Tensor& indices, IntArrayRef sizes, bool keepdim) {
Tensor res = at::zeros(sizes, grad.options());
if (!keepdim && sizes.size() > 0) {
return res.scatter_(dim, indices.unsqueeze(dim), grad.unsqueeze(dim));
}
return res.scatter_(dim, indices, grad);
}

bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
}
Expand Down
44 changes: 44 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,16 @@ Tensor permute(const Tensor& self, IntArrayRef dims) {
return self.as_strided(newSizes, newStrides);
}

Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
// invert the permutation
auto ndims = fwd_dims.size();
std::vector<int64_t> dims(ndims);
for (size_t i = 0; i < ndims; i++) {
dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
}
return grad.permute(dims);
}

Tensor repeat(const Tensor& self, IntArrayRef repeats) {
AT_CHECK(repeats.size() >= (size_t)self.dim(),
"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
Expand Down Expand Up @@ -451,6 +461,12 @@ Tensor select(const Tensor& self, int64_t dim, int64_t index) {
return self.as_strided(sizes, strides, storage_offset);
}

Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
auto grad_input = at::zeros(input_sizes, grad.options());
grad_input.select(dim, index).copy_(grad);
return grad_input;
}

Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
int64_t ndim = self.dim();
if (ndim == 0) {
Expand Down Expand Up @@ -484,6 +500,12 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_
return self.as_strided(sizes, strides, storage_offset);
}

Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
auto grad_input = at::zeros(input_sizes, grad.options());
grad_input.slice(dim, start, end, step).copy_(grad);
return grad_input;
}

std::vector<Tensor> split(const Tensor& self, int64_t split_size, int64_t dim) {
AT_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
AT_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
Expand Down Expand Up @@ -690,6 +712,28 @@ Tensor squeeze(const Tensor& self) {
return self.as_strided(std::get<0>(g), std::get<1>(g));
}

Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
auto result = self;

int64_t nDims = sizes.size();
for (int64_t dim = 0; dim < nDims; dim++) {
if (sizes[dim] == 1) {
result = result.unsqueeze(dim);
}
}
return result;
}

Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
dim = at::maybe_wrap_dim(dim, sizes.size());
// in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
// unsqueezing in the backward.
if (sizes.size() > 0 && sizes[dim] == 1) {
return self.unsqueeze(dim);
}
return self;
}

Tensor squeeze(const Tensor& self, int64_t dim) {
int64_t dims = self.dim();
dim = maybe_wrap_dim(dim, dims);
Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@
dispatch:
CUDA: _cudnn_init_dropout_state

- func: index_select_backward(Tensor grad, int64_t dim, Tensor indices, int[] sizes, bool keepdim) -> Tensor

- func: select_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t index) -> Tensor

- func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function
Expand Down Expand Up @@ -1320,6 +1324,8 @@
- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True

- func: logsumexp_backward(Tensor grad, Tensor self, Tensor res, int[1] dim, bool keepdim) -> Tensor

- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
matches_jit_signature: True

Expand Down Expand Up @@ -1392,6 +1398,8 @@
- func: mean(Tensor self, int[1] dim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True

- func: sum_backward(Tensor grad, int[] sizes, int[] dims, bool keepdim) -> Tensor

- func: median(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function, method
Expand Down Expand Up @@ -1611,6 +1619,8 @@
matches_jit_signature: True
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.

- func: permute_backwards(Tensor grad, int[] fwd_dims) -> Tensor

- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
matches_jit_signature: True

Expand Down Expand Up @@ -1885,11 +1895,15 @@
variants: function, method
device_guard: False

- func: _safe_size(int[] sizes, int[] dim) -> int64_t

- func: slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
matches_jit_signature: True
variants: function, method
device_guard: False

- func: slice_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) -> Tensor

- func: slogdet(Tensor self) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function, method
Expand Down Expand Up @@ -1980,6 +1994,10 @@
variants: function, method
device_guard: False

- func: unsqueeze_to(Tensor self, int[] sizes) -> Tensor

- func: unsqueeze_to(Tensor self, int64_t dim, int[] sizes) -> Tensor

- func: squeeze_(Tensor(a!) self) -> Tensor(a!)
matches_jit_signature: True
variants: method
Expand Down Expand Up @@ -2276,6 +2294,10 @@
- func: var(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True

- func: var_backward(Tensor grad, Tensor self, bool unbiased) -> Tensor

- func: var_backward(Tensor grad, Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor

- func: view_as(Tensor self, Tensor other) -> Tensor
matches_jit_signature: True
variants: method
Expand Down
3 changes: 3 additions & 0 deletions test/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def method_tests():
('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1'),
('expand', (), (dont_convert(()),), 'scalar_to_scalar'),
('expand', (), (1, 3, 2), 'scalar_to_dims'),
('expand_as', (S, 1, 1), (torch.rand(S, S, S),)),
('exp', (S, S, S), NO_ARGS),
('exp', (), NO_ARGS, 'scalar'),
('expm1', (S, S, S), NO_ARGS),
Expand Down Expand Up @@ -1020,6 +1021,8 @@ def unpack_variables(args):
'test_det_dim2_null',
'test_det_rank1',
'test_det_rank2',
# `other` expand_as(self, other) is not used in autograd.
'test_expand_as',
'test_logdet',
'test_logdet_1x1',
'test_logdet_symmetric',
Expand Down
40 changes: 37 additions & 3 deletions test/cpp/jit/test_misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,24 @@ void testDifferentiate(std::ostream& out = std::cout) {

auto grad_spec = differentiate(graph);
std::vector<size_t> expected_captured_inputs = {0, 1};
std::vector<size_t> expected_captured_outputs = {1, 2};
// With add/mul implemented using torchscript, we passes sizes of
// self & other instead passing the tensors themselve.
// The forward graph is now
//graph(%0 : Float(2, 3, 4)
// %1 : Float(2, 3, 4)) {
// %2 : Float(2, 3, 4) = aten::mul(%0, %1)
// %self_size.4 : int[] = aten::size(%0)
// %other_size.4 : int[] = aten::size(%1)
// %3 : Float(2, 3, 4) = aten::mul(%2, %0)
// %self_size.2 : int[] = aten::size(%2)
// %4 : int = prim::Constant[value=1]()
// %7 : int[] = aten::size(%3)
// %5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
// return (%5, %2, %self_size.4, %other_size.4, %self_size.2, %7);
//}
// Thus all the sizes info added in forward outputs are saved
// in grad_spec.df_input_caputered_outputs.
std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5};
std::vector<size_t> expected_input_vjps = {0, 1};
std::vector<size_t> expected_output_vjps = {0, 1};
ASSERT_EQ(grad_spec.f_real_outputs, 1);
Expand Down Expand Up @@ -880,12 +897,29 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
PropagateInputShapes(graph);
PropagateRequiresGrad(graph);

// With add/mul implemented using torchscript, we passes sizes of
// self & other instead passing the tensors themselve.
// The forward graph is now
// graph(%0 : Float(*)
// %1 : Float(*)) {
// %2 : Float(*) = aten::mul(%1, %1)
// %3 : int = prim::Constant[value=1]()
// %4 : Float(*) = aten::add(%2, %1, %3)
// %39 : int[] = aten::size(%0)
// %6 : Float(*) = aten::add(%4, %0, %3)
// %7 : Float(*) = aten::mul(%6, %0)
// %self_size.2 : int[] = aten::size(%6)
// %11 : int[] = aten::size(%7)
// %9 : Float(*) = aten::add(%7, %1, %3)
// return (%4, %9, %39, %6, %self_size.2, %11);
// }

auto grad_spec = differentiate(graph);
std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a)
std::vector<size_t> expected_input_vjps = {1, 3}; // for e and %6 = (d + a)
std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
ASSERT_EQ(grad_spec.f_real_outputs, 2);
ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3}));
ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3, 4, 5}));
ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
out << "testDifferentiateWithRequiresGrad\n";
Expand Down
32 changes: 16 additions & 16 deletions test/expect/TestFuser.test_lstm_cuda-backward.expect
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ graph(%0 : Float(*, *),
%forgetgate : Float(*, *),
%cellgate : Float(*, *),
%outgate : Float(*, *),
%24 : int[],
%25 : int[],
%26 : Float(*, *)):
%27 : int = prim::Constant[value=1]()
%28 : int[] = aten::size(%outgate)
%29 : int[] = aten::size(%26)
%30 : int[] = aten::size(%ingate)
%31 : int[] = aten::size(%cellgate)
%32 : int[] = aten::size(%forgetgate)
%33 : int[] = aten::size(%9)
%34 : Tensor = prim::FusionGroup_0(%outgate, %0, %26, %28)
%grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %26, %0, %outgate, %33, %32, %24, %31, %30, %25, %29)
%self_size.5 : int[],
%other_size.5 : int[],
%self_size.3 : int[],
%other_size.3 : int[],
%28 : int[],
%29 : int[],
%30 : Float(*, *),
%self_size.1 : int[],
%other_size.1 : int[]):
%33 : int = prim::Constant[value=1]()
%34 : Tensor = prim::FusionGroup_0(%outgate, %0, %30, %self_size.1)
%grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %30, %0, %outgate, %other_size.5, %self_size.5, %28, %other_size.3, %self_size.3, %29, %other_size.1)
%39 : Tensor[] = prim::ListConstruct(%38, %36, %37, %34)
%40 : Tensor = aten::cat(%39, %27)
%40 : Tensor = aten::cat(%39, %33)
%41 : Tensor = aten::_grad_sum_to_size(%40, %19)
%42 : Tensor = aten::_grad_sum_to_size(%40, %17)
%43 : Tensor = aten::_grad_sum_to_size(%40, %14)
Expand All @@ -44,13 +44,13 @@ graph(%0 : Float(*, *),
%46 : Float(*, *) = aten::mm(%44, %45)
%47 : Float(*, *) = aten::t(%10)
%48 : Float(*, *) = aten::mm(%47, %44)
%49 : Float(*, *) = aten::t(%48)
%grad_self.7 : Float(*, *) = aten::t(%48)
%50 : Float(*, *) = aten::t(%12)
%51 : Float(*, *) = aten::mm(%43, %50)
%52 : Float(*, *) = aten::t(%11)
%53 : Float(*, *) = aten::mm(%52, %43)
%54 : Float(*, *) = aten::t(%53)
return (%grad_other.5, %41, %42, %46, %49, %51, %54)
%grad_self.9 : Float(*, *) = aten::t(%53)
return (%grad_other.5, %41, %42, %46, %grad_self.7, %51, %grad_self.9)
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
%1 : Float(*, *),
%2 : Float(*, *),
Expand Down
18 changes: 10 additions & 8 deletions test/expect/TestFuser.test_lstm_cuda-forward.expect
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *),
%17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
%21 : int[] = prim::BroadcastSizes(%11, %12)
%22 : int[] = prim::BroadcastSizes(%21, %13)
%hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
%30 : int[] = aten::size(%0)
%31 : int[] = aten::size(%cellgate.1)
%32 : int[] = aten::size(%forgetgate.1)
%33 : int[] = aten::size(%ingate.1)
%34 : int[] = prim::BroadcastSizes(%32, %30)
%35 : int[] = prim::BroadcastSizes(%33, %31)
return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %35, %24)
%other_size.6 : int[] = aten::size(%0)
%hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
%31 : int[] = aten::size(%25)
%32 : int[] = aten::size(%outgate.1)
%33 : int[] = aten::size(%cellgate.1)
%34 : int[] = aten::size(%forgetgate.1)
%35 : int[] = aten::size(%ingate.1)
%36 : int[] = prim::BroadcastSizes(%34, %other_size.6)
%37 : int[] = prim::BroadcastSizes(%35, %33)
return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %other_size.6, %35, %33, %36, %37, %25, %32, %31)
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
%1 : Tensor,
%2 : Tensor,
Expand Down
Loading