Skip to content

Commit

Permalink
Update on "fix inference_mode with torch.compile"
Browse files Browse the repository at this point in the history
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
  • Loading branch information
bdhirsh committed May 23, 2023
2 parents 0ff62dc + e178403 commit 30ddba1
Show file tree
Hide file tree
Showing 168 changed files with 3,784 additions and 837 deletions.
24 changes: 20 additions & 4 deletions .github/actions/teardown-win/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,27 @@ runs:
- name: Clean up leftover processes on non-ephemeral Windows runner
uses: pytorch/test-infra/.github/actions/cleanup-runner@main

# Cleaning up Windows workspace sometimes fails flakily with device or resource busy
# error, meaning one or more processes haven't stopped completely yet. So trying to
# retry this step several time similar to how checkout-pytorch GHA does
- name: Cleanup workspace
if: always()
shell: bash
uses: nick-fields/retry@v2.8.2
env:
EXTRA_DELETE_DIR: ${{ inputs.extra-delete-dir }}
run: |
[ ! -z "${EXTRA_DELETE_DIR}" ] || rm -rf "${EXTRA_DELETE_DIR}"
rm -rf ./*
with:
shell: bash
timeout_minutes: 5
max_attempts: 3
retry_wait_seconds: 90
command: |
set +e
set -x
if [ -n "${EXTRA_DELETE_DIR}" ]; then
# It's ok to fail to clean up the extra directory on Windows as it only contains
# the build artifacts and doesn't take up much space, i.e. /c/5053411580/build-results
rm -rf "${EXTRA_DELETE_DIR}" || true
fi
rm -rf ./*
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/vision.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
300a90926e88f13abbaf3d8155cdba36aab86ab4
6ccc712b02b4014a087878969b610e486ebc6adf
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ docs/source/scripts/quantization_backend_configs/
## Caffe2

# build, distribute, and bins (+ python proto bindings)
build
build/
# Allow tools/build/ for build support.
!tools/build/
build_host_protoc
build_android
build_ios
Expand Down
4 changes: 3 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test")
load("@rules_python//python:defs.bzl", "py_library", "py_test")
load("@pytorch//third_party:substitution.bzl", "header_template_rule", "template_rule")
load("@pytorch//:tools/bazel.bzl", "rules")
load("@pytorch//tools/rules:cu.bzl", "cu_library")
Expand Down Expand Up @@ -1752,7 +1753,7 @@ template_rule(
}),
)

rules.py_library(
py_library(
name = "pytorch_py",
visibility = ["//visibility:public"],
srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"] + glob(["functorch/**/*.py"]),
Expand All @@ -1763,6 +1764,7 @@ rules.py_library(
rules.requirement("requests"),
rules.requirement("setuptools"),
rules.requirement("six"),
rules.requirement("sympy"),
rules.requirement("typing_extensions"),
"//torchgen",
],
Expand Down
41 changes: 32 additions & 9 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ http_archive(

http_archive(
name = "pybind11_bazel",
strip_prefix = "pybind11_bazel-992381ced716ae12122360b0fbadbc3dda436dbf",
urls = ["https://github.com/pybind/pybind11_bazel/archive/992381ced716ae12122360b0fbadbc3dda436dbf.zip"],
strip_prefix = "pybind11_bazel-b162c7c88a253e3f6b673df0c621aca27596ce6b",
urls = ["https://github.com/pybind/pybind11_bazel/archive/b162c7c88a253e3f6b673df0c621aca27596ce6b.zip"],
)

new_local_repository(
Expand Down Expand Up @@ -192,25 +192,48 @@ http_archive(

http_archive(
name = "rules_python",
sha256 = "aa96a691d3a8177f3215b14b0edc9641787abaaa30363a080165d06ab65e1161",
url = "https://github.com/bazelbuild/rules_python/releases/download/0.0.1/rules_python-0.0.1.tar.gz",
# TODO Fix bazel linter to support hashes for release tarballs.
#
# sha256 = "94750828b18044533e98a129003b6a68001204038dc4749f40b195b24c38f49f",
strip_prefix = "rules_python-0.21.0",
url = "https://github.com/bazelbuild/rules_python/releases/download/0.21.0/rules_python-0.21.0.tar.gz",
)

load("@rules_python//python:repositories.bzl", "py_repositories")

py_repositories()

load("@rules_python//python:repositories.bzl", "python_register_toolchains")

python_register_toolchains(
name = "python3_8",
python_version = "3.8",
)

load("@python3_8//:defs.bzl", "interpreter")
load("@rules_python//python:pip.bzl", "pip_parse")

pip_parse(
name = "pip_deps",
python_interpreter_target = interpreter,
requirements_lock = "//:tools/build/bazel/requirements.txt",
)

load("@pip_deps//:requirements.bzl", "install_deps")

install_deps()

load("@pybind11_bazel//:python_configure.bzl", "python_configure")

python_configure(
name = "local_config_python",
python_version = "3",
python_interpreter_target = interpreter,
)

load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")

protobuf_deps()

load("@rules_python//python:repositories.bzl", "py_repositories")

py_repositories()

new_local_repository(
name = "cuda",
build_file = "@//third_party:cuda.BUILD",
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/FunctionalStorageImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ static c10::SymInt get_nbytes(const Tensor& value) {
if (value.key_set().has(c10::DispatchKey::Python)) {
return value.storage().sym_nbytes();
}
return at::detail::computeStorageNbytes(value.sym_sizes(), value.sym_strides(), value.dtype().itemsize(), value.sym_storage_offset());
}
// XLA storage objects also do not properly track nbytes.
return at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset());
Expand Down
10 changes: 0 additions & 10 deletions aten/src/ATen/native/AdaptiveAveragePooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,6 @@ namespace {
}
}

Tensor& adaptive_avg_pool2d_backward_out_cpu(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input)
{
adaptive_avg_pool2d_backward_out_cpu_template(
grad_input, grad_output, input);
return grad_input;
}

Tensor adaptive_avg_pool2d_backward_cpu(
const Tensor& grad_output,
const Tensor& input)
Expand Down
18 changes: 9 additions & 9 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1482,23 +1482,23 @@ Tensor& not_equal_(Tensor& self, const Scalar& other) { return self.ne_(other);
Tensor& logical_and_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_and_stub); }
Tensor logical_and(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_and_out)); }
Tensor& logical_and_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_and_out)); }
Tensor& logical_and_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_and_out)); }
Tensor logical_and(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_and_out)); }
Tensor& logical_and_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_and_out)); }
static Tensor& logical_and_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_and_out)); }
static Tensor logical_and(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_and_out)); }
static Tensor& logical_and_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_and_out)); }

Tensor& logical_or_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_or_stub); }
Tensor logical_or(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_or_out)); }
Tensor& logical_or_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_or_out)); }
Tensor& logical_or_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_or_out)); }
Tensor logical_or(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_or_out)); }
Tensor& logical_or_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_or_out)); }
static Tensor& logical_or_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_or_out)); }
static Tensor logical_or(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_or_out)); }
static Tensor& logical_or_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_or_out)); }

Tensor& logical_xor_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_xor_stub); }
Tensor logical_xor(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
Tensor& logical_xor_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
Tensor& logical_xor_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_xor_out)); }
Tensor logical_xor(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
Tensor& logical_xor_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
static Tensor& logical_xor_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_xor_out)); }
static Tensor logical_xor(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
static Tensor& logical_xor_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out)); }

// binary max, alias for maximum
Tensor& max_out(const Tensor& self, const Tensor& other, Tensor& result) {
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/ComparisonUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include <ATen/core/TensorBody.h>
#include <c10/util/OptionalArrayRef.h>

#ifdef AT_PER_OPERATOR_HEADERS
#include <ATen/ops/_assert_tensor_metadata_native.h>
#endif

namespace at {

class Tensor;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#else
#include <ATen/ops/_copy_from.h>
#include <ATen/ops/_propagate_xla_data.h>
#include <ATen/ops/_propagate_xla_data_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/expand_copy.h>
Expand Down
30 changes: 9 additions & 21 deletions aten/src/ATen/native/Histogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
#include <ATen/ops/_histogramdd_from_bin_tensors.h>
#include <ATen/ops/_histogramdd_from_bin_tensors_native.h>
#include <ATen/ops/aminmax.h>
#include <ATen/ops/amin.h>
#include <ATen/ops/amax.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/histc_native.h>
#include <ATen/ops/histogram_native.h>
Expand Down Expand Up @@ -196,9 +194,8 @@ select_outer_bin_edges(const Tensor& input, c10::optional<c10::ArrayRef<double>>
// non-empty input
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
if (input.is_mps()) {
// aminmax has not been implemented on mps.
Tensor min = at::amin(input, 0);
Tensor max = at::amax(input, 0);
Tensor min, max;
std::tie(min, max) = at::aminmax(input, 0);

for (const auto i : c10::irange(N)) {
leftmost_edges[i] = min[i].item().to<scalar_t>();
Expand Down Expand Up @@ -239,18 +236,9 @@ std::pair<double, double> histc_select_outer_bin_edges(const Tensor& input,
double rightmost_edge = max.to<double>();

if (leftmost_edge == rightmost_edge && input.numel() > 0) {
if (input.is_mps()) {
// aminmax has not been implemented on mps.
Tensor min = at::amin(input);
Tensor max = at::amax(input);

leftmost_edge = min.item<double>();
rightmost_edge = max.item<double>();
} else {
auto extrema = aminmax(input);
leftmost_edge = std::get<0>(extrema).item<double>();
rightmost_edge = std::get<1>(extrema).item<double>();
}
auto extrema = aminmax(input);
leftmost_edge = std::get<0>(extrema).item<double>();
rightmost_edge = std::get<1>(extrema).item<double>();
}

if (leftmost_edge == rightmost_edge) {
Expand All @@ -269,7 +257,7 @@ std::pair<double, double> histc_select_outer_bin_edges(const Tensor& input,

} // namespace

std::vector<Tensor> allocate_bin_edges_tensors(const Tensor& self) {
static std::vector<Tensor> allocate_bin_edges_tensors(const Tensor& self) {
TORCH_CHECK(self.dim() >= 2, "torch.histogramdd: input tensor should have at least 2 dimensions");
const int64_t N = self.size(-1);
std::vector<Tensor> bin_edges_out(N);
Expand All @@ -281,7 +269,7 @@ std::vector<Tensor> allocate_bin_edges_tensors(const Tensor& self) {

/* Versions of histogramdd in which bins is a Tensor[] defining the sequences of bin edges.
*/
Tensor& histogramdd_out(const Tensor& self, TensorList bins,
static Tensor& histogramdd_out(const Tensor& self, TensorList bins,
const c10::optional<Tensor>& weight, bool density,
Tensor& hist, TensorList& bin_edges) {
histogramdd_check_inputs(self, bins, weight);
Expand All @@ -308,7 +296,7 @@ Tensor _histogramdd(const Tensor& self, TensorList bins,
/* Versions of histogramdd in which bins is an int[]
* defining the number of bins in each dimension.
*/
std::vector<Tensor>& histogramdd_bin_edges_out(const Tensor& self, IntArrayRef bin_ct,
static std::vector<Tensor>& histogramdd_bin_edges_out(const Tensor& self, IntArrayRef bin_ct,
c10::optional<c10::ArrayRef<double>> range,
const c10::optional<Tensor>& weight, bool density,
std::vector<Tensor>& bin_edges_out) {
Expand Down Expand Up @@ -340,7 +328,7 @@ std::vector<Tensor> histogramdd_bin_edges(const Tensor& self, IntArrayRef bin_ct
return histogramdd_bin_edges_out(self, bin_ct, range, weight, density, bin_edges_out);
}

Tensor& histogramdd_out(const Tensor& self, IntArrayRef bin_ct,
static Tensor& histogramdd_out(const Tensor& self, IntArrayRef bin_ct,
c10::optional<c10::ArrayRef<double>> range,
const c10::optional<Tensor>& weight, bool density,
Tensor& hist, TensorList& bin_edges) {
Expand Down
16 changes: 8 additions & 8 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ DEFINE_DISPATCH(addr_stub);

// As P is a permutation matrix
// det(P) = 1 if it's an even permutation and det(P) = -1 if it's an odd permutation
Tensor lu_det_P(const Tensor& pivots) {
static Tensor lu_det_P(const Tensor& pivots) {
return (at::arange(1, pivots.size(-1) + 1, pivots.options()) != pivots)
.sum(-1, /*keepdim=*/false, /*dtype=*/at::kLong)
.fmod_(2)
Expand Down Expand Up @@ -1594,7 +1594,7 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T
});
}

void baddbmm_with_gemm_(const Tensor &result, const Tensor &mat1, const Tensor &mat2, const Scalar &beta_, const Scalar &alpha_) {
static void baddbmm_with_gemm_(const Tensor &result, const Tensor &mat1, const Tensor &mat2, const Scalar &beta_, const Scalar &alpha_) {
TORCH_INTERNAL_ASSERT(result.is_contiguous());

const auto result_sizes = result.sizes();
Expand Down Expand Up @@ -1766,7 +1766,7 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
return;
}

void conjugate_mutable_input_if_needed(const Tensor& self, bool conjugate) {
static void conjugate_mutable_input_if_needed(const Tensor& self, bool conjugate) {
if (conjugate) {
self.conj_physical_();
}
Expand Down Expand Up @@ -1823,7 +1823,7 @@ Tensor& vdot_out(const Tensor& self, const Tensor& other, Tensor& result) {
return result.fill_(self.vdot(other));
}

bool should_fold(const Tensor& tensor1, const Tensor& tensor2) {
static bool should_fold(const Tensor& tensor1, const Tensor& tensor2) {
// We check that we can fold the larger tensor into a matrix and dispatch to mm or mv rather than
// to bmm. We want to make sure we can do so without incurring in any extra copy
const auto tensor1_larger = tensor1.dim() >= tensor2.dim();
Expand Down Expand Up @@ -2678,7 +2678,7 @@ TORCH_IMPL_FUNC(linalg_vector_norm_out)(const Tensor& self, const Scalar& scalar
norm_stub(iter.device_type(), iter, ord);
}

void _linalg_matrix_norm_checks(const Tensor& A, std::vector<int64_t>& dim, optional<ScalarType> opt_dtype, bool low_precision) {
static void _linalg_matrix_norm_checks(const Tensor& A, std::vector<int64_t>& dim, optional<ScalarType> opt_dtype, bool low_precision) {
// A
at::native::checkIsMatrix(A, "linalg.matrix_norm");
at::native::checkFloatingOrComplex(A, "linalg.matrix_norm", /*low_precision*/low_precision);
Expand Down Expand Up @@ -2950,7 +2950,7 @@ Tensor& nuclear_norm_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tens


// This function helps to dispatch norm computations depending on 'ord' of variant type
Tensor _linalg_cond_helper(const Tensor& self, c10::variant<Scalar, c10::string_view> ord_variant) {
static Tensor _linalg_cond_helper(const Tensor& self, c10::variant<Scalar, c10::string_view> ord_variant) {
Tensor inverse, info;
std::tie(inverse, info) = at::linalg_inv_ex(self);
info.unsqueeze_(-1).unsqueeze_(-1);
Expand All @@ -2967,13 +2967,13 @@ Tensor _linalg_cond_helper(const Tensor& self, c10::variant<Scalar, c10::string_
}

// Return zero for each matrix in the batch
Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) {
static Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) {
auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2);
TensorOptions options = self.options().dtype(toRealValueType(self.scalar_type()));
return at::zeros(result_shape, options);
}

void _linalg_cond_check_ord(c10::variant<Scalar, c10::string_view> ord_variant) {
static void _linalg_cond_check_ord(c10::variant<Scalar, c10::string_view> ord_variant) {
if (ord_variant.index() == 0) {
Scalar* ord = c10::get_if<Scalar>(&ord_variant);
double abs_ord = std::abs(ord->toDouble());
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/LossNLL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ Tensor nll_loss_symint(const Tensor & self, const Tensor & target, const c10::op
}

// Duplicate of above code for non-symbolic ints. Kept for BC purposes and to minimize breakages.
Tensor nll_loss(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index) {
static Tensor nll_loss(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/LossNLL2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ Tensor nll_loss2d_symint(const Tensor & self, const Tensor & target, const c10::
}

// Duplicate of above code for non-symbolic ints. Kept for BC purposes and to minimize breakages.
Tensor nll_loss2d(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index) {
static Tensor nll_loss2d(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/MetaTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Tensor empty_meta_symint(
}

// Kept only for BC with XLA
Tensor empty_strided_meta(
static Tensor empty_strided_meta(
IntArrayRef size,
IntArrayRef stride,
c10::optional<ScalarType> dtype_opt,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ TORCH_IMPL_FUNC(slow_conv_transpose2d_structured_cpu)
dilation);
}

std::tuple<Tensor&, Tensor&, Tensor&> slow_conv_transpose2d_backward_out_cpu(const Tensor& grad_output,
static std::tuple<Tensor&, Tensor&, Tensor&> slow_conv_transpose2d_backward_out_cpu(const Tensor& grad_output,
const Tensor& input,
const Tensor& weight,
IntArrayRef kernel_size,
Expand Down

0 comments on commit 30ddba1

Please sign in to comment.