Skip to content

Commit

Permalink
Update on "[export] ExportedProgram"
Browse files Browse the repository at this point in the history
cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx

[ghstack-poisoned]
  • Loading branch information
angelayi committed May 26, 2023
2 parents 39ba44c + 3110058 commit d57b238
Show file tree
Hide file tree
Showing 119 changed files with 3,347 additions and 1,611 deletions.
4 changes: 2 additions & 2 deletions .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ pip_install \
onnx==1.14.0

pip_install \
onnxruntime==1.14.0 \
onnxruntime==1.15.0 \
parameterized==0.8.1 \
pytest-cov==4.0.0 \
pytest-subtests==0.10.0 \
tabulate==0.9.0 \
transformers==4.25.1

# TODO: change this when onnx-script is on testPypi
pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@bf502680231e4b134a71f74e812c84ddd7efffbe"
pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@68adea42fb9b7353148e7ab289b76f9b89890e1c"

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
Expand Down
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/vision.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4125d3a02b15faf4b19767a91797320151ce8bc6
4a51822ca20027b6e03ec4fb582c31cc9545ba4e
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ docs/cpp/build
docs/cpp/source/api
docs/cpp/source/html/
docs/cpp/source/latex/
docs/source/compile/generated/
docs/source/generated/
log
usage_log.txt
Expand Down
14 changes: 12 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,22 @@ Follow the instructions for [installing PyTorch from source](https://github.com/

* If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below.

* When installing with `python setup.py develop` (in contrast to `python setup.py install`) you will symlink
the Python files from the current local source-tree into the Python install.
* When installing with `python setup.py develop` (in contrast to `python setup.py install`) Python runtime will use
the current local source-tree when importing `torch` package. (This is done by creating [`.egg-link`](https://wiki.python.org/moin/PythonPackagingTerminology#egg-link) file in `site-packages` folder)
This way you do not need to repeatedly install after modifying Python files (`.py`).
However, you would need to reinstall if you modify Python interface (`.pyi`, `.pyi.in`) or
non-Python files (`.cpp`, `.cc`, `.cu`, `.h`, ...).


One way to avoid running `python setup.py develop` every time one makes a change to C++/CUDA/ObjectiveC files on Linux/Mac,
is to create a symbolic link from `build` folder to `torch/lib`, for example, by issuing following:
```bash
pushd torch/lib; sh -c "ln -sf ../../build/lib/libtorch_cpu.* ."; popd
```
Afterwards rebuilding a library (for example to rebuild `libtorch_cpu.so` issue `ninja torch_cpu` from `build` folder),
would be sufficient to make change visible in `torch` package.


To reinstall, first uninstall all existing PyTorch installs. You may need to run `pip
uninstall torch` multiple times. You'll know `torch` is fully
uninstalled when you see `WARNING: Skipping torch as it is not
Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,13 @@ static_assert(
__m256i cmp = _mm256_cmpeq_epi16(values, _mm256_set1_epi16(0));
return _mm256_movemask_epi8(cmp);
}
static Vectorized<T> loadu(const void* ptr) {
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
}
static Vectorized<T> loadu(const void* ptr, int16_t count) {
static Vectorized<T> loadu(const void* ptr, int16_t count = size()) {
if (count == size())
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));

__at_align__ int16_t tmp_values[size()];
std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
return loadu(tmp_values);
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(tmp_values));
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,13 @@ static_assert(
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0));
}
static Vectorized<T> loadu(const void* ptr) {
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
}
static Vectorized<T> loadu(const void* ptr, int16_t count) {
static Vectorized<T> loadu(const void* ptr, int16_t count = size()) {
if (count == size())
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));

__at_align__ int16_t tmp_values[size()];
std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
return loadu(tmp_values);
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(tmp_values));
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
Expand Down
21 changes: 18 additions & 3 deletions aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
#include <ATen/native/DispatchStub.h>
#include <c10/core/Scalar.h>
#include <c10/util/TypeSafeSignMath.h>
#if defined(__CUDA_ARCH__)
#include <c10/cuda/CUDAMathCompat.h>
#define compat_copysign c10::cuda::compat::copysign
#elif defined(__HIPCC__)
#include <c10/hip/HIPMathCompat.h>
#define compat_copysign c10::hip::compat::copysign
#else
#include <c10/util/copysign.h>
#define compat_copysign c10::copysign
#endif


namespace at {
Expand Down Expand Up @@ -43,6 +52,12 @@ inline void sub_check(const TensorBase& self, const Scalar& scalar) {
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
}

#if defined(__CUDACC__) || defined(__HIPCC__)
#define HOST_DEVICE __host__ __device__
#else
#define HOST_DEVICE
#endif

// NOTE: [Floor Division in Python]
// Python's __floordiv__ operator is more complicated than just floor(a / b).
// It aims to maintain the property: a == (a // b) * b + remainder(a, b)
Expand All @@ -54,7 +69,7 @@ inline void sub_check(const TensorBase& self, const Scalar& scalar) {
// https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636

template <typename scalar_t>
inline scalar_t div_floor_floating(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ {
inline HOST_DEVICE scalar_t div_floor_floating(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ {
if (C10_UNLIKELY(b == 0)) {
// Divide by zero: return standard IEEE result
return a / b;
Expand All @@ -73,13 +88,13 @@ inline scalar_t div_floor_floating(scalar_t a, scalar_t b) __ubsan_ignore_float_
floordiv += scalar_t(1.0);
}
} else {
floordiv = c10::copysign(scalar_t(0), a / b);
floordiv = compat_copysign(scalar_t(0), a / b);
}
return floordiv;
}

template <typename scalar_t>
inline scalar_t div_floor_integer(scalar_t a, scalar_t b) {
inline HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) {
if (c10::signs_differ(a, b)) {
// Subtracts one from the results of truncation division if the
// divisor and dividend have different sign(bit)s and the remainder of
Expand Down
36 changes: 4 additions & 32 deletions aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,8 @@ void div_floor_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_floor_cuda", [&]() {
gpu_kernel_with_scalars(
iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
if (c10::signs_differ(a, b)) {
// Subtracts one from the results of truncation division if the
// divisor and dividend have different sign(bit)s and the
// remainder of the division is nonzero
const auto quot = a / b;
const auto rem = a % b;
return rem ? quot - 1 : quot;
}

return a / b;
});
return div_floor_integer(a, b);
});
});
} else if (iter.is_cpu_scalar(2)) {
// optimization for floating-point types: if the second operand is a CPU
Expand Down Expand Up @@ -79,26 +70,7 @@ void div_floor_kernel_cuda(TensorIteratorBase& iter) {
kHalf, kBFloat16, dtype, "div_floor_cuda", [&]() {
gpu_kernel_with_scalars(
iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
if (C10_UNLIKELY(b == 0)) {
return a / b;
}

auto mod = std::fmod(a, b);
auto div = (a - mod) / b;
if ((mod != 0) && (b < 0) != (mod < 0)) {
div -= scalar_t(1);
}

scalar_t floordiv;
if (div != 0) {
floordiv = std::floor(div);
if (div - floordiv > scalar_t(0.5)) {
floordiv += scalar_t(1.0);
}
} else {
floordiv = c10::cuda::compat::copysign(scalar_t(0), a / b);
}
return floordiv;
return div_floor_floating(a, b);
});
});
}
Expand All @@ -107,4 +79,4 @@ void div_floor_kernel_cuda(TensorIteratorBase& iter) {

REGISTER_DISPATCH(div_floor_stub, &binary_internal::div_floor_kernel_cuda);

} // namespace at::native
} // namespace at::native
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14132,7 +14132,7 @@
CUDA: _efficient_attention_forward
tags: nondeterministic_seeded

- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor rng_seed, Tensor rng_offset, int custom_mask_type, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor rng_seed, Tensor rng_offset, int custom_mask_type, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
device_check: NoCheck
variants: function
dispatch:
Expand Down
48 changes: 45 additions & 3 deletions aten/src/ATen/native/transformers/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ _efficient_attention_backward(
const at::Tensor& rng_seed_tensor, // seed using for generating random numbers for dropout
const at::Tensor& rng_offset_tensor, // offset into random number sequence
int64_t custom_mask_type,
const c10::optional<double> scale) {
const c10::optional<double> scale,
c10::optional <int64_t> num_splits_key) {
#if defined(USE_FLASH_ATTENTION)
if (!grad_out_.defined()) {
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
Expand All @@ -126,7 +127,7 @@ _efficient_attention_backward(
int64_t rng_seed = *rng_seed_tensor.data_ptr<int64_t>();
int64_t rng_offset = *rng_offset_tensor.data_ptr<int64_t>();

// ndim
// ndim
TORCH_CHECK(query.dim() == grad_out_.dim());
TORCH_CHECK(query.dim() == key.dim());
TORCH_CHECK(query.dim() == value.dim());
Expand Down Expand Up @@ -361,11 +362,51 @@ _efficient_attention_backward(
p.dropout_prob = dropout_p;
}

// Heuristic for finding optimal number of splits
auto parallelism_without_split_key =
p.getBlocksGrid().x * p.getBlocksGrid().y * p.getBlocksGrid().z;
p.num_splits_key = cutlass::ceil_div(p.num_keys, Kernel::kBlockSizeJ);
if (num_splits_key.has_value()) { // Skip heuristic, if user provided an explicit value
p.num_splits_key = std::max<int64_t>(p.num_splits_key, num_splits_key.value());
// If we already have enough parallelism, split-keys can help
// better use L2 cache.
// This is negligible when the seqlen is too small tho
if (parallelism_without_split_key >= 256 &&
p.num_keys <= 2 * Kernel::kBlockSizeJ) {
p.num_splits_key = 1;
}
// Increasing `split_keys` leads to using more gmem for temporary storage
// when we need a staging area for gK/gV. let's avoid that
if (Kernel::kNeedsAccumGradK || Kernel::kNeedsAccumGradV) {
p.num_splits_key = std::min(
int(p.num_splits_key), 200 / (p.num_batches * p.num_heads));
}
}
if (!Kernel::kEnableSplitKeys || p.num_splits_key < 1) {
p.num_splits_key = 1;
}

auto& ctx = at::globalContext();
if (ctx.deterministicAlgorithms()) {
if (ctx.deterministicAlgorithmsWarnOnly()) {
TORCH_WARN_ONCE(
"Memory Efficient attention defaults to a non-deterministic algorithm. ",
"To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False).");
} else {
TORCH_CHECK(
num_splits_key.value_or(1) <= 1,
"Using `num_splits_key > 1` makes the algorithm non-deterministic, and pytorch's deterministic mode is enabled");
p.num_splits_key = 1;
}
}
int64_t size_bytes = p.workspace_size();
if (size_bytes) {
workspace =
at::empty({size_bytes}, query.options().dtype(at::ScalarType::Byte));
p.workspace = (float*)workspace.data_ptr();
if (p.should_zero_workspace()) {
workspace.zero_();
}
}
Kernel::check_supported(p);

Expand Down Expand Up @@ -535,7 +576,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_efficient_att
seed_t,
offset_t,
static_cast<int64_t>(custom_mask_type),
scale);
scale,
c10::nullopt); // num_split_keys
return std::make_tuple(
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,17 @@ struct call_conditional<false, TA, TB> {
// The cheapest way to do it is just to broadcast it from lane 0
////////////////////////////////////////////////////////////////////////////////

CUTLASS_DEVICE int32_t warp_uniform(int32_t value) {
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
template <typename T>
CUTLASS_DEVICE T warp_uniform(T value) {
struct {
union {
T value;
uint32_t asInt;
};
} p;
p.value = value;
p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0);
return p.value;
}

template <typename T>
Expand Down

0 comments on commit d57b238

Please sign in to comment.