Skip to content

Commit

Permalink
Add API for checking if a pattern in a query is non-local
Browse files Browse the repository at this point in the history
  • Loading branch information
maxbrunsfeld committed Feb 16, 2023
1 parent 9542da6 commit 57508ea
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 12 deletions.
62 changes: 62 additions & 0 deletions cli/src/tests/query_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4084,6 +4084,68 @@ fn test_query_is_pattern_rooted() {
});
}

#[test]
fn test_query_is_pattern_non_local() {
struct Row {
description: &'static str,
pattern: &'static str,
is_non_local: bool,
}

let rows = [
Row {
description: "simple token",
pattern: r#"(identifier)"#,
is_non_local: false,
},
Row {
description: "siblings that can occur in an argument list",
pattern: r#"((identifier) (identifier))"#,
is_non_local: true,
},
Row {
description: "siblings that can occur in a statement block",
pattern: r#"((return_statement) (return_statement))"#,
is_non_local: true,
},
Row {
description: "siblings that can occur in a source file",
pattern: r#"((function_definition) (class_definition))"#,
is_non_local: true,
},
Row {
description: "siblings that can't occur in any repetition",
pattern: r#"("{" "}")"#,
is_non_local: false,
},
];

allocations::record(|| {
eprintln!("");

let language = get_language("python");
for row in &rows {
if let Some(filter) = EXAMPLE_FILTER.as_ref() {
if !row.description.contains(filter.as_str()) {
continue;
}
}
eprintln!(" query example: {:?}", row.description);
let query = Query::new(language, row.pattern).unwrap();
assert_eq!(
query.is_pattern_non_local(0),
row.is_non_local,
"Description: {}, Pattern: {:?}",
row.description,
row.pattern
.split_ascii_whitespace()
.collect::<Vec<_>>()
.join(" "),
)
}
});
}

#[test]
fn test_capture_quantifiers() {
struct Row {
Expand Down
3 changes: 3 additions & 0 deletions lib/binding_rust/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,9 @@ extern "C" {
length: *mut u32,
) -> *const TSQueryPredicateStep;
}
extern "C" {
pub fn ts_query_is_pattern_non_local(self_: *const TSQuery, pattern_index: u32) -> bool;
}
extern "C" {
pub fn ts_query_is_pattern_rooted(self_: *const TSQuery, pattern_index: u32) -> bool;
}
Expand Down
8 changes: 7 additions & 1 deletion lib/binding_rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1736,11 +1736,17 @@ impl Query {
}

/// Check if a given pattern within a query has a single root node.
#[doc(alias = "ts_query_is_pattern_guaranteed_at_step")]
#[doc(alias = "ts_query_is_pattern_rooted")]
pub fn is_pattern_rooted(&self, index: usize) -> bool {
unsafe { ffi::ts_query_is_pattern_rooted(self.ptr.as_ptr(), index as u32) }
}

/// Check if a given pattern within a query has a single root node.
#[doc(alias = "ts_query_is_pattern_non_local")]
pub fn is_pattern_non_local(&self, index: usize) -> bool {
unsafe { ffi::ts_query_is_pattern_non_local(self.ptr.as_ptr(), index as u32) }
}

/// Check if a given step in a query is 'definite'.
///
/// A query step is 'definite' if its parent pattern will be guaranteed to match
Expand Down
27 changes: 19 additions & 8 deletions lib/include/tree_sitter/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -750,15 +750,26 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern(
uint32_t *length
);

bool ts_query_is_pattern_rooted(
const TSQuery *self,
uint32_t pattern_index
);
/*
* Check if the given pattern in the query has a single root node.
*/
bool ts_query_is_pattern_rooted(const TSQuery *self, uint32_t pattern_index);

bool ts_query_is_pattern_guaranteed_at_step(
const TSQuery *self,
uint32_t byte_offset
);
/*
* Check if the given pattern in the query is 'non local'.
*
* A non-local pattern has multiple root nodes and can match within a
* repeating sequence of nodes, as specified by the grammar. Non-local
* patterns disable certain optimizations that would otherwise be possible
* when executing a query on a specific range of a syntax tree.
*/
bool ts_query_is_pattern_non_local(const TSQuery *self, uint32_t pattern_index);

/*
* Check if a given pattern is guaranteed to match once a given step is reached.
* The step is specified by its byte offset in the query's source code.
*/
bool ts_query_is_pattern_guaranteed_at_step(const TSQuery *self, uint32_t byte_offset);

/**
* Get the name and length of one of the query's captures, or one of the
Expand Down
24 changes: 21 additions & 3 deletions lib/src/query.c
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ typedef struct {
Slice steps;
Slice predicate_steps;
uint32_t start_byte;
bool is_non_local;
} QueryPattern;

typedef struct {
Expand Down Expand Up @@ -1455,7 +1456,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) {
if (!pattern->is_rooted) {
QueryStep *step = &self->steps.contents[pattern->step_index];
if (step->symbol != WILDCARD_SYMBOL) {
array_push(&non_rooted_pattern_start_steps, pattern->step_index);
array_push(&non_rooted_pattern_start_steps, i);
}
}
}
Expand Down Expand Up @@ -1868,7 +1869,8 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) {
// prevent certain optimizations with range restrictions.
analysis.did_abort = false;
for (uint32_t i = 0; i < non_rooted_pattern_start_steps.size; i++) {
uint16_t step_index = non_rooted_pattern_start_steps.contents[i];
uint16_t pattern_entry_index = non_rooted_pattern_start_steps.contents[i];
PatternEntry *pattern_entry = &self->pattern_map.contents[pattern_entry_index];

analysis_state_set__clear(&analysis.states, &analysis.state_pool);
analysis_state_set__clear(&analysis.deeper_states, &analysis.state_pool);
Expand All @@ -1880,7 +1882,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) {
for (uint32_t k = 0; k < subgraph->start_states.size; k++) {
TSStateId parse_state = subgraph->start_states.contents[k];
analysis_state_set__push(&analysis.states, &analysis.state_pool, &((AnalysisState) {
.step_index = step_index,
.step_index = pattern_entry->step_index,
.stack = {
[0] = {
.parse_state = parse_state,
Expand All @@ -1906,6 +1908,10 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) {
&analysis
);

if (analysis.finished_parent_symbols.size > 0) {
self->patterns.contents[pattern_entry->pattern_index].is_non_local = true;
}

for (unsigned k = 0; k < analysis.finished_parent_symbols.size; k++) {
TSSymbol symbol = analysis.finished_parent_symbols.contents[k];
array_insert_sorted_by(&self->repeat_symbols_with_rootless_patterns, , symbol);
Expand Down Expand Up @@ -2697,6 +2703,7 @@ TSQuery *ts_query_new(
.steps = (Slice) {.offset = start_step_index},
.predicate_steps = (Slice) {.offset = start_predicate_step_index},
.start_byte = stream_offset(&stream),
.is_non_local = false,
}));
CaptureQuantifiers capture_quantifiers = capture_quantifiers_new();
*error_type = ts_query__parse_pattern(self, &stream, 0, false, &capture_quantifiers);
Expand Down Expand Up @@ -2876,6 +2883,17 @@ bool ts_query_is_pattern_rooted(
return true;
}

bool ts_query_is_pattern_non_local(
const TSQuery *self,
uint32_t pattern_index
) {
if (pattern_index < self->patterns.size) {
return self->patterns.contents[pattern_index].is_non_local;
} else {
return false;
}
}

bool ts_query_is_pattern_guaranteed_at_step(
const TSQuery *self,
uint32_t byte_offset
Expand Down

0 comments on commit 57508ea

Please sign in to comment.