Skip to content

Commit

Permalink
Update on "[ONNX] Remove unnecessary deepcopy on args in 'DynamoExport'"
Browse files Browse the repository at this point in the history
The comment is outdated. There should be no side-effects on `args` and `kwargs`.

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Jul 7, 2023
2 parents 964d6ee + 5636db6 commit 9194480
Show file tree
Hide file tree
Showing 74 changed files with 1,926 additions and 1,363 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/generate_binary_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def generate_wheels_matrix(
"pytorch_extra_install_requirements": "nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950
"nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cudnn-cu12==8.8.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | "
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " --source-in-ptx")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " --source-in-ptx")
endif()
endif(NOT MSVC)
endif(DEBUG_CUDA)


if(USE_FBGEMM)
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/Resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ const Tensor& _resize_(
ArrayRef<T> size,
c10::optional<MemoryFormat> optional_memory_format) {
auto* self_ = self.unsafeGetTensorImpl();
int64_t old_storage_nbytes = self_->unsafe_storage() ? self_->unsafe_storage().nbytes() : 0;
// NOLINTNEXTLINE(bugprone-argument-comment)
_resize_impl_<T>(self_, size, /*strides=*/c10::nullopt, true);
if (optional_memory_format.has_value()) {
Expand All @@ -250,6 +251,10 @@ const Tensor& _resize_(
memory_format);
self_->empty_tensor_restride(memory_format);
}
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
}
return self;
}

Expand Down
28 changes: 28 additions & 0 deletions aten/src/ATen/native/ResizeCommon.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
#pragma once

#include <ATen/core/Tensor.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/NamedTensorUtils.h>
#include <c10/util/irange.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#endif

namespace at::native {

template <typename T>
Expand Down Expand Up @@ -44,4 +51,25 @@ inline const Tensor& resize_named_tensor_(
optional_memory_format.value());
return self;
}

// For deterministic output, fill new elements that were added after a storage
// resize with NaN or MAX_INT. `old_storage_nbytes` is the size of the storage
// before the resize happened.
inline const Tensor& fill_resize_deterministic_(const Tensor& tensor, int64_t old_storage_nbytes) {
const at::Storage& storage = tensor.unsafeGetTensorImpl()->unsafe_storage();
int64_t new_storage_nbytes = storage.nbytes();
int64_t old_storage_numel = old_storage_nbytes / tensor.itemsize();
int64_t new_storage_numel = new_storage_nbytes / tensor.itemsize();
if (new_storage_numel > old_storage_numel) {
at::Tensor tensor_view = at::empty({}, at::TensorOptions().dtype(tensor.scalar_type()).device(tensor.device()));
tensor_view.set_(
storage,
/*storage_offset=*/old_storage_numel,
/*size=*/{new_storage_numel - old_storage_numel},
/*stride=*/{1});
at::native::fill_empty_deterministic_(tensor_view);
}
return tensor;
}

} // namespace at::native
19 changes: 19 additions & 0 deletions aten/src/ATen/native/TensorFactories.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/core/Tensor.h>
#include <ATen/EmptyTensor.h>
#include <ATen/TensorIterator.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>

#ifndef AT_PER_OPERATOR_HEADERS
Expand Down Expand Up @@ -96,6 +97,24 @@ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tens
}
}

// Called by `empty*` functions when deterministic algorithms are enabled to
// fill the tensor with NaN if it is floating point or complex type, or fill
// with max value if it is integer type
inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
if (tensor.is_floating_point() || tensor.is_complex()) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
});
} else {
AT_DISPATCH_INTEGRAL_TYPES_AND(
kBool, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
tensor.fill_(std::numeric_limits<scalar_t>::max());
});
}
return tensor;
}

// The ZeroTensor allocator ignores whatever allocation is requested and always
// gives you nullptr
struct ZeroTensorAllocator final : public at::Allocator {
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/cuda/Resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ const Tensor& resize_cuda_(
return resize_named_tensor_(self, size, optional_memory_format);
}
auto* self_ = self.unsafeGetTensorImpl();
int64_t old_storage_nbytes = self_->unsafe_storage() ? self_->unsafe_storage().nbytes() : 0;
resize_impl_cuda_(self_, size, /*strides=*/c10::nullopt);
if (optional_memory_format.has_value()) {
auto memory_format =
Expand All @@ -63,6 +64,10 @@ const Tensor& resize_cuda_(
memory_format);
self_->empty_tensor_restride(memory_format);
}
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
}
return self;
}
} // namespace at::native
6 changes: 6 additions & 0 deletions aten/src/ATen/native/mps/TensorFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ATen/mps/EmptyTensor.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/native/Resize.h>
#include <ATen/native/ResizeCommon.h>
#include <ATen/native/mps/Copy.h>
#include <ATen/native/mps/TensorFactory.h>
namespace at::native {
Expand Down Expand Up @@ -99,6 +100,7 @@ const Tensor& resize_mps_(
return resize_named_tensor_(self, size, optional_memory_format);
}
auto* self_ = self.unsafeGetTensorImpl();
int64_t old_storage_nbytes = self_->unsafe_storage() ? self_->unsafe_storage().nbytes() : 0;
resize_impl_mps_(self_, size, /*strides=*/c10::nullopt);
if (optional_memory_format.has_value()) {
auto memory_format =
Expand All @@ -109,6 +111,10 @@ const Tensor& resize_mps_(
memory_format);
self_->empty_tensor_restride(memory_format);
}
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
}
return self;
}

Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/quantized/cpu/TensorOperators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ const Tensor& quantized_resize_cpu_(
const Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because if storage is resized, new elements are uninitialized
globalContext().alertNotDeterministic("quantized_resize_cpu_");
TORCH_CHECK(
!optional_memory_format.has_value(),
"Unsupported memory format for quantized tensor resize ",
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/transformers/cuda/flash_attn/fmha.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

#include <ATen/native/transformers/cuda/flash_attn/fmha_utils.h>

namespace pytorch_fmha {

constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
Expand Down Expand Up @@ -209,3 +210,5 @@ void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const
void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);

void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream);

}; // namespace pytorch_fmha
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

#include <ATen/native/transformers/cuda/flash_attn/fmha_bwd_launch_template.h>

namespace pytorch_fmha {

void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
FP16_SWITCH(params.is_bf16, ([&] {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
}));
}
}

}; // namespace pytorch_fmha

0 comments on commit 9194480

Please sign in to comment.