Skip to content

Commit

Permalink
PagedAttention Transformation: Rank alignment for replacements (openv…
Browse files Browse the repository at this point in the history
…inotoolkit#24690)

During the elimination of dependencies from `beam_idx` input and
`ReadValue`(s), we are replacing them by the new PA-related inputs and
sub-expressions dependent on other remaining inputs. In such
replacements we need to guarantee matching shape and element type of old
and new nodes. Before this PR it was not guaranteed for shape and
sometimes a scalar was replaced by a shape of rank 1 that led to errors
like `'start' input is not a scalar`. Now the shape is aligned.

---------

Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
  • Loading branch information
3 people committed May 27, 2024
1 parent 65c3b17 commit c0d197c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,5 @@ class PrevSequenceLengthPattern;
class ov::pass::PrevSequenceLengthPattern : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("PrevSequenceLengthPattern", "0");
explicit PrevSequenceLengthPattern(const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len,
std::shared_ptr<ov::Node>);
explicit PrevSequenceLengthPattern(std::shared_ptr<ov::Node> prev_max_seq_len, std::shared_ptr<ov::Node> batch_dim);
};
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

using namespace ov::op;

ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(
const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len,
std::shared_ptr<ov::Node> batch_dim) {
ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(std::shared_ptr<ov::Node> prev_max_seq_len,
std::shared_ptr<ov::Node> batch_dim) {
MATCHER_SCOPE(PrevSequenceLengthPattern);
// The transformation addresses two cases that look similar: (1) previous sequence length, (2) batch size in
// kv-cache state In first case it should replace it by prev_max_seq_len. For the second case, connect to batch_dim.
Expand All @@ -39,30 +38,23 @@ ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(
auto axis = gather_index->cast_vector<int64_t>().at(0);
auto kv_init_shape = pattern_map.at(kv_past).get_node()->get_input_partial_shape(0);
auto target_type = gather->get_output_element_type(0);
std::shared_ptr<ov::Node> replacement;
if (kv_init_shape[axis].is_static() && kv_init_shape[axis].get_length() == 0) {
// this is a sequence dimension based on how the initialization expression is build for stateful models
std::shared_ptr<ov::Node> replacement;
if (prev_max_seq_len->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(prev_max_seq_len, target_type);
} else {
replacement = prev_max_seq_len;
}
replace_node(
gather,
std::make_shared<v1::Reshape>(replacement, v0::Constant::create(element::i64, Shape{1}, {1}), false));
return true;
} else { // assumption that any other axis should point to batch dimension, precise reasoning is too complex
// (TODO)
// this is a batch dimension
std::shared_ptr<ov::Node> replacement;
if (batch_dim->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(batch_dim, target_type);
} else {
replacement = batch_dim;
}
replace_node(gather, replacement);
return true;
replacement = prev_max_seq_len;
} else {
// assumption that any other axis should point to batch dimension, precise reasoning is too complex
// TODO: provide more reliable check
replacement = batch_dim;
}
if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);
}
auto required_shape = gather->get_output_partial_shape(0);
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
}
replace_node(gather, replacement);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(seq, matcher_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
Expand All @@ -29,11 +30,13 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(
// use symbolic infra or look at the constant input
auto gather = m.get_match_root();
auto target_type = gather->get_output_element_type(0);
std::shared_ptr<Node> replacement;
if (max_context_len->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(max_context_len, target_type);
} else {
replacement = max_context_len;
std::shared_ptr<Node> replacement = max_context_len;
if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);
}
auto required_shape = gather->get_output_partial_shape(0);
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
}
replace_node(gather, replacement);
return true;
Expand Down

0 comments on commit c0d197c

Please sign in to comment.