diff --git a/compile/compile.go b/compile/compile.go index 5b08504ac5..ddcef95bb5 100644 --- a/compile/compile.go +++ b/compile/compile.go @@ -326,6 +326,12 @@ func (c *Compiler) Build(ctx context.Context) error { return err } + // Dedup entrypoint refs, if both CLI and entrypoint metadata annotations + // were used. + if err := c.dedupEntrypointRefs(); err != nil { + return err + } + if err := c.optimize(ctx); err != nil { return err } @@ -429,6 +435,30 @@ func (c *Compiler) checkNumEntrypoints() error { return nil } +// Note(philipc): When an entrypoint is provided on the CLI and from an +// entrypoint annotation, it can lead to duplicates in the slice of +// entrypoint refs. This can cause panics down the line due to c.entrypoints +// being a different length than c.entrypointrefs. As a result, we have to +// trim out the duplicates. +func (c *Compiler) dedupEntrypointRefs() error { + // Discover the distinct entrypoint refs. + entrypointRefSet := make(map[string]int, len(c.entrypointrefs)) + for i, r := range c.entrypointrefs { + refString := r.String() + // Store only the first index in the list that matches. + if _, ok := entrypointRefSet[refString]; !ok { + entrypointRefSet[refString] = i + } + } + // Build list of entrypoint refs, without duplicates. + newEntrypointRefs := make([]*ast.Term, 0, len(entrypointRefSet)) + for _, idx := range entrypointRefSet { + newEntrypointRefs = append(newEntrypointRefs, c.entrypointrefs[idx]) + } + c.entrypointrefs = newEntrypointRefs + return nil +} + // Bundle returns the compiled bundle. This function can be called to retrieve the // output of the compiler (as an alternative to having the bundle written to a stream.) func (c *Compiler) Bundle() *bundle.Bundle { diff --git a/compile/compile_test.go b/compile/compile_test.go index 93ba2b4f8d..787977bd69 100644 --- a/compile/compile_test.go +++ b/compile/compile_test.go @@ -978,9 +978,7 @@ update { func modulesToString(modules []bundle.ModuleFile) string { var buf bytes.Buffer - //result := make([]string, len(modules)) for i, m := range modules { - //result[i] = m.Parsed.String() buf.WriteString(strconv.Itoa(i)) buf.WriteString(":\n") buf.WriteString(string(m.Raw)) @@ -1623,6 +1621,28 @@ q[3] "test/p": {}, }, }, + { + note: "overlapping manual entrypoints + annotation entrypoints", + entrypoints: []string{"test/p"}, + modules: map[string]string{ + "test.rego": ` +package test + +# METADATA +# entrypoint: true +p { + q[input.x] +} + +q[1] +q[2] +q[3] + `, + }, + wantEntrypoints: map[string]struct{}{ + "test/p": {}, + }, + }, { note: "ref head rule annotation", entrypoints: []string{},