Skip to content

Commit

Permalink
fix: resolve issues with mongo client in context during evaluation in…
Browse files Browse the repository at this point in the history
… sdk
  • Loading branch information
davidebianchi committed Jul 12, 2023
1 parent 4cc2222 commit 3af9272
Show file tree
Hide file tree
Showing 15 changed files with 179 additions and 101 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
push:
jobs:
tests:
name: Test with go version ${{ matrix.go_version }} on OS ${{matrix.os}}
name: Test with go version ${{ matrix.go_version }} on OS ${{matrix.os}}
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand All @@ -22,14 +22,14 @@ jobs:
- name: Go get dependencies
run: go get -v -t -d ./...
- name: Run tests
run: make test
run: make coverage
- name: Send the coverage output
uses: shogo82148/actions-goveralls@v1
with:
path-to-profile: coverage.out

bench:
name: Bench with go version ${{ matrix.go_version }} on OS ${{matrix.os}}
name: Bench with go version ${{ matrix.go_version }} on OS ${{matrix.os}}
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand Down Expand Up @@ -61,11 +61,11 @@ jobs:
needs: tests
runs-on: ubuntu-latest
if: ${{ startsWith(github.ref, 'refs/tags/') || github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Configure docker metadata
id: meta
uses: docker/metadata-action@v4
Expand All @@ -81,7 +81,7 @@ jobs:
labels: |
org.opencontainers.image.documentation=https://rond-authz.io
org.opencontainers.image.vendor=rond authz
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2

Expand All @@ -97,15 +97,15 @@ jobs:
with:
username: ${{ secrets.BOT_DOCKER_USERNAME }}
password: ${{ secrets.BOT_DOCKER_TOKEN }}

- name: Prepare build cache
uses: actions/cache@v3
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}
restore-keys: |
${{ runner.os }}-buildx-
- name: Build and push
uses: docker/build-push-action@v4
with:
Expand Down
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ mongo-start:

.PHONY: test
test: clean mongo-start
go test ./... -cover
$(MAKE) clean

.PHONY: coverage
coverage: clean mongo-start
go test ./... -coverprofile coverage.out
$(MAKE) clean

Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ make test

Please note that in order to run tests you need Docker to be installed; tests need a local instance of MongoDB to be up and running, the `make test` command will take care of it by creating a new `mongodb` container. The container is auomatically removed at the end of tests; if it remains leaked simply run `make clean`.

#### With coverage

To run test with coverage file in output, run

```sh
make coverage
```

### Contributing

Please read [CONTRIBUTING.md](./CONTRIBUTING.md) for further details about the process for submitting pull requests.
Expand Down
112 changes: 55 additions & 57 deletions core/opaevaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ var Unknowns = []string{"data.resources"}
type OPAEvaluator struct {
PolicyEvaluator Evaluator
PolicyName string
Context context.Context

m *metrics.Metrics
routerInfo openapi.RouterInfo
context context.Context
mongoClient types.IMongoClient
}
type PartialResultsEvaluatorConfigKey struct{}

Expand All @@ -63,11 +62,11 @@ type PartialEvaluator struct {
PartialEvaluator *rego.PartialResult
}

func createPartialEvaluator(ctx context.Context, logger *logrus.Entry, policy string, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *EvaluatorOptions) (*PartialEvaluator, error) {
func createPartialEvaluator(ctx context.Context, logger *logrus.Entry, policy string, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *OPAEvaluatorOptions) (*PartialEvaluator, error) {
logger.WithField("policyName", policy).Info("precomputing rego policy")

policyEvaluatorTime := time.Now()
partialResultEvaluator, err := NewPartialResultEvaluator(ctx, policy, opaModuleConfig, options)
partialResultEvaluator, err := newPartialResultEvaluator(ctx, policy, opaModuleConfig, options)
if err != nil {
return nil, err
}
Expand All @@ -82,7 +81,7 @@ func createPartialEvaluator(ctx context.Context, logger *logrus.Entry, policy st
return &PartialEvaluator{PartialEvaluator: partialResultEvaluator}, nil
}

func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *EvaluatorOptions) (PartialResultsEvaluators, error) {
func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *OPAEvaluatorOptions) (PartialResultsEvaluators, error) {
if oas == nil {
return nil, fmt.Errorf("oas must not be nil")
}
Expand Down Expand Up @@ -169,27 +168,14 @@ func (h printHook) Print(_ print.Context, message string) error {
return err
}

type EvaluatorOptions struct {
type OPAEvaluatorOptions struct {
EnablePrintStatements bool
MongoClient types.IMongoClient

Metrics *metrics.Metrics
RouterInfo openapi.RouterInfo
}

func (e *EvaluatorOptions) WithMetrics(metrics metrics.Metrics) *EvaluatorOptions {
e.Metrics = &metrics
return e
}

func (e *EvaluatorOptions) WithRouterInfo(routerInfo openapi.RouterInfo) *EvaluatorOptions {
e.RouterInfo = routerInfo
return e
}

func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAModuleConfig, input []byte, options *EvaluatorOptions) (*OPAEvaluator, error) {
func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAModuleConfig, input []byte, options *OPAEvaluatorOptions) (*OPAEvaluator, error) {
if options == nil {
options = &EvaluatorOptions{}
options = &OPAEvaluatorOptions{}
}
inputTerm, err := ast.ParseTerm(string(input))
if err != nil {
Expand All @@ -211,19 +197,16 @@ func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAMod
custom_builtins.MongoFindMany,
)

ctx = mongoclient.WithMongoClient(ctx, options.MongoClient)

return &OPAEvaluator{
PolicyEvaluator: query,
PolicyName: policy,
Context: ctx,

m: options.Metrics,
routerInfo: options.RouterInfo,
context: ctx,
mongoClient: options.MongoClient,
}, nil
}

func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger *logrus.Entry, policy string, input []byte, options *EvaluatorOptions) (*OPAEvaluator, error) {
func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger *logrus.Entry, policy string, input []byte, options *OPAEvaluatorOptions) (*OPAEvaluator, error) {
// TODO: remove logger and set in sdk
logger.WithFields(logrus.Fields{
"policyName": policy,
Expand All @@ -241,9 +224,9 @@ func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger
return evaluator, nil
}

func NewPartialResultEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAModuleConfig, evaluatorOptions *EvaluatorOptions) (*rego.PartialResult, error) {
func newPartialResultEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAModuleConfig, evaluatorOptions *OPAEvaluatorOptions) (*rego.PartialResult, error) {
if evaluatorOptions == nil {
evaluatorOptions = &EvaluatorOptions{}
evaluatorOptions = &OPAEvaluatorOptions{}
}
if opaModuleConfig == nil {
return nil, fmt.Errorf("OPAModuleConfig must not be nil")
Expand Down Expand Up @@ -271,9 +254,9 @@ func NewPartialResultEvaluator(ctx context.Context, policy string, opaModuleConf
return &results, err
}

func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx context.Context, policy string, input []byte, options *EvaluatorOptions) (*OPAEvaluator, error) {
func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx context.Context, policy string, input []byte, options *OPAEvaluatorOptions) (*OPAEvaluator, error) {
if options == nil {
options = &EvaluatorOptions{}
options = &OPAEvaluatorOptions{}
}

if eval, ok := partialEvaluators[policy]; ok {
Expand All @@ -291,32 +274,24 @@ func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx con
return &OPAEvaluator{
PolicyName: policy,
PolicyEvaluator: evaluator,
Context: ctx,

m: options.Metrics,
routerInfo: options.RouterInfo,
context: ctx,
mongoClient: options.MongoClient,
}, nil
}
return nil, fmt.Errorf("%w: %s", ErrEvaluatorNotFound, policy)
}

func (evaluator *OPAEvaluator) metrics() metrics.Metrics {
if evaluator.m != nil {
return *evaluator.m
}
return metrics.SetupMetrics("rond")
}

func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry) (primitive.M, error) {
func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry, options *PolicyEvaluationOptions) (primitive.M, error) {
opaEvaluationTimeStart := time.Now()
partialResults, err := evaluator.PolicyEvaluator.Partial(evaluator.Context)
partialResults, err := evaluator.PolicyEvaluator.Partial(evaluator.getContext())
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrPartialPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)

evaluator.metrics().PolicyEvaluationDurationMilliseconds.With(prometheus.Labels{
options.metrics().PolicyEvaluationDurationMilliseconds.With(prometheus.Labels{
"policy_name": evaluator.PolicyName,
}).Observe(float64(opaEvaluationTime.Milliseconds()))

Expand All @@ -325,9 +300,9 @@ func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry) (primitiv
"policyName": evaluator.PolicyName,
"partialEval": true,
"allowed": true,
"matchedPath": evaluator.routerInfo.MatchedPath,
"requestedPath": evaluator.routerInfo.RequestedPath,
"method": evaluator.routerInfo.Method,
"matchedPath": options.RouterInfo.MatchedPath,
"requestedPath": options.RouterInfo.RequestedPath,
"method": options.RouterInfo.Method,
}).Debug("policy evaluation completed")

client := opatranslator.OPAClient{}
Expand All @@ -344,16 +319,16 @@ func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry) (primitiv
return q, nil
}

func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, error) {
func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry, options *PolicyEvaluationOptions) (interface{}, error) {
opaEvaluationTimeStart := time.Now()

results, err := evaluator.PolicyEvaluator.Eval(evaluator.Context)
results, err := evaluator.PolicyEvaluator.Eval(evaluator.getContext())
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)
evaluator.metrics().PolicyEvaluationDurationMilliseconds.With(prometheus.Labels{
options.metrics().PolicyEvaluationDurationMilliseconds.With(prometheus.Labels{
"policy_name": evaluator.PolicyName,
}).Observe(float64(opaEvaluationTime.Milliseconds()))

Expand All @@ -364,9 +339,9 @@ func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, erro
"partialEval": false,
"allowed": allowed,
"resultsLength": len(results),
"matchedPath": evaluator.routerInfo.MatchedPath,
"requestedPath": evaluator.routerInfo.RequestedPath,
"method": evaluator.routerInfo.Method,
"matchedPath": options.RouterInfo.MatchedPath,
"requestedPath": options.RouterInfo.RequestedPath,
"method": options.RouterInfo.Method,
}).Debug("policy evaluation completed")

logger.WithFields(logrus.Fields{
Expand All @@ -380,12 +355,35 @@ func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, erro
return nil, ErrPolicyEvalFailed
}

func (evaluator *OPAEvaluator) PolicyEvaluation(logger *logrus.Entry, permission *openapi.RondConfig) (interface{}, primitive.M, error) {
func (evaluator *OPAEvaluator) getContext() context.Context {
ctx := evaluator.context
if ctx == nil {
ctx = context.Background()
}
if evaluator.mongoClient != nil {
return mongoclient.WithMongoClient(ctx, evaluator.mongoClient)
}
return ctx
}

type PolicyEvaluationOptions struct {
Metrics *metrics.Metrics
RouterInfo openapi.RouterInfo
}

func (evaluator *PolicyEvaluationOptions) metrics() metrics.Metrics {
if evaluator.Metrics != nil {
return *evaluator.Metrics
}
return metrics.SetupMetrics("rond")
}

func (evaluator *OPAEvaluator) PolicyEvaluation(logger *logrus.Entry, permission *openapi.RondConfig, options *PolicyEvaluationOptions) (interface{}, primitive.M, error) {
if permission.RequestFlow.GenerateQuery {
query, err := evaluator.partiallyEvaluate(logger)
query, err := evaluator.partiallyEvaluate(logger, options)
return nil, query, err
}
dataFromEvaluation, err := evaluator.Evaluate(logger)
dataFromEvaluation, err := evaluator.Evaluate(logger, options)
if err != nil {
return nil, nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func entrypoint(shutdown chan os.Signal) {
registry := prometheus.NewRegistry()
sdk, err := sdk.NewFromOAS(ctx, opaModuleConfig, oas, &sdk.FromOASOptions{
Registry: registry,
EvaluatorOptions: &core.EvaluatorOptions{
EvaluatorOptions: &core.OPAEvaluatorOptions{
EnablePrintStatements: env.IsTraceLogLevel(),
MongoClient: mongoClient,
},
Expand Down
4 changes: 2 additions & 2 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ filter_policy {
registry := prometheus.NewRegistry()
logger, _ := test.NewNullLogger()
sdk, err := sdk.NewFromOAS(context.Background(), opa, oas, &sdk.FromOASOptions{
EvaluatorOptions: &core.EvaluatorOptions{
EvaluatorOptions: &core.OPAEvaluatorOptions{
MongoClient: mongoClient,
},
Registry: registry,
Expand Down Expand Up @@ -1957,7 +1957,7 @@ filter_policy {
registry := prometheus.NewRegistry()
logger, _ := test.NewNullLogger()
sdk, err := sdk.NewFromOAS(context.Background(), opa, oas, &sdk.FromOASOptions{
EvaluatorOptions: &core.EvaluatorOptions{
EvaluatorOptions: &core.OPAEvaluatorOptions{
MongoClient: mongoClient,
},
Registry: registry,
Expand Down
Loading

0 comments on commit 3af9272

Please sign in to comment.