Skip to content

Commit

Permalink
topdown: memoize partial set/object results (open-policy-agent#3492)
Browse files Browse the repository at this point in the history
We can either cache individual elements (`data.foo.p["bar"]`), or the
full extent of a partial set/object. A cached full extent of the partial
would be used when evaluating individual elements of the partial.

If the first encounter with a partial set/object has to materialize the
full extent with a variable key, like `data.foo.p[x]`, then we cache the
fully-evaluated result for `data.foo.p`.

Fixes open-policy-agent#822.

Co-authored-by: Torin Sandall <torinsandall@gmail.com>
Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus and tsandall committed Jul 14, 2021
1 parent 44436dc commit f3284cf
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 148 deletions.
2 changes: 1 addition & 1 deletion ast/term.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ func (term *Term) Hash() int {
return term.Value.Hash()
}

// IsGround returns true if this terms' Value is ground.
// IsGround returns true if this term's Value is ground.
func (term *Term) IsGround() bool {
return term.Value.IsGround()
}
Expand Down
103 changes: 71 additions & 32 deletions topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -2066,6 +2066,12 @@ type evalVirtualPartial struct {
empty *ast.Term
}

type evalVirtualPartialCacheHint struct {
key ast.Ref
hit bool
full bool
}

func (e evalVirtualPartial) eval(iter unifyIterator) error {

unknown := e.e.unknown(e.ref[:e.pos+1], e.bindings)
Expand Down Expand Up @@ -2095,17 +2101,25 @@ func (e evalVirtualPartial) evalEachRule(iter unifyIterator, unknown bool) error
return nil
}

key, hit, err := e.evalCache(iter)
hint, err := e.evalCache(iter)
if err != nil {
return err
} else if hit {
} else if hint.hit {
return nil
}

result := e.empty
if hint.full {
result, err := e.evalAllRulesNoCache(e.ir.Rules)
if err != nil {
return err
}
e.e.virtualCache.Put(hint.key, result)
return e.evalTerm(iter, e.pos+1, result, e.bindings)
}

result := e.empty
for _, rule := range e.ir.Rules {
if err := e.evalOneRulePreUnify(iter, rule, key, result, unknown); err != nil {
if err := e.evalOneRulePreUnify(iter, rule, hint, result, unknown); err != nil {
return err
}
}
Expand All @@ -2115,12 +2129,33 @@ func (e evalVirtualPartial) evalEachRule(iter unifyIterator, unknown bool) error

func (e evalVirtualPartial) evalAllRules(iter unifyIterator, rules []*ast.Rule) error {

cacheKey := e.plugged[:e.pos+1]
result := e.e.virtualCache.Get(cacheKey)
if result != nil {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
return e.e.biunify(result, e.rterm, e.bindings, e.rbindings, iter)
}

e.e.instr.counterIncr(evalOpVirtualCacheMiss)

result, err := e.evalAllRulesNoCache(rules)
if err != nil {
return err
}

if cacheKey != nil {
e.e.virtualCache.Put(cacheKey, result)
}

return e.e.biunify(result, e.rterm, e.bindings, e.rbindings, iter)
}

func (e evalVirtualPartial) evalAllRulesNoCache(rules []*ast.Rule) (*ast.Term, error) {
result := e.empty

for _, rule := range rules {
child := e.e.child(rule.Body)
child.traceEnter(rule)

err := child.eval(func(*eval) error {
child.traceExit(rule)
var err error
Expand All @@ -2134,14 +2169,14 @@ func (e evalVirtualPartial) evalAllRules(iter unifyIterator, rules []*ast.Rule)
})

if err != nil {
return err
return nil, err
}
}

return e.e.biunify(result, e.rterm, e.bindings, e.rbindings, iter)
return result, nil
}

func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Rule, cacheKey ast.Ref, result *ast.Term, unknown bool) error {
func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Rule, hint evalVirtualPartialCacheHint, result *ast.Term, unknown bool) error {

key := e.ref[e.pos+1]
child := e.e.child(rule.Body)
Expand All @@ -2158,9 +2193,9 @@ func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Ru
term = rule.Head.Key
}

if cacheKey != nil {
if hint.key != nil {
result := child.bindings.Plug(term)
e.e.virtualCache.Put(cacheKey, result)
e.e.virtualCache.Put(hint.key, result)
}

// NOTE(tsandall): if the rule set depends on any unknowns then do
Expand All @@ -2181,7 +2216,7 @@ func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Ru

child.traceExit(rule)
term, termbindings := child.bindings.apply(term)
err := e.evalTerm(iter, term, termbindings)
err := e.evalTerm(iter, e.pos+2, term, termbindings)
if err != nil {
return err
}
Expand Down Expand Up @@ -2239,7 +2274,7 @@ func (e evalVirtualPartial) evalOneRuleContinue(iter unifyIterator, rule *ast.Ru
}

term, termbindings := child.bindings.apply(term)
err := e.evalTerm(iter, term, termbindings)
err := e.evalTerm(iter, e.pos+2, term, termbindings)
if err != nil {
return err
}
Expand Down Expand Up @@ -2330,11 +2365,11 @@ func (e evalVirtualPartial) partialEvalSupportRule(rule *ast.Rule, path ast.Ref)
return defined, err
}

func (e evalVirtualPartial) evalTerm(iter unifyIterator, term *ast.Term, termbindings *bindings) error {
func (e evalVirtualPartial) evalTerm(iter unifyIterator, pos int, term *ast.Term, termbindings *bindings) error {
eval := evalTerm{
e: e.e,
ref: e.ref,
pos: e.pos + 2,
pos: pos,
bindings: e.bindings,
term: term,
termbindings: termbindings,
Expand All @@ -2344,34 +2379,38 @@ func (e evalVirtualPartial) evalTerm(iter unifyIterator, term *ast.Term, termbin
return eval.eval(iter)
}

func (e evalVirtualPartial) evalCache(iter unifyIterator) (ast.Ref, bool, error) {
func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCacheHint, error) {

var hint evalVirtualPartialCacheHint

if e.e.unknown(e.ref[:e.pos+1], e.bindings) {
return nil, false, nil
return hint, nil
}

var cacheKey ast.Ref

if e.ir.Kind == ast.PartialObjectDoc {

plugged := e.bindings.Plug(e.ref[e.pos+1])
if cached := e.e.virtualCache.Get(e.plugged[:e.pos+1]); cached != nil { // have full extent cached
e.e.instr.counterIncr(evalOpVirtualCacheHit)
hint.hit = true
return hint, e.evalTerm(iter, e.pos+1, cached, e.bindings)
}

if plugged.IsGround() {
path := e.plugged[:e.pos+2]
path[len(path)-1] = plugged
cached := e.e.virtualCache.Get(path)
plugged := e.bindings.Plug(e.ref[e.pos+1])

if cached != nil {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
return nil, true, e.evalTerm(iter, cached, e.bindings)
}
if plugged.IsGround() {
hint.key = append(e.plugged[:e.pos+1], plugged)

e.e.instr.counterIncr(evalOpVirtualCacheMiss)
cacheKey = path
if cached := e.e.virtualCache.Get(hint.key); cached != nil {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
hint.hit = true
return hint, e.evalTerm(iter, e.pos+2, cached, e.bindings)
}
} else if _, ok := plugged.Value.(ast.Var); ok {
hint.full = true
hint.key = e.plugged[:e.pos+1]
}

return cacheKey, false, nil
e.e.instr.counterIncr(evalOpVirtualCacheMiss)

return hint, nil
}

func (e evalVirtualPartial) reduce(head *ast.Head, b *bindings, result *ast.Term) (*ast.Term, bool, error) {
Expand Down
91 changes: 90 additions & 1 deletion topdown/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func TestContainsNestedRefOrCall(t *testing.T) {
}
}

func TestTopdownVirtualCacheFunctions(t *testing.T) {
func TestTopdownVirtualCache(t *testing.T) {
ctx := context.Background()
store := inmem.New()

Expand Down Expand Up @@ -316,6 +316,95 @@ func TestTopdownVirtualCacheFunctions(t *testing.T) {
hit: 0,
miss: 2,
},
{
note: "partial object: simple",
module: `package p
s["foo"] = true { true }
s["bar"] = true { true }`,
query: `data.p.s["foo"]; data.p.s["foo"]`,
hit: 1,
miss: 1,
},
{
note: "partial set: simple",
module: `package p
s["foo"] { true }
s["bar"] { true }`,
query: `data.p.s["foo"]; data.p.s["foo"]`,
hit: 1,
miss: 1,
},
{
note: "partial set: object",
module: `package p
s[z] { z := {"foo": "bar"} }`,
query: `x = {"foo": "bar"}; data.p.s[x]; data.p.s[x]`,
hit: 1,
miss: 1,
},
{
note: "partial set: miss",
module: `package p
s[z] { z = true }`,
query: `data.p.s[true]; not data.p.s[false]`,
hit: 0,
miss: 2,
},
{
note: "partial set: full extent cached",
module: `package test
p[x] { x = 1 }
p[x] { x = 2 }
`,
query: "data.test.p = x; data.test.p = y",
hit: 1,
miss: 1,
},
{
note: "partial set: all rules + each rule (non-ground var) cached",
module: `package test
p { data.test.q = x; data.test.q[y] = z; data.test.q[a] = b }
q[x] { x = 1 }
q[x] { x = 2 }
`,
query: "data.test.p = true",
hit: 3, // 'data.test.q[y] = z' + 2x 'data.test.q[a] = b'
miss: 2, // 'data.test.p = true' + 'data.test.q = x'
},
{
note: "partial set: all rules + each rule (non-ground composite) cached",
module: `package test
p { data.test.q = x; data.test.q[[y, 1]] = z; data.test.q[[a, 2]] = b }
q[[x, x]] { x = 1 }
q[[x, x]] { x = 2 }
`,
query: "data.test.p = true",
hit: 2, // 'data.test.q[[y,1]] = z' + 'data.test.q[[a, 2]] = b'
miss: 2, // 'data.test.p = true' + 'data.test.q = x'
},
{
note: "partial set: each rule (non-ground var), full extent cached",
module: `package test
p { data.test.q[y] = z; data.test.q = x }
q[x] { x = 1 }
q[x] { x = 2 }
`,
query: "data.test.p = x",
hit: 2, // 2x 'data.test.q = x'
miss: 2, // 'data.test.p = true' + 'data.test.q[y] = z'
},
{
note: "partial set: each rule (non-ground composite), full extent cached",
module: `package test
p = y { data.test.q[[y, 1]] = z; data.test.q = x }
q[[x, x]] { x = 1 }
q[[x, x]] { x = 2 }
`,
query: "data.test.p = x",
hit: 0,
miss: 3, // 'data.test.p = true' + 'data.test.q[[y, 1]] = z' + 'data.test.q = x'
exp: 1,
},
}

for _, tc := range tests {
Expand Down
4 changes: 2 additions & 2 deletions topdown/topdown_partial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2246,13 +2246,13 @@ func TestTopDownPartialEval(t *testing.T) {
`
package partial.test
p { not data.partial.__not1_1_3__ }
p { not data.partial.__not1_1_4__ }
p { not data.partial.__not1_1_5__ }
`,
`
package partial
__not1_1_3__ { input[1] = x_term_3_01; x_term_3_01 }
__not1_1_4__ { input[1] = x_term_4_01; x_term_4_01 }
__not1_1_5__ { input[2] = x_term_5_01; x_term_5_01 }
`,
},
Expand Down
Loading

0 comments on commit f3284cf

Please sign in to comment.