Skip to content

Commit

Permalink
feat: new FromConfig SDK (#219)
Browse files Browse the repository at this point in the history
* feat: remove openapi package from core

* feat: remove rondConfig params from evaluation policy function

* test: add more openapi tests

* feat: add NewWithConfig function in sdk

* Update core/errors.go

* Update openapi/openapi_utils.go

---------

Co-authored-by: Federico Maggi <7142570+fredmaggiowski@users.noreply.github.com>
  • Loading branch information
davidebianchi and fredmaggiowski committed Jul 14, 2023
1 parent 53900d7 commit 8ce2333
Show file tree
Hide file tree
Showing 31 changed files with 1,236 additions and 508 deletions.
2 changes: 2 additions & 0 deletions core/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import "fmt"
var (
ErrMissingRegoModules = fmt.Errorf("no rego module found in directory")
ErrRegoModuleReadFailed = fmt.Errorf("failed rego file read")
ErrInvalidConfig = fmt.Errorf("invalid rond configuration")

ErrEvaluatorCreationFailed = fmt.Errorf("error during evaluator creation")
ErrEvaluatorNotFound = fmt.Errorf("evaluator not found")
Expand All @@ -31,4 +32,5 @@ var (
ErrFailedInputEncode = fmt.Errorf("failed input encode")
ErrFailedInputRequestParse = fmt.Errorf("failed request body parse")
ErrFailedInputRequestDeserialization = fmt.Errorf("failed request body deserialization")
ErrRondConfigNotExists = fmt.Errorf("rond config does not exist")
)
241 changes: 50 additions & 191 deletions core/opaevaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ package core

import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
Expand All @@ -29,151 +27,62 @@ import (
"github.com/rond-authz/rond/internal/mongoclient"
"github.com/rond-authz/rond/internal/opatranslator"
"github.com/rond-authz/rond/internal/utils"
"github.com/rond-authz/rond/openapi"
"github.com/rond-authz/rond/types"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/topdown/print"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson/primitive"
)

type Evaluator interface {
Eval(ctx context.Context) (rego.ResultSet, error)
Partial(ctx context.Context) (*rego.PartialQueries, error)
type RondConfig struct {
RequestFlow RequestFlow `json:"requestFlow"`
ResponseFlow ResponseFlow `json:"responseFlow"`
Options PermissionOptions `json:"options"`
}

var Unknowns = []string{"data.resources"}

type OPAEvaluator struct {
PolicyEvaluator Evaluator
PolicyName string

context context.Context
mongoClient types.IMongoClient
type QueryOptions struct {
HeaderName string `json:"headerName"`
}
type PartialResultsEvaluatorConfigKey struct{}

type PartialResultsEvaluators map[string]PartialEvaluator

type PartialEvaluator struct {
PartialEvaluator *rego.PartialResult
type RequestFlow struct {
PolicyName string `json:"policyName"`
GenerateQuery bool `json:"generateQuery"`
QueryOptions QueryOptions `json:"queryOptions"`
}

func createPartialEvaluator(ctx context.Context, logger *logrus.Entry, policy string, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *OPAEvaluatorOptions) (*PartialEvaluator, error) {
logger.WithField("policyName", policy).Info("precomputing rego policy")

policyEvaluatorTime := time.Now()
partialResultEvaluator, err := newPartialResultEvaluator(ctx, policy, opaModuleConfig, options)
if err != nil {
return nil, err
}

logger.
WithFields(logrus.Fields{
"policyName": policy,
"computationTimeMicroserconds": time.Since(policyEvaluatorTime).Microseconds,
}).
Info("precomputation time")

return &PartialEvaluator{PartialEvaluator: partialResultEvaluator}, nil
type ResponseFlow struct {
PolicyName string `json:"policyName"`
}

func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *OPAEvaluatorOptions) (PartialResultsEvaluators, error) {
if oas == nil {
return nil, fmt.Errorf("oas must not be nil")
}

policyEvaluators := PartialResultsEvaluators{}
for path, OASContent := range oas.Paths {
for verb, verbConfig := range OASContent {
if verbConfig.PermissionV2 == nil {
continue
}

allowPolicy := verbConfig.PermissionV2.RequestFlow.PolicyName
responsePolicy := verbConfig.PermissionV2.ResponseFlow.PolicyName

logger.
WithFields(logrus.Fields{
"verb": verb,
"policyName": allowPolicy,
"path": path,
"responsePolicyName": responsePolicy,
}).
Info("precomputing rego queries for API")

if allowPolicy == "" {
// allow policy is required, if missing assume the API has no valid x-rond configuration.
continue
}

if _, ok := policyEvaluators[allowPolicy]; !ok {
evaluator, err := createPartialEvaluator(ctx, logger, allowPolicy, oas, opaModuleConfig, options)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrEvaluatorCreationFailed, err.Error())
}

policyEvaluators[allowPolicy] = *evaluator
}

if responsePolicy != "" {
if _, ok := policyEvaluators[responsePolicy]; !ok {
evaluator, err := createPartialEvaluator(ctx, logger, responsePolicy, oas, opaModuleConfig, options)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrEvaluatorCreationFailed, err.Error())
}

policyEvaluators[responsePolicy] = *evaluator
}
}
}
}
return policyEvaluators, nil
type PermissionOptions struct {
EnableResourcePermissionsMapOptimization bool `json:"enableResourcePermissionsMapOptimization"`
IgnoreTrailingSlash bool `json:"ignoreTrailingSlash,omitempty"`
}

func NewPrintHook(w io.Writer, policy string) print.Hook {
return printHook{
w: w,
policyName: policy,
}
type Evaluator interface {
Eval(ctx context.Context) (rego.ResultSet, error)
Partial(ctx context.Context) (*rego.PartialQueries, error)
}

type printHook struct {
w io.Writer
policyName string
}
var Unknowns = []string{"data.resources"}

type LogPrinter struct {
Level int `json:"level"`
Message string `json:"msg"`
Time int64 `json:"time"`
PolicyName string `json:"policyName"`
}
type OPAEvaluator struct {
PolicyEvaluator Evaluator
PolicyName string

func (h printHook) Print(_ print.Context, message string) error {
structMessage := LogPrinter{
Level: 10,
Message: message,
Time: time.Now().UnixNano() / 1000,
PolicyName: h.policyName,
}
msg, err := json.Marshal(structMessage)
if err != nil {
return err
}
_, err = fmt.Fprintln(h.w, string(msg))
return err
context context.Context
mongoClient types.IMongoClient
generateQuery bool
}

type OPAEvaluatorOptions struct {
EnablePrintStatements bool
MongoClient types.IMongoClient
}

func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAModuleConfig, input []byte, options *OPAEvaluatorOptions) (*OPAEvaluator, error) {
func newQueryOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAModuleConfig, input []byte, options *OPAEvaluatorOptions) (*OPAEvaluator, error) {
if options == nil {
options = &OPAEvaluatorOptions{}
}
Expand Down Expand Up @@ -201,8 +110,9 @@ func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAMod
PolicyEvaluator: query,
PolicyName: policy,

context: ctx,
mongoClient: options.MongoClient,
context: ctx,
mongoClient: options.MongoClient,
generateQuery: true,
}, nil
}

Expand All @@ -213,7 +123,7 @@ func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger
}).Info("Policy to be evaluated")

opaEvaluatorInstanceTime := time.Now()
evaluator, err := NewOPAEvaluator(ctx, policy, config, input, options)
evaluator, err := newQueryOPAEvaluator(ctx, policy, config, input, options)
if err != nil {
logger.WithError(err).Error(ErrEvaluatorCreationFailed)
return nil, err
Expand All @@ -224,65 +134,10 @@ func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger
return evaluator, nil
}

func newPartialResultEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAModuleConfig, evaluatorOptions *OPAEvaluatorOptions) (*rego.PartialResult, error) {
if evaluatorOptions == nil {
evaluatorOptions = &OPAEvaluatorOptions{}
}
if opaModuleConfig == nil {
return nil, fmt.Errorf("OPAModuleConfig must not be nil")
}

sanitizedPolicy := strings.Replace(policy, ".", "_", -1)
queryString := fmt.Sprintf("data.policies.%s", sanitizedPolicy)

options := []func(*rego.Rego){
rego.Query(queryString),
rego.Module(opaModuleConfig.Name, opaModuleConfig.Content),
rego.Unknowns(Unknowns),
rego.EnablePrintStatements(evaluatorOptions.EnablePrintStatements),
rego.PrintHook(NewPrintHook(os.Stdout, policy)),
rego.Capabilities(ast.CapabilitiesForThisVersion()),
custom_builtins.GetHeaderFunction,
}
if evaluatorOptions.MongoClient != nil {
ctx = mongoclient.WithMongoClient(ctx, evaluatorOptions.MongoClient)
options = append(options, custom_builtins.MongoFindOne, custom_builtins.MongoFindMany)
}
regoInstance := rego.New(options...)

results, err := regoInstance.PartialResult(ctx)
return &results, err
}

func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx context.Context, policy string, input []byte, options *OPAEvaluatorOptions) (*OPAEvaluator, error) {
func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry, options *PolicyEvaluationOptions) (primitive.M, error) {
if options == nil {
options = &OPAEvaluatorOptions{}
}

if eval, ok := partialEvaluators[policy]; ok {
inputTerm, err := ast.ParseTerm(string(input))
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrFailedInputParse, err)
}

evaluator := eval.PartialEvaluator.Rego(
rego.ParsedInput(inputTerm.Value),
rego.EnablePrintStatements(options.EnablePrintStatements),
rego.PrintHook(NewPrintHook(os.Stdout, policy)),
)

return &OPAEvaluator{
PolicyName: policy,
PolicyEvaluator: evaluator,

context: ctx,
mongoClient: options.MongoClient,
}, nil
options = &PolicyEvaluationOptions{}
}
return nil, fmt.Errorf("%w: %s", ErrEvaluatorNotFound, policy)
}

func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry, options *PolicyEvaluationOptions) (primitive.M, error) {
opaEvaluationTimeStart := time.Now()
partialResults, err := evaluator.PolicyEvaluator.Partial(evaluator.getContext())
if err != nil {
Expand All @@ -295,15 +150,15 @@ func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry, options *
"policy_name": evaluator.PolicyName,
}).Observe(float64(opaEvaluationTime.Milliseconds()))

logger.WithFields(logrus.Fields{
fields := logrus.Fields{
"evaluationTimeMicroseconds": opaEvaluationTime.Microseconds(),
"policyName": evaluator.PolicyName,
"partialEval": true,
"allowed": true,
"matchedPath": options.RouterInfo.MatchedPath,
"requestedPath": options.RouterInfo.RequestedPath,
"method": options.RouterInfo.Method,
}).Debug("policy evaluation completed")
}
addDataToLogFields(fields, options.AdditionalLogFields)

logger.WithFields(fields).Debug("policy evaluation completed")

client := opatranslator.OPAClient{}
q, err := client.ProcessQuery(partialResults)
Expand All @@ -320,6 +175,10 @@ func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry, options *
}

func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry, options *PolicyEvaluationOptions) (interface{}, error) {
if options == nil {
options = &PolicyEvaluationOptions{}
}

opaEvaluationTimeStart := time.Now()

results, err := evaluator.PolicyEvaluator.Eval(evaluator.getContext())
Expand All @@ -333,16 +192,16 @@ func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry, options *PolicyEva
}).Observe(float64(opaEvaluationTime.Milliseconds()))

allowed, responseBodyOverwriter := processResults(results)
logger.WithFields(logrus.Fields{
fields := logrus.Fields{
"evaluationTimeMicroseconds": opaEvaluationTime.Microseconds(),
"policyName": evaluator.PolicyName,
"partialEval": false,
"allowed": allowed,
"resultsLength": len(results),
"matchedPath": options.RouterInfo.MatchedPath,
"requestedPath": options.RouterInfo.RequestedPath,
"method": options.RouterInfo.Method,
}).Debug("policy evaluation completed")
}
addDataToLogFields(fields, options.AdditionalLogFields)

logger.WithFields(fields).Debug("policy evaluation completed")

logger.WithFields(logrus.Fields{
"policyName": evaluator.PolicyName,
Expand All @@ -367,8 +226,8 @@ func (evaluator *OPAEvaluator) getContext() context.Context {
}

type PolicyEvaluationOptions struct {
Metrics *metrics.Metrics
RouterInfo openapi.RouterInfo
Metrics *metrics.Metrics
AdditionalLogFields map[string]string
}

func (evaluator *PolicyEvaluationOptions) metrics() metrics.Metrics {
Expand All @@ -378,8 +237,8 @@ func (evaluator *PolicyEvaluationOptions) metrics() metrics.Metrics {
return metrics.SetupMetrics("rond")
}

func (evaluator *OPAEvaluator) PolicyEvaluation(logger *logrus.Entry, permission *openapi.RondConfig, options *PolicyEvaluationOptions) (interface{}, primitive.M, error) {
if permission.RequestFlow.GenerateQuery {
func (evaluator *OPAEvaluator) PolicyEvaluation(logger *logrus.Entry, options *PolicyEvaluationOptions) (interface{}, primitive.M, error) {
if evaluator.generateQuery {
query, err := evaluator.partiallyEvaluate(logger, options)
return nil, query, err
}
Expand Down
Loading

0 comments on commit 8ce2333

Please sign in to comment.