Skip to content

Commit

Permalink
ruleguard: give error message when filter uses undefined var (#282)
Browse files Browse the repository at this point in the history
Fixes #159
  • Loading branch information
quasilyte committed Oct 14, 2021
1 parent 7b21d77 commit 451d089
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 54 deletions.
19 changes: 14 additions & 5 deletions internal/gogrep/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ type compiler struct {
strict bool
fset *token.FileSet

info *PatternInfo

insideStmtList bool
}

func (c *compiler) Compile(fset *token.FileSet, root ast.Node, strict bool) (p *program, err error) {
func (c *compiler) Compile(fset *token.FileSet, root ast.Node, info *PatternInfo, strict bool) (p *program, err error) {
defer func() {
if err != nil {
return
Expand All @@ -38,6 +40,7 @@ func (c *compiler) Compile(fset *token.FileSet, root ast.Node, strict bool) (p *
panic(rv) // Not our panic
}()

c.info = info
c.fset = fset
c.strict = strict
c.prog = &program{
Expand Down Expand Up @@ -68,6 +71,12 @@ func (c *compiler) toUint8(n ast.Node, v int) uint8 {
return uint8(v)
}

func (c *compiler) internVar(n ast.Node, s string) uint8 {
c.info.Vars[s] = struct{}{}
index := c.internString(n, s)
return index
}

func (c *compiler) internString(n ast.Node, s string) uint8 {
if index, ok := c.stringIndexes[s]; ok {
return index
Expand Down Expand Up @@ -156,7 +165,7 @@ func (c *compiler) compileOptFieldList(n *ast.FieldList) {
} else {
c.emitInst(instruction{
op: opNamedFieldNode,
valueIndex: c.internString(n, info.Name),
valueIndex: c.internVar(n, info.Name),
})
}
return
Expand Down Expand Up @@ -406,10 +415,10 @@ func (c *compiler) compileWildIdent(n *ast.Ident, optional bool) {
inst.op = pickOp(optional, opOptNode, opNodeSeq)
case info.Name != "_" && !info.Seq:
inst.op = opNamedNode
inst.valueIndex = c.internString(n, info.Name)
inst.valueIndex = c.internVar(n, info.Name)
default:
inst.op = pickOp(optional, opNamedOptNode, opNamedNodeSeq)
inst.valueIndex = c.internString(n, info.Name)
inst.valueIndex = c.internVar(n, info.Name)
}
c.prog.insts = append(c.prog.insts, inst)
}
Expand Down Expand Up @@ -771,7 +780,7 @@ func (c *compiler) compileIfStmt(n *ast.IfStmt) {
if info.Seq {
c.prog.insts = append(c.prog.insts, instruction{
op: pickOp(n.Else == nil, opIfNamedOptStmt, opIfNamedOptElseStmt),
valueIndex: c.internString(ident, info.Name),
valueIndex: c.internVar(ident, info.Name),
})
c.compileStmt(n.Body)
if n.Else != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/gogrep/compile_error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestCompileError(t *testing.T) {
for input, want := range tests {
fset := token.NewFileSet()
testPattern := unwrapPattern(input)
_, err := Compile(fset, testPattern, isStrict(input))
_, _, err := Compile(fset, testPattern, isStrict(input))
if err == nil {
t.Errorf("compile `%s`: expected error, got none", input)
continue
Expand Down
5 changes: 3 additions & 2 deletions internal/gogrep/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ func TestCompileWildcard(t *testing.T) {
input := test.input
want := test.output
fset := token.NewFileSet()
p, err := Compile(fset, input, false)
p, _, err := Compile(fset, input, false)
if err != nil {
t.Errorf("compile `%s`: %v", input, err)
return
Expand Down Expand Up @@ -992,7 +992,8 @@ func TestCompile(t *testing.T) {
fset := token.NewFileSet()
n := testParseNode(t, fset, input)
var c compiler
p, err := c.Compile(fset, n, false)
info := newPatternInfo()
p, err := c.Compile(fset, n, &info, false)
if err != nil {
t.Errorf("compile `%s`: %v", input, err)
return
Expand Down
21 changes: 16 additions & 5 deletions internal/gogrep/gogrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ type Pattern struct {
m *matcher
}

type PatternInfo struct {
Vars map[string]struct{}
}

func (p *Pattern) NodeTag() nodetag.Value {
return operationInfoTable[p.m.prog.insts[0].op].Tag
}
Expand All @@ -55,16 +59,23 @@ func (p *Pattern) Clone() *Pattern {
return &clone
}

func Compile(fset *token.FileSet, src string, strict bool) (*Pattern, error) {
func Compile(fset *token.FileSet, src string, strict bool) (*Pattern, PatternInfo, error) {
info := newPatternInfo()
n, err := parseExpr(fset, src)
if err != nil {
return nil, err
return nil, info, err
}
var c compiler
prog, err := c.Compile(fset, n, strict)
prog, err := c.Compile(fset, n, &info, strict)
if err != nil {
return nil, err
return nil, info, err
}
m := newMatcher(prog)
return &Pattern{m: m}, nil
return &Pattern{m: m}, info, nil
}

func newPatternInfo() PatternInfo {
return PatternInfo{
Vars: map[string]struct{}{},
}
}
2 changes: 1 addition & 1 deletion internal/gogrep/match_perf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func BenchmarkMatch(b *testing.B) {
test := tests[i]
b.Run(test.name, func(b *testing.B) {
fset := token.NewFileSet()
pat, err := Compile(fset, test.pat, true)
pat, _, err := Compile(fset, test.pat, true)
if err != nil {
b.Errorf("parse `%s`: %v", test.pat, err)
return
Expand Down
2 changes: 1 addition & 1 deletion internal/gogrep/match_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ func TestMatch(t *testing.T) {
t.Run(fmt.Sprintf("test%d", i), func(t *testing.T) {
fset := token.NewFileSet()
testPattern := unwrapPattern(test.pat)
pat, err := Compile(fset, testPattern, isStrict(test.pat))
pat, _, err := Compile(fset, testPattern, isStrict(test.pat))
if err != nil {
t.Errorf("compile `%s`: %v", test.pat, err)
return
Expand Down
37 changes: 27 additions & 10 deletions ruleguard/ir/filter_op.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 22 additions & 18 deletions ruleguard/ir/gen_filter_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type opInfo struct {
const (
flagIsBinaryExpr uint64 = 1 << iota
flagIsBasicLit
flagHasVar
)

func main() {
Expand All @@ -38,24 +39,24 @@ func main() {
{name: "GtEq", comment: "$Args[0] >= $Args[1]", flags: flagIsBinaryExpr},
{name: "LtEq", comment: "$Args[0] <= $Args[1]", flags: flagIsBinaryExpr},

{name: "VarAddressable", comment: "m[$Value].Addressable", valueType: "string"},
{name: "VarPure", comment: "m[$Value].Pure", valueType: "string"},
{name: "VarConst", comment: "m[$Value].Const", valueType: "string"},
{name: "VarConstSlice", comment: "m[$Value].ConstSlice", valueType: "string"},
{name: "VarText", comment: "m[$Value].Text", valueType: "string"},
{name: "VarLine", comment: "m[$Value].Line", valueType: "string"},
{name: "VarValueInt", comment: "m[$Value].Value.Int()", valueType: "string"},
{name: "VarTypeSize", comment: "m[$Value].Type.Size", valueType: "string"},

{name: "VarFilter", comment: "m[$Value].Filter($Args[0])", valueType: "string"},
{name: "VarNodeIs", comment: "m[$Value].Node.Is($Args[0])", valueType: "string"},
{name: "VarObjectIs", comment: "m[$Value].Object.Is($Args[0])", valueType: "string"},
{name: "VarTypeIs", comment: "m[$Value].Type.Is($Args[0])", valueType: "string"},
{name: "VarTypeUnderlyingIs", comment: "m[$Value].Type.Underlying().Is($Args[0])", valueType: "string"},
{name: "VarTypeConvertibleTo", comment: "m[$Value].Type.ConvertibleTo($Args[0])", valueType: "string"},
{name: "VarTypeAssignableTo", comment: "m[$Value].Type.AssignableTo($Args[0])", valueType: "string"},
{name: "VarTypeImplements", comment: "m[$Value].Type.Implements($Args[0])", valueType: "string"},
{name: "VarTextMatches", comment: "m[$Value].Text.Matches($Args[0])", valueType: "string"},
{name: "VarAddressable", comment: "m[$Value].Addressable", valueType: "string", flags: flagHasVar},
{name: "VarPure", comment: "m[$Value].Pure", valueType: "string", flags: flagHasVar},
{name: "VarConst", comment: "m[$Value].Const", valueType: "string", flags: flagHasVar},
{name: "VarConstSlice", comment: "m[$Value].ConstSlice", valueType: "string", flags: flagHasVar},
{name: "VarText", comment: "m[$Value].Text", valueType: "string", flags: flagHasVar},
{name: "VarLine", comment: "m[$Value].Line", valueType: "string", flags: flagHasVar},
{name: "VarValueInt", comment: "m[$Value].Value.Int()", valueType: "string", flags: flagHasVar},
{name: "VarTypeSize", comment: "m[$Value].Type.Size", valueType: "string", flags: flagHasVar},

{name: "VarFilter", comment: "m[$Value].Filter($Args[0])", valueType: "string", flags: flagHasVar},
{name: "VarNodeIs", comment: "m[$Value].Node.Is($Args[0])", valueType: "string", flags: flagHasVar},
{name: "VarObjectIs", comment: "m[$Value].Object.Is($Args[0])", valueType: "string", flags: flagHasVar},
{name: "VarTypeIs", comment: "m[$Value].Type.Is($Args[0])", valueType: "string", flags: flagHasVar},
{name: "VarTypeUnderlyingIs", comment: "m[$Value].Type.Underlying().Is($Args[0])", valueType: "string", flags: flagHasVar},
{name: "VarTypeConvertibleTo", comment: "m[$Value].Type.ConvertibleTo($Args[0])", valueType: "string", flags: flagHasVar},
{name: "VarTypeAssignableTo", comment: "m[$Value].Type.AssignableTo($Args[0])", valueType: "string", flags: flagHasVar},
{name: "VarTypeImplements", comment: "m[$Value].Type.Implements($Args[0])", valueType: "string", flags: flagHasVar},
{name: "VarTextMatches", comment: "m[$Value].Text.Matches($Args[0])", valueType: "string", flags: flagHasVar},

{name: "Deadcode", comment: "m.Deadcode()"},

Expand Down Expand Up @@ -117,6 +118,9 @@ func main() {
if op.flags&flagIsBasicLit != 0 {
parts = append(parts, "flagIsBasicLit")
}
if op.flags&flagHasVar != 0 {
parts = append(parts, "flagHasVar")
}
fmt.Fprintf(&buf, "Filter%sOp: %s,\n", op.name, strings.Join(parts, " | "))
}
buf.WriteString("}\n")
Expand Down
2 changes: 2 additions & 0 deletions ruleguard/ir/ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func (e FilterExpr) IsValid() bool { return e.Op != FilterInvalidOp }

func (e FilterExpr) IsBinaryExpr() bool { return filterOpFlags[e.Op]&flagIsBinaryExpr != 0 }
func (e FilterExpr) IsBasicLit() bool { return filterOpFlags[e.Op]&flagIsBasicLit != 0 }
func (e FilterExpr) HasVar() bool { return filterOpFlags[e.Op]&flagHasVar != 0 }

func (e FilterExpr) String() string {
switch e.Op {
Expand All @@ -107,4 +108,5 @@ func (e FilterExpr) String() string {
const (
flagIsBinaryExpr uint64 = 1 << iota
flagIsBasicLit
flagHasVar
)
43 changes: 32 additions & 11 deletions ruleguard/ir_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,16 +270,19 @@ func (l *irLoader) loadRule(rule ir.Rule) error {
location: rule.LocationVar,
}

info := filterInfo{
Vars: make(map[string]struct{}),
}
if rule.WhereExpr.IsValid() {
filter, err := l.newFilter(rule.WhereExpr)
filter, err := l.newFilter(rule.WhereExpr, &info)
if err != nil {
return err
}
proto.filter = filter
}

for _, pat := range rule.SyntaxPatterns {
if err := l.loadSyntaxRule(proto, rule, pat.Value, pat.Line); err != nil {
if err := l.loadSyntaxRule(proto, info, rule, pat.Value, pat.Line); err != nil {
return err
}
}
Expand Down Expand Up @@ -309,16 +312,26 @@ func (l *irLoader) loadCommentRule(resultProto goRule, rule ir.Rule, src string,
return nil
}

func (l *irLoader) loadSyntaxRule(resultProto goRule, rule ir.Rule, src string, line int) error {
func (l *irLoader) loadSyntaxRule(resultProto goRule, filterInfo filterInfo, rule ir.Rule, src string, line int) error {
result := resultProto
result.line = line

pat, err := gogrep.Compile(l.gogrepFset, src, false)
pat, info, err := gogrep.Compile(l.gogrepFset, src, false)
if err != nil {
return l.errorf(rule.Line, err, "parse match pattern")
}
result.pat = pat

for filterVar := range filterInfo.Vars {
if filterVar == "$$" {
continue // OK: a predefined var for the "entire match"
}
_, ok := info.Vars[filterVar]
if !ok {
return l.errorf(rule.Line, nil, "filter refers to a non-existing var %s", filterVar)
}
}

dst := l.res.universal
var dstTags []nodetag.Value
switch tag := pat.NodeTag(); tag {
Expand Down Expand Up @@ -441,16 +454,20 @@ func (l *irLoader) unwrapStringExpr(filter ir.FilterExpr) string {
return ""
}

func (l *irLoader) newFilter(filter ir.FilterExpr) (matchFilter, error) {
func (l *irLoader) newFilter(filter ir.FilterExpr, info *filterInfo) (matchFilter, error) {
if filter.HasVar() {
info.Vars[filter.Value.(string)] = struct{}{}
}

if filter.IsBinaryExpr() {
return l.newBinaryExprFilter(filter)
return l.newBinaryExprFilter(filter, info)
}

result := matchFilter{src: filter.Src}

switch filter.Op {
case ir.FilterNotOp:
x, err := l.newFilter(filter.Args[0])
x, err := l.newFilter(filter.Args[0], info)
if err != nil {
return result, err
}
Expand Down Expand Up @@ -600,14 +617,14 @@ func (l *irLoader) newFilter(filter ir.FilterExpr) (matchFilter, error) {
return result, nil
}

func (l *irLoader) newBinaryExprFilter(filter ir.FilterExpr) (matchFilter, error) {
func (l *irLoader) newBinaryExprFilter(filter ir.FilterExpr, info *filterInfo) (matchFilter, error) {
if filter.Op == ir.FilterAndOp || filter.Op == ir.FilterOrOp {
result := matchFilter{src: filter.Src}
lhs, err := l.newFilter(filter.Args[0])
lhs, err := l.newFilter(filter.Args[0], info)
if err != nil {
return result, err
}
rhs, err := l.newFilter(filter.Args[1])
rhs, err := l.newFilter(filter.Args[1], info)
if err != nil {
return result, err
}
Expand All @@ -631,7 +648,7 @@ func (l *irLoader) newBinaryExprFilter(filter ir.FilterExpr) (matchFilter, error
// Simple commutative ops. Just swap the args.
newFilter := filter
newFilter.Args = []ir.FilterExpr{filter.Args[1], filter.Args[0]}
return l.newBinaryExprFilter(newFilter)
return l.newBinaryExprFilter(newFilter, info)
}
}
}
Expand Down Expand Up @@ -697,3 +714,7 @@ func (l *irLoader) newBinaryExprFilter(filter ir.FilterExpr) (matchFilter, error

return result, nil
}

type filterInfo struct {
Vars map[string]struct{}
}
Loading

0 comments on commit 451d089

Please sign in to comment.