Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions backends/apple/metal/runtime/shims/tensor_attribute.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/apple/metal/runtime/shims/tensor_attribute.h>
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
#include <iostream>

namespace executorch {
namespace backends {
namespace metal {

extern "C" {

// Metal-specific device type constant
__attribute__((__visibility__("default"))) int32_t
aoti_torch_device_type_mps() {
// Let's use 2 for MPS
return 2;
}

// Override aoti_torch_get_device_type to return MPS device type
AOTITorchError aoti_torch_get_device_type(
AOTITensorHandle tensor,
int32_t* ret_device_type) {
*ret_device_type = aoti_torch_device_type_mps();
return Error::Ok;
}

} // extern "C"

} // namespace metal
} // namespace backends
} // namespace executorch
32 changes: 32 additions & 0 deletions backends/apple/metal/runtime/shims/tensor_attribute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/apple/metal/runtime/shims/types.h>

namespace executorch {
namespace backends {
namespace metal {

extern "C" {

// Metal-specific device type function
int32_t aoti_torch_device_type_mps();

// Override aoti_torch_get_device_type to return MPS device type
AOTITorchError aoti_torch_get_device_type(
AOTITensorHandle tensor,
int32_t* ret_device_type);

} // extern "C"

} // namespace metal
} // namespace backends
} // namespace executorch
35 changes: 35 additions & 0 deletions backends/apple/metal/runtime/shims/types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/error.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace metal {

// Common using declarations for ExecutorTorch types
using executorch::runtime::Error;
using executorch::runtime::etensor::Tensor;

extern "C" {

// Common AOTI type aliases
// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility
using AOTITensorHandle = Tensor*;
using AOTIRuntimeError = Error;
using AOTITorchError = Error;

} // extern "C"

} // namespace metal
} // namespace backends
} // namespace executorch
92 changes: 92 additions & 0 deletions backends/apple/metal/runtime/shims/utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/apple/metal/runtime/shims/utils.h>
#include <executorch/runtime/platform/log.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace metal {

extern "C" {

// Helper function to check if a dtype is supported in Metal backend
bool is_dtype_supported_in_et_metal(int32_t dtype) {
switch (dtype) {
case static_cast<int32_t>(SupportedDTypes::INT64):
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
return true;
default:
return false;
}
}

// Metal-specific dtype validation utility function
AOTITorchError validate_dtype(int32_t dtype) {
if (is_dtype_supported_in_et_metal(dtype)) {
return Error::Ok;
}

ET_LOG(
Error,
"Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)",
dtype,
static_cast<int32_t>(SupportedDTypes::INT64),
static_cast<int32_t>(SupportedDTypes::FLOAT32),
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
return Error::InvalidArgument;
}

} // extern "C"

// Utility function to convert sizes pointer to vector
std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
int64_t ndim,
const int64_t* sizes_ptr) {
std::vector<executorch::aten::SizesType> sizes(ndim);
for (int i = 0; i < ndim; i++) {
sizes[i] = static_cast<executorch::aten::SizesType>(sizes_ptr[i]);
}
return sizes;
}

// Utility function to convert strides pointer to vector or calculate from sizes
std::vector<executorch::aten::StridesType> convert_strides_to_vector(
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr) {
std::vector<executorch::aten::StridesType> strides(ndim);

if (strides_ptr != nullptr) {
// Use provided strides.
for (int64_t i = 0; i < ndim; i++) {
strides[i] = static_cast<executorch::aten::StridesType>(strides_ptr[i]);
}
} else {
// Calculate strides from sizes.
if (ndim > 0) {
strides[ndim - 1] = static_cast<executorch::aten::StridesType>(
1); // Last dimension has stride 1
for (int64_t i = ndim - 2; i >= 0; i--) {
if (sizes_ptr[i + 1] == 0) {
strides[i] = strides[i + 1]; // Copy stride when size is 0
} else {
strides[i] = static_cast<executorch::aten::StridesType>(
static_cast<int64_t>(strides[i + 1]) * sizes_ptr[i + 1]);
}
}
}
}
return strides;
}

} // namespace metal
} // namespace backends
} // namespace executorch
74 changes: 74 additions & 0 deletions backends/apple/metal/runtime/shims/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/apple/metal/runtime/shims/types.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace metal {

// Enum for supported data types in et-metal backend
enum class SupportedDTypes : int32_t {
// UINT8 = 0, // PyTorch's uint8 dtype code
// INT8 = 1, // PyTorch's int8 dtype code
// INT16 = 2, // PyTorch's int16 dtype code
// INT32 = 3, // PyTorch's int32 dtype code
INT64 = 4, // PyTorch's int64 dtype code
// FLOAT16 = 5, // PyTorch's float16 dtype code
FLOAT32 = 6, // PyTorch's float32 dtype code
// FLOAT64 = 7, // PyTorch's float64 dtype code
// BOOL = 11, // PyTorch's bool dtype code
BFLOAT16 = 15 // PyTorch's bfloat16 dtype code
};

extern "C" {

// Helper function to check if a dtype is supported in Metal backend
bool is_dtype_supported_in_et_metal(int32_t dtype);

// Metal-specific dtype validation utility function
AOTITorchError validate_dtype(int32_t dtype);

} // extern "C"

// Utility function to convert sizes pointer to vector
std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
int64_t ndim,
const int64_t* sizes_ptr);

// Utility function to convert strides pointer to vector or calculate from sizes
std::vector<executorch::aten::StridesType> convert_strides_to_vector(
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr);

// Check if tensor is in contiguous memory format (NCHW for 4D tensors)
// Contiguous format means strides decrease from left to right:
// For NCHW: strides = [C*H*W, H*W, W, 1]
inline bool is_contiguous_tensor(
std::vector<executorch::aten::SizesType> sizes,
std::vector<executorch::aten::StridesType> strides) {
int64_t ndim = static_cast<int64_t>(strides.size());
int64_t expected_stride = 1;
for (int64_t i = ndim - 1; i >= 0; i--) {
if (strides[i] != expected_stride) {
return false;
}
expected_stride *= sizes[i];
}
return true;
}

} // namespace metal
} // namespace backends
} // namespace executorch
Loading