Skip to content

Commit

Permalink
Update on "[Inductor] Add triton.autotune support for user defined tr…
Browse files Browse the repository at this point in the history
…iton kernels with constant/simple grids"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
  • Loading branch information
oulgen committed Oct 28, 2023
2 parents b3d02e9 + 02301e8 commit 356c5b1
Show file tree
Hide file tree
Showing 154 changed files with 4,781 additions and 2,629 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pip_install coloredlogs packaging
retry pip_install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ --no-cache-dir --no-input ort-nightly==1.17.0.dev20231005006

pip_install -i https://test.pypi.org/simple/ onnx==1.15.0rc2
pip_install onnxscript==0.1.0.dev20231006 --no-deps
pip_install onnxscript==0.1.0.dev20231025 --no-deps

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
Expand Down
3 changes: 3 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,13 @@ code = 'MYPYNOFOLLOW'
include_patterns = [
'torch/_dynamo/allowed_functions.py',
'torch/_dynamo/codegen.py',
'torch/_dynamo/compiled_autograd.py',
'torch/_dynamo/eval_frame.py',
'torch/_dynamo/exc.py',
'torch/_dynamo/funcname_cache.py',
'torch/_dynamo/convert_frame.py',
'torch/_dynamo/symbolic_convert.py',
'torch/_dynamo/testing.py',
'torch/_dynamo/types.py',
'torch/_dynamo/output_graph.py',
'torch/_dynamo/guards.py',
Expand Down
10 changes: 8 additions & 2 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- [Release Candidate Storage](#release-candidate-storage)
- [Release Candidate health validation](#release-candidate-health-validation)
- [Cherry Picking Fixes](#cherry-picking-fixes)
- [Cherry Picking Reverts](#cherry-picking-reverts)
- [Promoting RCs to Stable](#promoting-rcs-to-stable)
- [Additional Steps to prepare for release day](#additional-steps-to-prepare-for-release-day)
- [Modify release matrix](#modify-release-matrix)
Expand Down Expand Up @@ -132,8 +133,8 @@ them:
* Example: https://github.com/pytorch/pytorch/pull/77983 and https://github.com/pytorch/pytorch/pull/77986
* A release branches should also be created in [`pytorch/xla`](https://github.com/pytorch/xla) and [`pytorch/builder`](https://github.com/pytorch/builder) repos and pinned in `pytorch/pytorch`
* Example: https://github.com/pytorch/pytorch/pull/86290 and https://github.com/pytorch/pytorch/pull/90506
* Update branch used in composite actions from trunk to release (for example, can be done by running `for i in .github/workflows/*.yml; do sed -i -e s#@master#@release/2.0# $i; done`
* Example: https://github.com/pytorch/pytorch/commit/51b42d98d696a9a474bc69f9a4c755058809542f
* Update branch used in composite actions from trunk to release (for example, can be done by running `for i in .github/workflows/*.yml; do sed -i -e s#@main#@release/2.0# $i; done`
* Example: https://github.com/pytorch/pytorch/commit/17f400404f2ca07ea5ac864428e3d08149de2304

These are examples of changes that should be made to the *default* branch after a release branch is cut

Expand Down Expand Up @@ -211,6 +212,11 @@ Please also make sure to add milestone target to the PR/issue, especially if it

**NOTE**: The cherry pick process is not an invitation to add new features, it is mainly there to fix regressions

### Cherry Picking Reverts

If PR that has been cherry-picked into release branch has been reverted, it's cherry-pick must be reverted as well.

Reverts for changes that was committed into the main branch prior to the branch cut, must be propagated into release branch as well.

## Promoting RCs to Stable

Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/cpu/vec/vec_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ struct is_floating_point:
std::integral_constant<bool,
std::is_floating_point<T>::value ||
std::is_same<T, at::Half>::value ||
std::is_same<T, at::BFloat16>::value ||
std::is_same<T, at::Float8_e5m2>::value ||
std::is_same<T, at::Float8_e4m3fn>::value> {
std::is_same<T, at::BFloat16>::value> {
};

template<typename T>
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSGeneratorImpl.mm
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Generator createMPSGenerator(uint64_t seed_val) {
} // namespace mps::detail

MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in)
: c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)},
: c10::GeneratorImpl{Device(DeviceType::MPS, 0), DispatchKeySet(c10::DispatchKey::MPS)},
data_({.seed = seed_in}),
engine_(seed_in, 0, 0) {}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSStream.mm
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ @interface MPSGraphExecutionDescriptor ()

MPSStream* MPSStreamImpl::getInstance() {
if (_stream == nullptr) {
_stream = new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS), 0));
_stream = new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS, 0), 0));
}
return _stream;
}
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/cuda/ReplicationPadding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ void replication_pad2d_backward_out_cuda_template(
}
gradInput.zero_();

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "replication_pad2d_backward_cuda", [&] {

auto gradInput_ = gradInput;
Expand Down Expand Up @@ -383,7 +383,7 @@ void replication_pad3d_backward_out_cuda_template(
}
gradInput.zero_();

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "replication_pad3d_backward_cuda", [&] {
auto gradInput_ = gradInput;
auto gradOutput_ = gradOutput;
Expand Down Expand Up @@ -437,7 +437,7 @@ TORCH_IMPL_FUNC(replication_pad1d_out_cuda) (
return;
}

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kHalf,
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
input.scalar_type(), "replication_pad1d_cuda", [&] {
at::Tensor input_ = input;
at::Tensor output_ = output;
Expand Down Expand Up @@ -499,7 +499,7 @@ TORCH_IMPL_FUNC(replication_pad1d_backward_out_cuda) (
}
gradInput.zero_();
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "replication_pad1d_backward_cuda", [&] {
auto gradInput_ = gradInput;
Expand Down Expand Up @@ -543,7 +543,7 @@ TORCH_IMPL_FUNC(replication_pad2d_out_cuda) (
// const auto padR = paddingSize[1]; // This padding is ignored here
const auto padT = paddingSize[2];
// const auto padB = paddingSize[3]; // This padding is ignored here
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kHalf,
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
input.scalar_type(), "replication_pad2d_cuda", [&] {
at::Tensor input_ = input;
at::Tensor output_ = output;
Expand Down Expand Up @@ -635,7 +635,7 @@ TORCH_IMPL_FUNC(replication_pad3d_out_cuda) (
return;
}
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kHalf,
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
input.scalar_type(), "replication_pad3d_cuda", [&] {
at::Tensor input_ = input;
at::Tensor output_ = output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const boo
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
Expand All @@ -170,6 +173,9 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const boo
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
Expand Down Expand Up @@ -213,6 +219,9 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// if (params.h == params.h_k) {
Expand Down Expand Up @@ -240,6 +249,9 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// if (params.h == params.h_k) {
Expand Down Expand Up @@ -276,6 +288,9 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bo
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
Expand All @@ -293,6 +308,9 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bo
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 136 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
Expand All @@ -318,6 +336,9 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bo
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 176 * 1024) { // H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
Expand Down Expand Up @@ -327,6 +330,9 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,11 @@ struct AttentionKernel {
// 15/16th of tensor core compute In that case :
// - we only launch kernels for head_id % kQueriesPerBlock == 0
// - we iterate over heads instead of queries (strideM = strideH)
if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) {
if (head_id % kQueriesPerBlock != 0)
if (num_queries == 1 && k_strideH == 0 && v_strideH == 0 &&
logsumexp_ptr == nullptr) {
if (head_id % kQueriesPerBlock != 0) {
return false;
}
q_strideM = q_strideH;
num_queries = num_heads;
num_heads = 1; // unused but here for intent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ binary_op_tensor:
- NAME: pow
IS_DIV: 0
OPERATOR: pow(X, Y)
- NAME: floor_divide
IS_DIV: 1
OPERATOR: floor(X / Y)

binary_op_tensor_inplace:
parameter_names_with_default_values:
Expand All @@ -59,3 +62,6 @@ binary_op_tensor_inplace:
- NAME: pow_
IS_DIV: 0
OPERATOR: pow(X, Y)
- NAME: floor_divide_
IS_DIV: 1
OPERATOR: floor(X / Y)
16 changes: 16 additions & 0 deletions aten/src/ATen/native/vulkan/ops/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,16 @@ Tensor& floor_divide_scalar_(Tensor& self, const Scalar& other) {
VK_KERNEL(floor_mul_scalar_));
}

Tensor floor_divide_tensor(const Tensor& self, const Tensor& other) {
return binary_op_tensor(
self, other, c10::optional<Scalar>(), VK_KERNEL(floor_divide));
}

Tensor& floor_divide_tensor_(Tensor& self, const Tensor& other_arg) {
return binary_op_tensor_(
self, other_arg, c10::optional<Scalar>(), VK_KERNEL(floor_divide_));
}

#ifdef USE_VULKAN_API

TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
Expand Down Expand Up @@ -572,6 +582,12 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
m.impl(
TORCH_SELECTIVE_NAME("aten::floor_divide_.Scalar"),
TORCH_FN(floor_divide_scalar_));
m.impl(
TORCH_SELECTIVE_NAME("aten::floor_divide"),
TORCH_FN(floor_divide_tensor));
m.impl(
TORCH_SELECTIVE_NAME("aten::floor_divide_.Tensor"),
TORCH_FN(floor_divide_tensor_));
}

#endif /* USE_VULKAN_API */
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ Tensor sum_dim_IntList(
Tensor sum(const Tensor& self, const c10::optional<ScalarType> dtype) {
std::vector<int64_t> dims;
for (int64_t d = 0; d < self.dim(); d++) {
// If any dimension has zero elements, we will shortcut to a zero-dim.
if (self.size(d) == 0) {
return self.new_zeros({}, at::device(at::kVulkan).dtype(self.dtype()));
}

dims.push_back(d);
}

Expand Down
76 changes: 76 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Var.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include <ATen/native/vulkan/ops/Common.h>
#include <ATen/native/vulkan/ops/Utils.h>
#include <torch/library.h>

namespace at {
namespace native {
namespace vulkan {
namespace ops {
namespace {

using namespace api::utils;

Tensor var_dim_IntList(
const at::Tensor& self_arg,
const OptionalIntArrayRef opt_dim,
bool unbiased = true, // correction=1 in version 2.0
bool keepdim = false) {
TORCH_CHECK(
self_arg.dim() >= 2 && self_arg.dim() <= 4,
"Vulkan var.dim_IntList only supports 2d, 3d, 4d tensors as input!");

TORCH_CHECK(
opt_dim.has_value(), "Vulkan var without a dim arg is not implemented");

const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan();

std::set<int64_t> dims_set;
if (opt_dim.has_value()) {
int sample_size = 1;
auto dims = opt_dim.value();

for (const auto& d : dims) {
TORCH_CHECK(d >= -self.dim() || d < self.dim(), "Dimension out of range");

int64_t dim_normalized = utils::normalize(d, self.dim());
if (dims_set.find(dim_normalized) != dims_set.end()) {
TORCH_CHECK(
false,
"dim ",
dim_normalized,
" appears multiple times in the list of dims")
}
dims_set.insert(dim_normalized);

sample_size *= self.sizes().vec()[dim_normalized];
}

at::Tensor self_mean = self.mean(opt_dim, true);
at::Tensor self_minus_mean = self.sub(self_mean);
// We write `self_minus_mean.mul(self_minus_mean)` instead of
// `self.sub(self_mean).pow(2)` because Vulkan driver on Android doesn't
// support negative input: "The result is undefined if x<0 or if x=0 and
// y≤0" see https://registry.khronos.org/OpenGL-Refpages/gl4/html/pow.xhtml
at::Tensor output =
self_minus_mean.mul(self_minus_mean).mean(opt_dim, keepdim);
if (unbiased == true) {
output = output.mul(sample_size * 1.0 / (sample_size - 1));
}
return output;
}
return self;
}

#ifdef USE_VULKAN_API

TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
m.impl(TORCH_SELECTIVE_NAME("aten::var.dim"), TORCH_FN(var_dim_IntList));
}

#endif /* USE_VULKAN_API */

} // namespace
} // namespace ops
} // namespace vulkan
} // namespace native
} // namespace at

0 comments on commit 356c5b1

Please sign in to comment.