Skip to content

Commit

Permalink
feat: response filter now uses sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
davidebianchi committed Jun 26, 2023
1 parent 258326c commit 4f0be7c
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 103 deletions.
71 changes: 18 additions & 53 deletions core/opa_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (

"github.com/rond-authz/rond/internal/mongoclient"
"github.com/rond-authz/rond/internal/utils"
"github.com/rond-authz/rond/openapi"
"github.com/rond-authz/rond/types"

"github.com/gorilla/mux"
Expand All @@ -35,39 +34,33 @@ import (
type OPATransport struct {
http.RoundTripper
// FIXME: this overlaps with the req.Context used during RoundTrip.
context context.Context
logger *logrus.Entry
request *http.Request
permission *openapi.RondConfig
partialResultsEvaluators PartialResultsEvaluators

clientHeaderKey string
userHeaders types.UserHeadersKeys
evaluatorOptions *EvaluatorOptions
context context.Context
logger *logrus.Entry
request *http.Request

clientHeaderKey string
userHeaders types.UserHeadersKeys
evaluatorSDK SDKEvaluator
}

func NewOPATransport(
transport http.RoundTripper,
context context.Context,
logger *logrus.Entry,
req *http.Request,
permission *openapi.RondConfig,
partialResultsEvaluators PartialResultsEvaluators,
clientHeaderKey string,
userHeadersKeys types.UserHeadersKeys,
evaluatorOptions *EvaluatorOptions,
evaluatorSDK SDKEvaluator,
) *OPATransport {
return &OPATransport{
RoundTripper: transport,
context: req.Context(),
logger: logger,
request: req,
permission: permission,
partialResultsEvaluators: partialResultsEvaluators,

clientHeaderKey: clientHeaderKey,
userHeaders: userHeadersKeys,
evaluatorOptions: evaluatorOptions,
RoundTripper: transport,
context: req.Context(),
logger: logger,
request: req,

clientHeaderKey: clientHeaderKey,
userHeaders: userHeadersKeys,
evaluatorSDK: evaluatorSDK,
}
}

Expand Down Expand Up @@ -116,42 +109,14 @@ func (t *OPATransport) RoundTrip(req *http.Request) (resp *http.Response, err er

pathParams := mux.Vars(t.request)
rondReq := NewRondInput(t.request, t.clientHeaderKey, pathParams)
input, err := rondReq.FromRequestInfo(userInfo, decodedBody)
if err != nil {
t.responseWithError(resp, err, http.StatusInternalServerError)
return resp, nil
}

regoInput, err := CreateRegoQueryInput(t.logger, input, RegoInputOptions{
EnableResourcePermissionsMapOptimization: t.permission.Options.EnableResourcePermissionsMapOptimization,
})
if err != nil {
t.responseWithError(resp, err, http.StatusInternalServerError)
return resp, nil
}

evaluator, err := t.partialResultsEvaluators.GetEvaluatorFromPolicy(t.context, t.permission.ResponseFlow.PolicyName, regoInput, t.evaluatorOptions)
if err != nil {
t.logger.WithField("error", logrus.Fields{
"policyName": t.permission.ResponseFlow.PolicyName,
"message": err.Error(),
}).Error("RBAC policy evaluation on response failed")
t.responseWithError(resp, err, http.StatusInternalServerError)
return resp, nil
}

bodyToProxy, err := evaluator.Evaluate(t.logger)
responseBody, err := t.evaluatorSDK.EvaluateResponsePolicy(t.context, rondReq, userInfo, decodedBody)
if err != nil {
t.responseWithError(resp, err, http.StatusForbidden)
return resp, nil
}

marshalledBody, err := json.Marshal(bodyToProxy)
if err != nil {
t.responseWithError(resp, err, http.StatusInternalServerError)
return resp, nil
}
overwriteResponse(resp, marshalledBody)
overwriteResponse(resp, responseBody)
return resp, nil
}

Expand Down
79 changes: 40 additions & 39 deletions core/opa_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package core

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -53,17 +52,15 @@ func TestRoundTripErrors(t *testing.T) {
JSON(responseBody)

req := httptest.NewRequest(http.MethodPost, "http://example.com/some-api", nil)
transport := &OPATransport{
transport := NewOPATransport(
http.DefaultTransport,
req.Context(),
logrus.NewEntry(logger),
req,
nil,
nil,
"",
types.UserHeadersKeys{},
nil,
}
)

resp, err := transport.RoundTrip(req)
require.NoError(t, err, "unexpected error")
Expand Down Expand Up @@ -92,17 +89,15 @@ func TestOPATransportResponseWithError(t *testing.T) {

req := httptest.NewRequest(http.MethodPost, "http://example.com/some-api", nil)

transport := &OPATransport{
transport := NewOPATransport(
http.DefaultTransport,
req.Context(),
logrus.NewEntry(logger),
req,
nil,
nil,
"",
types.UserHeadersKeys{},
nil,
}
)

t.Run("generic business error message", func(t *testing.T) {
resp := &http.Response{
Expand Down Expand Up @@ -151,15 +146,14 @@ func TestOPATransportResponseWithError(t *testing.T) {

func TestOPATransportRoundTrip(t *testing.T) {
logger, _ := test.NewNullLogger()
req := httptest.NewRequest(http.MethodPost, "http://example.com/some-api", nil)
req := httptest.NewRequest(http.MethodGet, "/users", nil)

t.Run("returns error on RoundTrip error", func(t *testing.T) {
transport := NewOPATransport(
&MockRoundTrip{Error: fmt.Errorf("some error")},
req.Context(),
logrus.NewEntry(logger),
req,
nil, nil,
"",
types.UserHeadersKeys{
IDHeaderKey: "useridheader",
Expand Down Expand Up @@ -279,35 +273,36 @@ func TestOPATransportRoundTrip(t *testing.T) {
})

t.Run("ok with filter response", func(t *testing.T) {
resp := http.Response{
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader([]byte(`{"some":"field"}`))),
ContentLength: 16,
Header: http.Header{"Content-Type": []string{"application/json"}},
}

req = req.Clone(metrics.WithValue(openapi.WithRouterInfo(logrus.NewEntry(logger), req.Context(), req), metrics.SetupMetrics("")))
logEntry := logrus.NewEntry(logger)
req = req.Clone(metrics.WithValue(openapi.WithRouterInfo(logEntry, req.Context(), req), metrics.SetupMetrics("")))

partialResult, err := NewPartialResultEvaluator(context.Background(), "my_policy", &OPAModuleConfig{
Content: "package policies my_policy [resources] { resources := input.response.body }",
}, nil)
evaluator := getSdk(t, &sdkOptions{
oasFilePath: "../mocks/rondOasConfig.json",
opaModuleContent: "package policies responsepolicy [resources] { resources := input.response.body }",
})
evaluatorSDK, err := evaluator.FindEvaluator(logEntry, http.MethodGet, "/users/")
require.NoError(t, err)

transport := &OPATransport{
RoundTripper: &MockRoundTrip{Response: &resp},
context: req.Context(),
logger: logrus.NewEntry(logger),
request: req,
permission: &openapi.RondConfig{
ResponseFlow: openapi.ResponseFlow{PolicyName: "my_policy"},
},
partialResultsEvaluators: PartialResultsEvaluators{"my_policy": PartialEvaluator{partialResult}},
userHeaders: types.UserHeadersKeys{
transport := NewOPATransport(
&MockRoundTrip{Response: resp},
req.Context(),
logrus.NewEntry(logger),
req,
"",
types.UserHeadersKeys{
IDHeaderKey: "useridheader",
GroupsHeaderKey: "usergroupsheader",
PropertiesHeaderKey: "userpropertiesheader",
},
}
evaluatorSDK,
)

actualResp, err := transport.RoundTrip(req)
require.NoError(t, err, "response body is not valid")
Expand Down Expand Up @@ -419,22 +414,28 @@ func TestOPATransportRoundTrip(t *testing.T) {
ContentLength: 0,
Header: http.Header{"Content-Type": []string{"application/json"}},
}
transport := &OPATransport{
RoundTripper: &MockRoundTrip{Response: resp},
context: req.Context(),
logger: logrus.NewEntry(logger),
request: req,
permission: &openapi.RondConfig{
ResponseFlow: openapi.ResponseFlow{PolicyName: "my_policy"},
},
partialResultsEvaluators: PartialResultsEvaluators{"my_policy": {}},
userHeaders: types.UserHeadersKeys{
evaluator := getSdk(t, &sdkOptions{
oasFilePath: "../mocks/rondOasConfig.json",
opaModuleContent: "package policies responsepolicy [resources] { resources := input.response.body }",
})
logEntry := logrus.NewEntry(logger)
evaluatorSDK, err := evaluator.FindEvaluator(logEntry, http.MethodGet, "/users/")
require.NoError(t, err)

transport := NewOPATransport(
&MockRoundTrip{Response: resp},
req.Context(),
logrus.NewEntry(logger),
req,
"",
types.UserHeadersKeys{
IDHeaderKey: "useridheader",
GroupsHeaderKey: "usergroupsheader",
PropertiesHeaderKey: "userpropertiesheader",
},
}
resp, err := transport.RoundTrip(req)
evaluatorSDK,
)
resp, err = transport.RoundTrip(req)
require.Nil(t, err)
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
bodyBytes, err := io.ReadAll(resp.Body)
Expand Down
14 changes: 3 additions & 11 deletions service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@ func ReverseProxyOrResponse(
evaluatorSdk core.SDKEvaluator,
) {
var permission openapi.RondConfig
var partialResultsEvaluators core.PartialResultsEvaluators
if evaluatorSdk != nil {
permission = evaluatorSdk.Config()
partialResultsEvaluators = evaluatorSdk.PartialResultsEvaluators()
}

if env.Standalone {
Expand All @@ -64,7 +62,7 @@ func ReverseProxyOrResponse(
}
return
}
ReverseProxy(logger, env, w, req, &permission, partialResultsEvaluators)
ReverseProxy(logger, env, w, req, &permission, evaluatorSdk)
}

func rbacHandler(w http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -149,7 +147,7 @@ func ReverseProxy(
w http.ResponseWriter,
req *http.Request,
permission *openapi.RondConfig,
partialResultsEvaluators core.PartialResultsEvaluators,
evaluatorSdk core.SDKEvaluator,
) {
targetHostFromEnv := env.TargetServiceHost
proxy := httputil.ReverseProxy{
Expand All @@ -164,10 +162,6 @@ func ReverseProxy(
},
}

options := &core.EvaluatorOptions{
EnablePrintStatements: env.IsTraceLogLevel(),
}

// Check on nil is performed to proxy the oas documentation path
if permission == nil || permission.ResponseFlow.PolicyName == "" {
proxy.ServeHTTP(w, req)
Expand All @@ -178,16 +172,14 @@ func ReverseProxy(
req.Context(),
logger,
req,
permission,
partialResultsEvaluators,

env.ClientTypeHeader,
types.UserHeadersKeys{
IDHeaderKey: env.UserIdHeader,
GroupsHeaderKey: env.UserGroupsHeader,
PropertiesHeaderKey: env.UserPropertiesHeader,
},
options,
evaluatorSdk,
)
proxy.ServeHTTP(w, req)
}
Expand Down

0 comments on commit 4f0be7c

Please sign in to comment.