Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple predicates #92

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
210 changes: 163 additions & 47 deletions bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"math"
"reflect"
"regexp"
"runtime"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -119,33 +120,12 @@ func (p *Parser) ParseCtx(ctx context.Context, oldTree *Tree, content []byte) (*
close(parseComplete)
C.free(input)

return p.convertTSTree(ctx, BaseTree)
}

// ParseInput produces new Tree by reading from a callback defined in input
// it is useful if your data is stored in specialized data structure
// as it will avoid copying the data into []bytes
// and faster access to edited part of the data
func (p *Parser) ParseInput(oldTree *Tree, input Input) *Tree {
t, _ := p.ParseInputCtx(context.Background(), oldTree, input)
return t
}

// ParseInputCtx produces new Tree by reading from a callback defined in input
// it is useful if your data is stored in specialized data structure
// as it will avoid copying the data into []bytes
// and faster access to edited part of the data
func (p *Parser) ParseInputCtx(ctx context.Context, oldTree *Tree, input Input) (*Tree, error) {
var BaseTree *C.TSTree
if oldTree != nil {
BaseTree = oldTree.c
t, err := p.convertTSTree(ctx, BaseTree)
if err != nil {
return nil, err
}

funcID := readFuncs.register(input.Read)
BaseTree = C.call_ts_parser_parse(p.c, BaseTree, C.int(funcID), C.TSInputEncoding(input.Encoding))
readFuncs.unregister(funcID)

return p.convertTSTree(ctx, BaseTree)
t.input = content
return t, nil
}

// convertTSTree converts the tree-sitter response into a *Tree or an error.
Expand Down Expand Up @@ -271,13 +251,17 @@ type Tree struct {
// Otherwise Parser may be GC'ed (and deleted by the finalizer) while some Tree objects are still in use.
p *Parser

input []byte

// most probably better save node.id
cache map[C.TSNode]*Node
}

// Copy returns a new copy of a tree
func (t *Tree) Copy() *Tree {
return t.p.newTree(C.ts_tree_copy(t.c))
nt := t.p.newTree(C.ts_tree_copy(t.c))
nt.input = t.input
return nt
}

// RootNode returns root node of a tree
Expand All @@ -286,6 +270,11 @@ func (t *Tree) RootNode() *Node {
return t.cachedNode(ptr)
}

// Input returns the input given to parse this tree
func (t *Tree) Input() []byte {
return t.input
}

func (t *Tree) cachedNode(ptr C.TSNode) *Node {
if ptr.id == nil {
return nil
Expand Down Expand Up @@ -564,8 +553,8 @@ func (n Node) Edit(i EditInput) {
}

// Content returns node's source code from input as a string
func (n Node) Content(input []byte) string {
return string(input[n.StartByte():n.EndByte()])
func (n Node) Content() string {
return string(n.t.input[n.StartByte():n.EndByte()])
}

func (n Node) NamedDescendantForPointRange(start Point, end Point) *Node {
Expand Down Expand Up @@ -737,6 +726,49 @@ func NewQuery(pattern []byte, lang *Language) (*Query, error) {
}

q := &Query{c: c}

for i := uint32(0); i < q.PatternCount(); i++ {
steps := q.PredicatesForPattern(i)
if len(steps) == 0 {
continue
}

if steps[0].Type != QueryPredicateStepTypeString {
return nil, errors.New("predicate must begin with a literal value")
}

operator := q.StringValueForId(steps[0].ValueId)
switch operator {
case "eq?", "not-eq?":
if len(steps) != 4 {
return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
}
if steps[1].Type != QueryPredicateStepTypeCapture {
return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
}
case "match?", "not-match?":
if len(steps) != 4 {
return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
}
if steps[1].Type != QueryPredicateStepTypeCapture {
return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
}
if steps[2].Type != QueryPredicateStepTypeString {
return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))
}
case "set!", "is?", "is-not?":
if len(steps) < 3 || len(steps) > 4 {
return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 1 or 2, got %d", operator, len(steps)-2)
}
if steps[1].Type != QueryPredicateStepTypeString {
return nil, fmt.Errorf("first argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[1].ValueId))
}
if len(steps) > 2 && steps[2].Type != QueryPredicateStepTypeString {
return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))
}
}
}

runtime.SetFinalizer(q, (*Query).Close)

return q, nil
Expand Down Expand Up @@ -797,6 +829,9 @@ func (q *Query) PredicatesForPattern(patternIndex uint32) []QueryPredicateStep {
stepType := QueryPredicateStepType(s._type)
valueId := uint32(s.value_id)
predicateSteps = append(predicateSteps, QueryPredicateStep{stepType, valueId})
if stepType == QueryPredicateStepTypeDone {
break
}
}

return predicateSteps
Expand Down Expand Up @@ -881,32 +916,41 @@ type QueryMatch struct {
// Otherwise, it will populate the QueryMatch with data
// about which pattern matched and which nodes were captured.
func (qc *QueryCursor) NextMatch() (*QueryMatch, bool) {
return qc.nextMatch(true)
}

func (qc *QueryCursor) nextMatch(filterPredicates bool) (*QueryMatch, bool) {
var (
cqm C.TSQueryMatch
cqc []C.TSQueryCapture
)

if ok := C.ts_query_cursor_next_match(qc.c, &cqm); !bool(ok) {
return nil, false
}
for {
if ok := C.ts_query_cursor_next_match(qc.c, &cqm); !bool(ok) {
return nil, false
}

qm := &QueryMatch{
ID: uint32(cqm.id),
PatternIndex: uint16(cqm.pattern_index),
}
qm := &QueryMatch{
ID: uint32(cqm.id),
PatternIndex: uint16(cqm.pattern_index),
}

count := int(cqm.capture_count)
slice := (*reflect.SliceHeader)((unsafe.Pointer(&cqc)))
slice.Cap = count
slice.Len = count
slice.Data = uintptr(unsafe.Pointer(cqm.captures))
for _, c := range cqc {
idx := uint32(c.index)
node := qc.t.cachedNode(c.node)
qm.Captures = append(qm.Captures, QueryCapture{idx, node})
}
count := int(cqm.capture_count)
slice := (*reflect.SliceHeader)((unsafe.Pointer(&cqc)))
slice.Cap = count
slice.Len = count
slice.Data = uintptr(unsafe.Pointer(cqm.captures))
for _, c := range cqc {
idx := uint32(c.index)
node := qc.t.cachedNode(c.node)
qm.Captures = append(qm.Captures, QueryCapture{idx, node})
}

return qm, true
if filterPredicates && !qm.satisfiesTextPredicates(qc.q) {
continue
}
return qm, true
}
}

func (qc *QueryCursor) NextCapture() (*QueryMatch, uint32, bool) {
Expand Down Expand Up @@ -939,6 +983,78 @@ func (qc *QueryCursor) NextCapture() (*QueryMatch, uint32, bool) {
return qm, uint32(captureIndex), true
}

func (qm *QueryMatch) satisfiesTextPredicates(q *Query) bool {
steps := q.PredicatesForPattern(uint32(qm.PatternIndex))
if len(steps) == 0 {
return true
}

operator := q.StringValueForId(steps[0].ValueId)

switch operator {
case "eq?", "not-eq?":
isPositive := operator == "eq?"

expectedCaptureNameLeft := q.CaptureNameForId(steps[1].ValueId)

if steps[2].Type == QueryPredicateStepTypeCapture {
expectedCaptureNameRight := q.CaptureNameForId(steps[2].ValueId)

var nodeLeft, nodeRight *Node

for _, c := range qm.Captures {
captureName := q.CaptureNameForId(c.Index)

if captureName == expectedCaptureNameLeft {
nodeLeft = c.Node
}
if captureName == expectedCaptureNameRight {
nodeRight = c.Node
}

if nodeLeft != nil && nodeRight != nil {
matches := nodeLeft.Content() == nodeRight.Content()
if isPositive {
return matches
}
return !matches
}
}
} else {
expectedValueRight := q.StringValueForId(steps[2].ValueId)

for _, c := range qm.Captures {
captureName := q.CaptureNameForId(c.Index)
if expectedCaptureNameLeft == captureName {
matches := c.Node.Content() == expectedValueRight
if isPositive {
return matches
}
return !matches
}
}
}
case "match?", "not-match?":
isPositive := operator == "match?"

expectedCaptureName := q.CaptureNameForId(steps[1].ValueId)
regex := regexp.MustCompile(q.StringValueForId(steps[2].ValueId))

for _, c := range qm.Captures {
captureName := q.CaptureNameForId(c.Index)
if expectedCaptureName == captureName {
matches := regex.Match([]byte(c.Node.Content()))
if isPositive {
return matches
}
return !matches
}

}
}
return false
}

// keeps callbacks for parser.parse method
type readFuncsMap struct {
sync.Mutex
Expand Down