Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ggml/src/ggml-openvino/ggml-decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1252,20 +1252,23 @@ std::string GgmlOvDecoder::compute_op_type(const ggml_tensor * node) {
{GGML_OP_GET_ROWS, "GGML_OP_GET_ROWS" },
{GGML_OP_MUL, "GGML_OP_MUL" },
{GGML_OP_MUL_MAT, "GGML_OP_MUL_MAT" },
{GGML_OP_MUL_MAT_ID, "GGML_OP_MUL_MAT_ID" },
{GGML_OP_PERMUTE, "GGML_OP_PERMUTE" },
{GGML_OP_RESHAPE, "GGML_OP_RESHAPE" },
{GGML_OP_RMS_NORM, "GGML_OP_RMS_NORM" },
{GGML_OP_NORM, "GGML_OP_NORM" },
{GGML_OP_ROPE, "GGML_OP_ROPE" },
{GGML_OP_SCALE, "GGML_OP_SCALE" },
{GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX" },
{GGML_OP_SUM_ROWS, "GGML_OP_SUM_ROWS" },
{GGML_OP_SUB, "GGML_OP_SUB" },
{GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE" },
{GGML_OP_VIEW, "GGML_OP_VIEW" },
{GGML_OP_SET_ROWS, "GGML_OP_SET_ROWS" },
{GGML_OP_CPY, "GGML_OP_CPY" },
{GGML_OP_FLASH_ATTN_EXT, "GGML_OP_FLASH_ATTN_EXT" },
{GGML_OP_L2_NORM, "GGML_OP_L2_NORM" },
{GGML_OP_CLAMP, "GGML_OP_CLAMP" },
{GGML_OP_PAD, "GGML_OP_PAD" },
{GGML_OP_SSM_CONV, "GGML_OP_SSM_CONV" },
{GGML_OP_GATED_DELTA_NET, "GGML_OP_GATED_DELTA_NET"}
Expand Down
56 changes: 56 additions & 0 deletions ggml/src/ggml-openvino/ggml-openvino.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,45 @@ static bool is_supported_flash_attn_pattern(const ggml_tensor * op) {
return true;
}

static bool checked_mul_size(size_t a, size_t b, size_t & out) {
if (a == 0 || b == 0) {
out = 0;
return true;
}
if (a > SIZE_MAX / b) {
return false;
}
out = a * b;
return true;
}

static bool mul_mat_id_requires_large_tmp(const ggml_tensor * op) {
const ggml_tensor * as = op->src[0];
const ggml_tensor * ids = op->src[2];
if (as == nullptr || ids == nullptr) {
return true;
}

// The current OpenVINO translation materializes selected expert weights with
// shape [n_tokens, n_used, rows, k]. Skip cases that would create a very
// large temporary on GPU and let the scheduler fall back instead.
size_t tmp_elems = 1;
if (!checked_mul_size(tmp_elems, static_cast<size_t>(ids->ne[1]), tmp_elems) ||
!checked_mul_size(tmp_elems, static_cast<size_t>(ids->ne[0]), tmp_elems) ||
!checked_mul_size(tmp_elems, static_cast<size_t>(as->ne[1]), tmp_elems) ||
!checked_mul_size(tmp_elems, static_cast<size_t>(as->ne[0]), tmp_elems)) {
return true;
}

size_t tmp_bytes = 0;
if (!checked_mul_size(tmp_elems, sizeof(float), tmp_bytes)) {
return true;
}

static constexpr size_t mul_mat_id_tmp_limit = 1ULL << 30; // 1 GiB
return tmp_bytes > mul_mat_id_tmp_limit;
}

static bool is_op_unsupported_case(const ggml_tensor * op) {
switch (op->op) {
case GGML_OP_GET_ROWS:
Expand Down Expand Up @@ -830,6 +869,13 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
}
break;
}
case GGML_OP_SUM_ROWS: {
// if the input is PERMUTE skip
if (op->src[0]->op == GGML_OP_PERMUTE) {
return true;
}
break;
}
case GGML_OP_FLASH_ATTN_EXT: {
if (op->src[4] != nullptr) {
// GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n");
Expand Down Expand Up @@ -889,6 +935,12 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
}
break;
}
case GGML_OP_MUL_MAT_ID: {
if (mul_mat_id_requires_large_tmp(op)) {
return true;
}
break;
}
case GGML_OP_ROPE: {
const int32_t * op_params = op->op_params;
const int n_dims = op_params[1];
Expand Down Expand Up @@ -953,8 +1005,10 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con

static const std::set<ggml_op> supported_ops{GGML_OP_NONE,
GGML_OP_ADD,
GGML_OP_DIV,
GGML_OP_MUL,
GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
GGML_OP_VIEW,
GGML_OP_CONT,
GGML_OP_RESHAPE,
Expand All @@ -970,6 +1024,8 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
GGML_OP_FLASH_ATTN_EXT,
GGML_OP_CPY,
GGML_OP_L2_NORM,
GGML_OP_SUM_ROWS,
GGML_OP_CLAMP,
GGML_OP_PAD,
GGML_OP_SSM_CONV,
GGML_OP_GATED_DELTA_NET};
Expand Down
33 changes: 33 additions & 0 deletions ggml/src/ggml-openvino/openvino/op/clamp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "../node_context.h"
#include "../op_table.h"
#include "../utils.h"

#include <cstring>
#include <openvino/op/clamp.hpp>

namespace ov {
namespace frontend {
namespace ggml {
namespace op {

OutputVector translate_clamp(const NodeContext & context) {
num_inputs_check(context, 1, 1);

auto input = process_view_input_new(context, 0);

const int32_t * op_params = context.get_output_op_params();
FRONT_END_CHECK_IMPLEMENTED(op_params != nullptr, "CLAMP requires output op params");

float min;
float max;
std::memcpy(&min, reinterpret_cast<const float *>(op_params) + 0, sizeof(float));
std::memcpy(&max, reinterpret_cast<const float *>(op_params) + 1, sizeof(float));

auto res = std::make_shared<ov::op::v0::Clamp>(input, min, max);
return rename_outputs_with_suffix({res}, context.get_name());
}

} // namespace op
} // namespace ggml
} // namespace frontend
} // namespace ov
93 changes: 93 additions & 0 deletions ggml/src/ggml-openvino/openvino/op/div.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include "../node_context.h"
#include "../op_table.h"
#include "../utils.h"

#include <memory>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/divide.hpp>
#include <openvino/op/shape_of.hpp>
#include <openvino/op/tile.hpp>
#include <vector>

namespace ov {
namespace frontend {
namespace ggml {
namespace op {

namespace {

ov::Output<ov::Node> repeat_input_to_match(const NodeContext & context,
const ov::Output<ov::Node> & input,
const ov::Output<ov::Node> & target,
size_t input_index) {
const auto input_shape = context.get_input_shape(input_index);
const auto target_shape = context.get_input_shape(0);

if (input_shape == target_shape) {
return input;
}

if (input_shape.rank().is_static() && target_shape.rank().is_static()) {
const auto rank = static_cast<size_t>(input_shape.rank().get_length());
std::vector<int64_t> repeats(rank, 1);
bool needs_repeat = false;

for (size_t axis = 0; axis < rank; ++axis) {
FRONT_END_OP_CONVERSION_CHECK(input_shape[axis].is_static() && target_shape[axis].is_static(),
"DIV repeat requires static dimensions on both inputs");

const int64_t input_dim = input_shape[axis].get_length();
const int64_t target_dim = target_shape[axis].get_length();

FRONT_END_OP_CONVERSION_CHECK(input_dim > 0 && target_dim > 0 && target_dim % input_dim == 0,
"DIV input shape ", input_shape, " cannot repeat to match ", target_shape);

repeats[axis] = target_dim / input_dim;
needs_repeat = needs_repeat || repeats[axis] != 1;
}

if (!needs_repeat) {
return input;
}

auto repeats_node = ov::op::v0::Constant::create(ov::element::i64, {repeats.size()}, repeats);
return std::make_shared<ov::op::v0::Tile>(input, repeats_node);
}

auto input_shape_node = std::make_shared<ov::op::v3::ShapeOf>(input, ov::element::i64);
auto target_shape_node = std::make_shared<ov::op::v3::ShapeOf>(target, ov::element::i64);
auto repeats_node = std::make_shared<ov::op::v1::Divide>(target_shape_node, input_shape_node);
return std::make_shared<ov::op::v0::Tile>(input, repeats_node);
}

} // namespace

OutputVector translate_div(const NodeContext & context) {
num_inputs_check(context, 2, 2);

auto input_0 = process_view_input_new(context, 0);
auto input_1 = process_view_input_new(context, 1);
input_1 = repeat_input_to_match(context, input_1, input_0, 1);

const auto output_type = context.get_output_type();
const bool use_f32_compute = input_0.get_element_type() != ov::element::f32 ||
input_1.get_element_type() != ov::element::f32 ||
output_type != ov::element::f32;

if (use_f32_compute) {
input_0 = std::make_shared<ov::op::v0::Convert>(input_0, ov::element::f32);
input_1 = std::make_shared<ov::op::v0::Convert>(input_1, ov::element::f32);
}

ov::Output<ov::Node> res = std::make_shared<ov::op::v1::Divide>(input_0, input_1);
if (res.get_element_type() != output_type) {
res = std::make_shared<ov::op::v0::Convert>(res, output_type);
}
return rename_outputs_with_suffix({res}, context.get_name());
}

} // namespace op
} // namespace ggml
} // namespace frontend
} // namespace ov
79 changes: 79 additions & 0 deletions ggml/src/ggml-openvino/openvino/op/mul_mat_id.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include "../node_context.h"
#include "../op_table.h"
#include "../utils.h"

#include <memory>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/matmul.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/shape_of.hpp>
#include <openvino/op/squeeze.hpp>
#include <openvino/op/unsqueeze.hpp>

namespace ov {
namespace frontend {
namespace ggml {
namespace op {

OutputVector translate_mul_mat_id(const NodeContext & context) {
num_inputs_check(context, 3, 3);

auto expert_weights = process_view_input_new(context, 0);
auto activations = process_view_input_new(context, 1);
auto ids = process_view_input_new(context, 2);

// OpenVINO sees GGML tensors in reversed dimension order:
// weights: [1, n_expert, m, k]
// activations: [1, n_tokens, n_used_or_1, k]
// ids: [1, 1, n_tokens, n_used]
auto squeeze_weights_axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto squeeze_acts_axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto squeeze_ids_axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1});

expert_weights = std::make_shared<ov::op::v0::Squeeze>(expert_weights, squeeze_weights_axes);
activations = std::make_shared<ov::op::v0::Squeeze>(activations, squeeze_acts_axes);
ids = std::make_shared<ov::op::v0::Squeeze>(ids, squeeze_ids_axes);

if (ids.get_element_type() != ov::element::i32 && ids.get_element_type() != ov::element::i64) {
ids = std::make_shared<ov::op::v0::Convert>(ids, ov::element::i32);
}

auto gather_axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
ov::Output<ov::Node> selected_weights = std::make_shared<ov::op::v8::Gather>(expert_weights, ids, gather_axis);

const auto output_type = context.get_output_type();
if (selected_weights.get_element_type() != ov::element::f32) {
selected_weights = std::make_shared<ov::op::v0::Convert>(selected_weights, ov::element::f32);
}
if (activations.get_element_type() != ov::element::f32) {
activations = std::make_shared<ov::op::v0::Convert>(activations, ov::element::f32);
}

auto selected_weights_shape = std::make_shared<ov::op::v3::ShapeOf>(selected_weights, ov::element::i64);
auto acts_target_dims = get_dimensions(selected_weights_shape, {0, 1, 3});
ov::Output<ov::Node> acts_broadcasted = std::make_shared<ov::op::v3::Broadcast>(activations, acts_target_dims,
ov::op::BroadcastType::BIDIRECTIONAL);

auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto activations_expanded = std::make_shared<ov::op::v0::Unsqueeze>(acts_broadcasted, unsqueeze_axes);

ov::Output<ov::Node> result = std::make_shared<ov::op::v0::MatMul>(activations_expanded, selected_weights, false, true);
result = std::make_shared<ov::op::v0::Squeeze>(result, unsqueeze_axes);

auto restore_batch_axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
result = std::make_shared<ov::op::v0::Unsqueeze>(result, restore_batch_axis);

if (result.get_element_type() != output_type) {
result = std::make_shared<ov::op::v0::Convert>(result, output_type);
}

return rename_outputs_with_suffix({result}, context.get_name());
}

} // namespace op
} // namespace ggml
} // namespace frontend
} // namespace ov
6 changes: 4 additions & 2 deletions ggml/src/ggml-openvino/openvino/op/permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ OutputVector translate_permute(const NodeContext & context) {
std::vector<int64_t> perm_values{0, 2, 1, 3};
const int32_t* op_params = context.get_output_op_params();
if (op_params != nullptr) {
for (size_t i = 0; i < perm_values.size(); ++i) {
perm_values[i] = static_cast<int64_t>(perm_values.size() - 1 - op_params[perm_values.size() - 1 - i]);
for (size_t input_axis = 0; input_axis < perm_values.size(); ++input_axis) {
const size_t output_axis = static_cast<size_t>(op_params[input_axis]);
perm_values[perm_values.size() - 1 - output_axis] =
static_cast<int64_t>(perm_values.size() - 1 - input_axis);
}
}
auto perm = ov::op::v0::Constant::create(ov::element::i64, {4}, perm_values);
Expand Down
27 changes: 27 additions & 0 deletions ggml/src/ggml-openvino/openvino/op/sum_rows.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "../node_context.h"
#include "../op_table.h"
#include "../utils.h"

#include <memory>
#include <openvino/op/constant.hpp>
#include <openvino/op/reduce_sum.hpp>

namespace ov {
namespace frontend {
namespace ggml {
namespace op {

OutputVector translate_sum_rows(const NodeContext & context) {
num_inputs_check(context, 1, 1);

auto input = process_view_input_new(context, 0);
auto res = std::make_shared<ov::op::v1::ReduceSum>(
input, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}), true);

return rename_outputs_with_suffix({res}, context.get_name());
}

} // namespace op
} // namespace ggml
} // namespace frontend
} // namespace ov
5 changes: 4 additions & 1 deletion ggml/src/ggml-openvino/openvino/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ std::unordered_map<std::string, CreatorFunction> get_supported_ops() {
{"GGML_OP_ADD", op::translate_1to1_match_2_inputs<v1::Add> },
{"GGML_OP_ADD1", op::translate_1to1_match_2_inputs<v1::Add> },
{"GGML_OP_CONT", op::translate_cont },
{"GGML_OP_DIV", op::translate_1to1_match_2_inputs<v1::Divide> },
{"GGML_OP_DIV", op::translate_div },
{"GGML_OP_GET_ROWS", op::translate_get_rows },
{"GGML_OP_MUL", op::translate_1to1_match_2_inputs<v1::Multiply>},
{"GGML_OP_MUL_MAT", op::translate_mulmat },
{"GGML_OP_MUL_MAT_ID", op::translate_mul_mat_id },
{"GGML_OP_PERMUTE", op::translate_permute },
{"GGML_OP_RESHAPE", op::translate_reshape },
{"GGML_OP_RMS_NORM", op::translate_rms_norm },
{"GGML_OP_NORM", op::translate_norm },
{"GGML_OP_L2_NORM", op::translate_l2_norm },
{"GGML_OP_SUM_ROWS", op::translate_sum_rows },
{"GGML_OP_ROPE", op::translate_rope },
{"GGML_OP_SCALE", op::translate_scale },
{"GGML_OP_SOFT_MAX", op::translate_soft_max },
Expand All @@ -44,6 +46,7 @@ std::unordered_map<std::string, CreatorFunction> get_supported_ops() {
{"GGML_OP_SET_ROWS", op::translate_set_rows },
{"GGML_OP_CPY", op::translate_cpy },
{"GGML_OP_FLASH_ATTN_EXT", op::translate_flash_attn_ext },
{"GGML_OP_CLAMP", op::translate_clamp },
{"GGML_OP_PAD", op::translate_pad },
{"GGML_OP_SSM_CONV", op::translate_ssm_conv },
{"GGML_OP_GATED_DELTA_NET", op::translate_gated_delta_net },
Expand Down
Loading
Loading