Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ int32_t aoti_torch_dtype_bfloat16() {
return 15; // PyTorch's bfloat16 dtype code
}

int32_t aoti_torch_dtype_int32() {
return 3; // PyTorch's int32 dtype code
}

int32_t aoti_torch_dtype_int64() {
return 4; // PyTorch's int64 dtype code
}
Expand Down
1 change: 1 addition & 0 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ int32_t aoti_torch_device_type_cpu();
int32_t aoti_torch_layout_strided();
int32_t aoti_torch_dtype_float32();
int32_t aoti_torch_dtype_bfloat16();
int32_t aoti_torch_dtype_int32();
int32_t aoti_torch_dtype_int64();

// Autograd mode functions
Expand Down
2 changes: 2 additions & 0 deletions backends/aoti/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
// Convert based on known PyTorch dtype codes (without CUDA-specific
// dependency)
switch (dtype) {
case 3: // PyTorch's int32 dtype code
return executorch::aten::ScalarType::Int;
case 4: // PyTorch's int64 dtype code
return executorch::aten::ScalarType::Long;
case 6: // PyTorch's float32 dtype code
Expand Down
14 changes: 13 additions & 1 deletion backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,20 @@ find_package_torch()
set(_aoti_cuda_sources
runtime/cuda_backend.cpp runtime/shims/memory.cpp
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
runtime/shims/cuda_guard.cpp
runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu
)
# Set default CUDA architectures if not specified (int4mm requires sm_80+)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES
"80;86;89;90"
CACHE STRING "CUDA architectures"
)
message(
STATUS
"CMAKE_CUDA_ARCHITECTURES not set, using default: 80;86;89;90 (Ampere+)"
)
message(STATUS " Override with: cmake -DCMAKE_CUDA_ARCHITECTURES=<arch> ...")
endif()
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
target_include_directories(
aoti_cuda
Expand Down
4 changes: 3 additions & 1 deletion backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from torch.nn.attention import SDPBackend

# exist fallback operators in et namespace;
supported_fallback_kernels: Dict[str, Any] = {}
supported_fallback_kernels: Dict[str, Any] = {
"at::_ops::_weight_int4pack_mm::call": None,
}

# required fallback kernels but not supported
missing_fallback_kernels: Set[str] = set()
Expand Down
3 changes: 3 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ runtime.cxx_library(
srcs = [
"guard.cpp",
"shims/cuda_guard.cpp",
"shims/int4mm.cu",
"shims/memory.cpp",
"shims/tensor_attribute.cpp",
],
headers = [
"guard.h",
"shims/cuda_guard.h",
"shims/int4mm.cuh",
"shims/int4mm.h",
"shims/memory.h",
"shims/tensor_attribute.h",
"utils.h",
Expand Down
57 changes: 57 additions & 0 deletions backends/cuda/runtime/shims/int4mm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cuda.h>
#include <cuda_runtime.h>

#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/cuda/runtime/shims/int4mm.h>
#include <executorch/backends/cuda/runtime/shims/int4mm.cuh>
#include <executorch/runtime/platform/log.h>

namespace executorch::backends::cuda {
#ifdef __cplusplus
extern "C" {
#endif

AOTITorchError aoti_torch_cuda__weight_int4pack_mm(
Tensor* self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should check whether self is bfloat16?

Tensor* mat2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check whether mat2 is int32

int64_t qGroupSize,
Tensor* qScaleAndZeros,
Tensor** ret0) {
// Validate input parameters first
ET_CHECK_OR_RETURN_ERROR(
self != nullptr,
InvalidArgument,
"aoti_torch_cuda__weight_int4pack_mm failed: self tensor is null");

ET_CHECK_OR_RETURN_ERROR(
mat2 != nullptr,
InvalidArgument,
"aoti_torch_cuda__weight_int4pack_mm failed: mat2 tensor is null");

ET_CHECK_OR_RETURN_ERROR(
qScaleAndZeros != nullptr,
InvalidArgument,
"aoti_torch_cuda__weight_int4pack_mm failed: qScaleAndZeros tensor is null");

ET_CHECK_OR_RETURN_ERROR(
ret0 != nullptr,
InvalidArgument,
"aoti_torch_cuda__weight_int4pack_mm failed: ret0 is null");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ET_CHECK_OR_RETURN_ERROR(
        qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || qGroupSize == 256,
        InvalidArgument,
        "aoti_torch_cuda__weight_int4pack_mm: qGroupSize must be 32/64/128/256, got %lld",
        static_cast<long long>(qGroupSize));

*ret0 = _weight_int4pack_mm_cuda(*self, *mat2, qGroupSize, *qScaleAndZeros);
ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR();
return Error::Ok;
}

#ifdef __cplusplus
}
#endif
} // namespace executorch::backends::cuda
Loading
Loading