Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-openvino/ggml-openvino.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
// op->src[0]->ne[0]);
return true;
}
if (op->type != GGML_TYPE_F32) {
if (op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) {
// GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type));
return true;
}
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/ggml-openvino/openvino/op/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ OutputVector translate_rope(const NodeContext & context) {
}
}

auto output_type = context.get_output_type();
if (data_node->get_element_type() != ov::element::f32) {
data_node = std::make_shared<ov::op::v0::Convert>(data_node, ov::element::f32);
}

if (mode == ROPE_TYPE_NORMAL) {
auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
Expand Down Expand Up @@ -139,6 +144,10 @@ OutputVector translate_rope(const NodeContext & context) {
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{sub, add}, 3);
}

if (res.get_element_type() != output_type) {
res = std::make_shared<ov::op::v0::Convert>(res, output_type);
}

return rename_outputs_with_suffix({res}, context.get_name());
}

Expand Down
20 changes: 15 additions & 5 deletions ggml/src/ggml-openvino/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,17 +283,23 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr<o
auto & core = ov_singleton_core();

auto get_prefill_chunk_size = [] {
const char * chunk_size_str = getenv("GGML_OPENVINO_PREFILL_CHUNK_SIZE");
if (chunk_size_str && atoi(chunk_size_str) > 0) {
return atoi(chunk_size_str);
static int chunk_size = -1;
if (chunk_size == -1) {
const char * chunk_size_str = getenv("GGML_OPENVINO_PREFILL_CHUNK_SIZE");
if (chunk_size_str && atoi(chunk_size_str) > 0) {
chunk_size = atoi(chunk_size_str);
} else {
chunk_size = 256;
}
}
return 256;
return chunk_size;
};

static std::string device = "NPU";
static auto is_static = true;
static auto stateful = false;
static auto prefill_chunk_size = get_prefill_chunk_size();

auto prefill_chunk_size = get_prefill_chunk_size();
const auto & config = ggml_openvino_get_compile_config();

if (is_naive(cgraph)) {
Expand Down Expand Up @@ -357,6 +363,10 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr<o
std::shared_ptr<ov::Model> model;
auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph);

if (m_params.n_heads == -1) {
// graph is not a LLM, e.g. context-shift graph
prefill_chunk_size = inp_pos->ne[0];
}
auto ggml_decoder_prefill = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights,
is_static, stateful, false, true, prefill_chunk_size);
auto ggml_decoder_decode = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights, is_static,
Expand Down