Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions backends/cadence/fusion_g3/operators/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ bool is_out_of_bounds(CTYPE_VAL val) {
}

ET_NODISCARD bool check_bounds(
KernelRuntimeContext& ctx,
const Scalar& val_scalar,
const ScalarType& val_type,
const ScalarType& out_type,
Expand Down Expand Up @@ -107,14 +108,14 @@ Tensor& clamp_out(
if (has_min) {
ET_KERNEL_CHECK(
ctx,
check_bounds(min_opt.value(), min_type, out_type, "minimum"),
check_bounds(ctx, min_opt.value(), min_type, out_type, "minimum"),
InvalidArgument,
out);
}
if (has_max) {
ET_KERNEL_CHECK(
ctx,
check_bounds(max_opt.value(), max_type, out_type, "maximum"),
check_bounds(ctx, max_opt.value(), max_type, out_type, "maximum"),
InvalidArgument,
out);
}
Expand Down
24 changes: 21 additions & 3 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,15 @@ - (NSString *)description {
auto const count = _tensor->numel();
os << "\n count: " << count << ",";
os << "\n scalars: [";
// Create a minimal context for error handling in ET_SWITCH
struct {
[[noreturn]] void fail(torch::executor::Error /* error */) {
ET_CHECK_MSG(false, "Unsupported dtype in description");
}
} ctx;
ET_SWITCH_REALHBBF16_TYPES(
static_cast<ScalarType>(_tensor->scalar_type()),
nullptr,
ctx,
"description",
CTYPE,
[&] {
Expand Down Expand Up @@ -488,9 +494,15 @@ - (instancetype)initWithScalars:(NSArray<NSNumber *> *)scalars
"Number of scalars does not match the shape");
std::vector<uint8_t> data;
data.resize(count * ExecuTorchSizeOfDataType(dataType));
// Create a minimal context for error handling in ET_SWITCH
struct {
[[noreturn]] void fail(torch::executor::Error /* error */) {
ET_CHECK_MSG(false, "Unsupported dtype in initWithScalars");
}
} ctx;
for (NSUInteger index = 0; index < count; ++index) {
ET_SWITCH_REALHBBF16_AND_UINT_TYPES(
static_cast<ScalarType>(dataType), nil, "initWithScalars", CTYPE, [&] {
static_cast<ScalarType>(dataType), ctx, "initWithScalars", CTYPE, [&] {
reinterpret_cast<CTYPE *>(data.data())[index] = utils::toType<CTYPE>(scalars[index]);
}
);
Expand Down Expand Up @@ -801,8 +813,14 @@ + (instancetype)fullTensorWithShape:(NSArray<NSNumber *> *)shape
dataType:(ExecuTorchDataType)dataType
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
Scalar fillValue;
// Create a minimal context for error handling in ET_SWITCH
struct {
[[noreturn]] void fail(torch::executor::Error /* error */) {
ET_CHECK_MSG(false, "Unsupported dtype in fullTensor");
}
} ctx;
ET_SWITCH_REALHBBF16_AND_UINT_TYPES(
static_cast<ScalarType>(dataType), nil, "fullTensor", CTYPE, [&] {
static_cast<ScalarType>(dataType), ctx, "fullTensor", CTYPE, [&] {
fillValue = utils::toType<CTYPE>(scalar);
}
);
Expand Down
10 changes: 9 additions & 1 deletion extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,20 @@ class ET_EXPERIMENTAL TextDecoderRunner {
const executorch::aten::Tensor& logits_tensor,
const float temperature = 0.0f) {
int32_t result = 0;

// Create a minimal context for error handling in ET_SWITCH
struct {
[[noreturn]] void fail(torch::executor::Error /* error */) {
ET_CHECK_MSG(false, "Unsupported dtype in logits_to_token");
}
} ctx;

ET_SWITCH_THREE_TYPES(
Float,
Half,
BFloat16,
logits_tensor.scalar_type(),
unused,
ctx,
"logits_to_token",
CTYPE,
[&]() {
Expand Down
10 changes: 9 additions & 1 deletion extension/tensor/tensor_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,15 @@ inline TensorPtr make_tensor_ptr(
runtime::canCast(deduced_type, type),
"Cannot cast deduced type to specified type.");
std::vector<uint8_t> casted_data(data.size() * runtime::elementSize(type));
ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "make_tensor_ptr", CTYPE, [&] {

// Create a minimal context for error handling in ET_SWITCH
struct {
[[noreturn]] void fail(torch::executor::Error /* error */) {
ET_CHECK_MSG(false, "Unsupported dtype in make_tensor_ptr");
}
} ctx;

ET_SWITCH_REALHBBF16_TYPES(type, ctx, "make_tensor_ptr", CTYPE, [&] {
std::transform(
data.begin(),
data.end(),
Expand Down
18 changes: 16 additions & 2 deletions extension/tensor/tensor_ptr_maker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ TensorPtr random_strided(
empty_strided(std::move(sizes), std::move(strides), type, dynamism);
std::default_random_engine gen{std::random_device{}()};

ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "random_strided", CTYPE, [&] {
// Create a minimal context for error handling in ET_SWITCH
struct {
[[noreturn]] void fail(torch::executor::Error /* error */) {
ET_CHECK_MSG(false, "Unsupported dtype in random_strided");
}
} ctx;

ET_SWITCH_REALHBBF16_TYPES(type, ctx, "random_strided", CTYPE, [&] {
std::generate_n(tensor->mutable_data_ptr<CTYPE>(), tensor->numel(), [&]() {
return static_cast<CTYPE>(distribution(gen));
});
Expand Down Expand Up @@ -124,7 +131,14 @@ TensorPtr full_strided(
executorch::aten::TensorShapeDynamism dynamism) {
auto tensor =
empty_strided(std::move(sizes), std::move(strides), type, dynamism);
ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "full_strided", CTYPE, [&] {
// Create a minimal context for error handling in ET_SWITCH
struct {
[[noreturn]] void fail(torch::executor::Error /* error */) {
ET_CHECK_MSG(false, "Unsupported data type in full_strided");
}
} ctx;

ET_SWITCH_REALHBBF16_TYPES(type, ctx, "full_strided", CTYPE, [&] {
CTYPE value;
ET_EXTRACT_SCALAR(fill_value, value);
std::fill(
Expand Down
12 changes: 6 additions & 6 deletions kernels/optimized/cpu/op_add_sub_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ Tensor& opt_add_sub_out_impl(
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
// Cannot apply the trick of -alpha here because alpha is Scalar without
// support for - operator. At least not right now.
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() -> void {
CTYPE alpha_val;
ET_KERNEL_CHECK_MSG(
ctx,
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
InvalidArgument,
out,
,
"Failed to extract scalar alpha.");
using Vec = at::vec::Vectorized<CTYPE>;
Vec alpha_val_vec(alpha_val);
Expand All @@ -164,13 +164,13 @@ Tensor& opt_add_sub_out_impl(
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return y - alpha_val_vec * x;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
} else {
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return x - alpha_val_vec * y;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
}
} else {
Expand All @@ -191,13 +191,13 @@ Tensor& opt_add_sub_out_impl(
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return y + alpha_val_vec * x;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
} else {
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return x + alpha_val_vec * y;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
}
}
Expand Down
4 changes: 2 additions & 2 deletions kernels/optimized/cpu/op_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ Tensor& opt_div_out(
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
auto div_lambda = [](auto x, auto y) { return y / x; };
return torch::executor::handle_broadcast_elementwise<CTYPE>(
torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, div_lambda, a, b, out, selected_optimized_path);
} else {
auto div_lambda = [](auto x, auto y) { return x / y; };
return torch::executor::handle_broadcast_elementwise<CTYPE>(
torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, div_lambda, a, b, out, selected_optimized_path);
}
});
Expand Down
2 changes: 1 addition & 1 deletion kernels/optimized/cpu/op_le.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Tensor& opt_le_tensor_out(
// Handle optimized broadcast cases
ET_SWITCH_REALB_TYPES(out_type, ctx, "le.Tensor_out", CTYPE, [&]() {
auto le_lambda = [](auto x, auto y) { return x.le(y); };
return torch::executor::handle_broadcast_elementwise<CTYPE>(
torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, le_lambda, a, b, out, selected_optimized_path);
});
} else {
Expand Down
4 changes: 2 additions & 2 deletions kernels/optimized/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ Tensor& opt_mul_out(

ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
auto mul_lambda = [](auto x, auto y) { return x * y; };
return torch::executor::handle_broadcast_elementwise<CTYPE>(
torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, mul_lambda, a, b, out, selected_optimized_path);
});
} else {
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
auto mul_lambda = [](auto x, auto y) { return x * y; };
return torch::executor::handle_broadcast_elementwise<CTYPE>(
torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, mul_lambda, a, b, out, selected_optimized_path);
});
}
Expand Down
5 changes: 3 additions & 2 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ bool is_out_of_bounds(CTYPE_CAST val_cast) {
}

ET_NODISCARD bool check_bounds(
KernelRuntimeContext& ctx,
const Scalar& val_scalar,
const torch::executor::native::ScalarType& val_type,
const torch::executor::native::ScalarType& out_type,
Expand Down Expand Up @@ -107,14 +108,14 @@ Tensor& clamp_out(
if (has_min) {
ET_KERNEL_CHECK(
ctx,
check_bounds(min_opt.value(), min_type, out_type, "minimum"),
check_bounds(ctx, min_opt.value(), min_type, out_type, "minimum"),
InvalidArgument,
out);
}
if (has_max) {
ET_KERNEL_CHECK(
ctx,
check_bounds(max_opt.value(), max_type, out_type, "maximum"),
check_bounds(ctx, max_opt.value(), max_type, out_type, "maximum"),
InvalidArgument,
out);
}
Expand Down
2 changes: 1 addition & 1 deletion kernels/portable/cpu/op_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ Tensor& convolution_out(
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
const auto load_bias = bias.has_value()
? utils::internal::get_load_to_compute_fn<CTYPE, name>(
bias.value(), utils::SupportedTensorDtypes::REALHBF16)
ctx, bias.value(), utils::SupportedTensorDtypes::REALHBF16)
: nullptr;
convolution_wrapper<CTYPE>(
in,
Expand Down
4 changes: 2 additions & 2 deletions kernels/portable/cpu/op_cumsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ Tensor& cumsum_out(
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "cumsum.out";

ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&]() {
const auto load_self =
utils::internal::get_load_to_compute_fn<CTYPE_OUT, op_name>(
self, utils::SupportedTensorDtypes::REALHBBF16);
ctx, self, utils::SupportedTensorDtypes::REALHBBF16);
cumsum_tensors<CTYPE_OUT>(self, load_self, dim, out);
});

Expand Down
2 changes: 1 addition & 1 deletion kernels/portable/cpu/op_fill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Tensor& fill_tensor_out(
static constexpr const char op_name[] = "fill.Tensor_out";

ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, op_name, CTYPE_A, [&] {
CTYPE_A b_casted;
CTYPE_A b_casted{};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this took me a minute -- we now return all-zeros output and signal a failure when there's an unsupported dtype instead of terminating the process. this seems fine.

ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, op_name, CTYPE_B, [&] {
CTYPE_B b_val;
ET_EXTRACT_SCALAR_TENSOR(b, b_val);
Expand Down
4 changes: 3 additions & 1 deletion kernels/portable/cpu/op_index_put.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Tensor& index_put_out(
namespace {

bool check_special_case_in_place_args(
KernelRuntimeContext& ctx,
Tensor& in,
TensorOptList indices,
const Tensor& values,
Expand Down Expand Up @@ -285,7 +286,8 @@ Tensor& index_put_(
size_t dim = 0;
ET_KERNEL_CHECK(
ctx,
check_special_case_in_place_args(in, indices, values, accumulate, &dim),
check_special_case_in_place_args(
ctx, in, indices, values, accumulate, &dim),
InvalidArgument,
in);

Expand Down
11 changes: 3 additions & 8 deletions kernels/portable/cpu/op_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,25 +104,20 @@ void scatter_value_helper(
} // namespace

Tensor& scatter_src_out(
KernelRuntimeContext& context,
KernelRuntimeContext& ctx,
const Tensor& in,
int64_t dim,
const Tensor& index,
const Tensor& src,
Tensor& out) {
(void)context;

ET_KERNEL_CHECK(
context,
ctx,
check_scatter_src_args(in, dim, index, src, out),
InvalidArgument,
out);

ET_KERNEL_CHECK(
context,
resize_tensor(out, in.sizes()) == Error::Ok,
InvalidArgument,
out);
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);

constexpr auto name = "scatter.src_out";

Expand Down
18 changes: 5 additions & 13 deletions kernels/portable/cpu/op_scatter_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,38 +52,30 @@ void scatter_add_helper(
} // namespace

Tensor& scatter_add_out(
KernelRuntimeContext& context,
KernelRuntimeContext& ctx,
const Tensor& self,
int64_t dim,
const Tensor& index,
const Tensor& src,
Tensor& out) {
(void)context;

ET_KERNEL_CHECK(
context,
ctx,
check_scatter_add_args(self, dim, index, src, out),
InvalidArgument,
out);

ET_KERNEL_CHECK(
context,
tensors_have_same_dim_order(self, src, out),
InvalidArgument,
out);
ctx, tensors_have_same_dim_order(self, src, out), InvalidArgument, out);

ET_KERNEL_CHECK(
context, tensor_is_default_dim_order(index), InvalidArgument, out);
ctx, tensor_is_default_dim_order(index), InvalidArgument, out);

if (dim < 0) {
dim += nonzero_dim(self);
}

ET_KERNEL_CHECK(
context,
resize_tensor(out, self.sizes()) == Error::Ok,
InvalidArgument,
out);
ctx, resize_tensor(out, self.sizes()) == Error::Ok, InvalidArgument, out);

ScalarType self_type = self.scalar_type();

Expand Down
Loading
Loading