From d7ab817de74eede13b20c3f7009cf5b463bcef81 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 31 Mar 2021 10:09:06 -0600 Subject: [PATCH] authorize: add databroker server and record version to result, force sync via polling (#2024) * authorize: add databroker server and record version to result, force sync via polling * wrap inmem store to take read lock when grabbing databroker versions * address code review comments * reset max to 0 --- authorize/check_response_test.go | 2 +- authorize/evaluator/custom_test.go | 6 +- authorize/evaluator/evaluator.go | 100 +--------------- authorize/evaluator/evaluator_test.go | 16 +-- authorize/evaluator/opa/policy/authz.rego | 4 + authorize/evaluator/opa_test.go | 14 ++- authorize/evaluator/result.go | 115 +++++++++++++++++++ authorize/evaluator/store.go | 81 ++++++++----- authorize/evaluator/store_test.go | 6 +- authorize/grpc.go | 81 +------------ authorize/grpc_test.go | 130 --------------------- authorize/sync.go | 134 +++++++++++++++++++++- authorize/sync_test.go | 116 +++++++++++++++++++ internal/databroker/config_source.go | 2 +- internal/identity/manager/sync.go | 2 +- pkg/grpc/databroker/syncer.go | 6 +- pkg/grpc/databroker/syncer_test.go | 8 +- 17 files changed, 464 insertions(+), 359 deletions(-) create mode 100644 authorize/evaluator/result.go create mode 100644 authorize/sync_test.go diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index 50109cc58ee..ba0e2170c89 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -39,7 +39,7 @@ func TestAuthorize_okResponse(t *testing.T) { encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) a.state.Load().encoder = encoder a.currentOptions.Store(opt) - a.store = evaluator.NewStoreFromProtos( + a.store = evaluator.NewStoreFromProtos(0, &session.Session{ Id: "SESSION_ID", UserId: "USER_ID", diff --git a/authorize/evaluator/custom_test.go b/authorize/evaluator/custom_test.go index 5ee4ce3b811..d915b98cb23 100644 --- a/authorize/evaluator/custom_test.go +++ b/authorize/evaluator/custom_test.go @@ -14,7 +14,7 @@ func TestCustomEvaluator(t *testing.T) { store := NewStore() t.Run("bool deny", func(t *testing.T) { - ce := NewCustomEvaluator(store.opaStore) + ce := NewCustomEvaluator(store) res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{ RegoPolicy: ` package pomerium.custom_policy @@ -29,7 +29,7 @@ func TestCustomEvaluator(t *testing.T) { assert.Empty(t, res.Reason) }) t.Run("set deny", func(t *testing.T) { - ce := NewCustomEvaluator(store.opaStore) + ce := NewCustomEvaluator(store) res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{ RegoPolicy: ` package pomerium.custom_policy @@ -44,7 +44,7 @@ func TestCustomEvaluator(t *testing.T) { assert.Equal(t, "test", res.Reason) }) t.Run("missing package", func(t *testing.T) { - ce := NewCustomEvaluator(store.opaStore) + ce := NewCustomEvaluator(store) res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{ RegoPolicy: `allow = true`, }) diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index 083254b8bea..a682c4fd405 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -7,7 +7,6 @@ import ( "encoding/base64" "fmt" "net/http" - "strconv" "github.com/open-policy-agent/opa/rego" "gopkg.in/square/go-jose.v2" @@ -29,7 +28,7 @@ type Evaluator struct { // New creates a new Evaluator. func New(options *config.Options, store *Store) (*Evaluator, error) { e := &Evaluator{ - custom: NewCustomEvaluator(store.opaStore), + custom: NewCustomEvaluator(store), policies: options.GetAllPolicies(), store: store, } @@ -55,7 +54,7 @@ func New(options *config.Options, store *Store) (*Evaluator, error) { store.UpdateSigningKey(jwk) e.rego = rego.New( - rego.Store(store.opaStore), + rego.Store(store), rego.Module("pomerium.authz", string(authzPolicy)), rego.Query("result = data.pomerium.authz"), getGoogleCloudServerlessHeadersRegoOption, @@ -91,6 +90,9 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error) MatchingPolicy: getMatchingPolicy(res[0].Bindings.WithoutWildcards(), e.policies), Headers: getHeadersVar(res[0].Bindings.WithoutWildcards()), } + evalResult.DataBrokerServerVersion, evalResult.DataBrokerRecordVersion = getDataBrokerVersions( + res[0].Bindings, + ) allow := getAllowVar(res[0].Bindings.WithoutWildcards()) // evaluate any custom policies @@ -181,95 +183,3 @@ func (e *Evaluator) newInput(req *Request, isValidClientCertificate bool) *input i.IsValidClientCertificate = isValidClientCertificate return i } - -// Result is the result of evaluation. -type Result struct { - Status int - Message string - Headers map[string]string - MatchingPolicy *config.Policy -} - -func getMatchingPolicy(vars rego.Vars, policies []config.Policy) *config.Policy { - result, ok := vars["result"].(map[string]interface{}) - if !ok { - return nil - } - - idx, err := strconv.Atoi(fmt.Sprint(result["route_policy_idx"])) - if err != nil { - return nil - } - - if idx >= len(policies) { - return nil - } - - return &policies[idx] -} - -func getAllowVar(vars rego.Vars) bool { - result, ok := vars["result"].(map[string]interface{}) - if !ok { - return false - } - - allow, ok := result["allow"].(bool) - if !ok { - return false - } - return allow -} - -func getDenyVar(vars rego.Vars) []Result { - result, ok := vars["result"].(map[string]interface{}) - if !ok { - return nil - } - - denials, ok := result["deny"].([]interface{}) - if !ok { - return nil - } - - results := make([]Result, 0, len(denials)) - for _, denial := range denials { - denial, ok := denial.([]interface{}) - if !ok || len(denial) != 2 { - continue - } - - status, err := strconv.Atoi(fmt.Sprint(denial[0])) - if err != nil { - log.Error().Err(err).Msg("invalid type in deny") - continue - } - msg := fmt.Sprint(denial[1]) - - results = append(results, Result{ - Status: status, - Message: msg, - }) - } - return results -} - -func getHeadersVar(vars rego.Vars) map[string]string { - headers := make(map[string]string) - - result, ok := vars["result"].(map[string]interface{}) - if !ok { - return headers - } - - m, ok := result["identity_headers"].(map[string]interface{}) - if !ok { - return headers - } - - for k, v := range m { - headers[k] = fmt.Sprint(v) - } - - return headers -} diff --git a/authorize/evaluator/evaluator_test.go b/authorize/evaluator/evaluator_test.go index 033fb09d1fc..392f269cd4d 100644 --- a/authorize/evaluator/evaluator_test.go +++ b/authorize/evaluator/evaluator_test.go @@ -25,7 +25,7 @@ import ( func TestJSONMarshal(t *testing.T) { opt := config.NewDefaultOptions() opt.AuthenticateURLString = "https://authenticate.example.com" - e, err := New(opt, NewStoreFromProtos( + e, err := New(opt, NewStoreFromProtos(0, &session.Session{ UserId: "user1", }, @@ -100,7 +100,7 @@ func TestEvaluator_Evaluate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - store := NewStoreFromProtos() + store := NewStoreFromProtos(0) data, _ := ptypes.MarshalAny(&session.Session{ Version: "1", Id: sessionID, @@ -116,7 +116,7 @@ func TestEvaluator_Evaluate(t *testing.T) { RefreshToken: "REFRESH TOKEN", }, }) - store.UpdateRecord(&databroker.Record{ + store.UpdateRecord(0, &databroker.Record{ Version: 1, Type: "type.googleapis.com/session.Session", Id: sessionID, @@ -127,7 +127,7 @@ func TestEvaluator_Evaluate(t *testing.T) { Id: userID, Email: "foo@example.com", }) - store.UpdateRecord(&databroker.Record{ + store.UpdateRecord(0, &databroker.Record{ Version: 1, Type: "type.googleapis.com/user.User", Id: userID, @@ -189,7 +189,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) { RefreshToken: "REFRESH TOKEN", }, }) - store.UpdateRecord(&databroker.Record{ + store.UpdateRecord(0, &databroker.Record{ Version: uint64(i), Type: "type.googleapis.com/session.Session", Id: sessionID, @@ -199,7 +199,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) { Version: fmt.Sprint(i), Id: userID, }) - store.UpdateRecord(&databroker.Record{ + store.UpdateRecord(0, &databroker.Record{ Version: uint64(i), Type: "type.googleapis.com/user.User", Id: userID, @@ -211,7 +211,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) { Id: userID, GroupIds: []string{"1", "2", "3", "4"}, }) - store.UpdateRecord(&databroker.Record{ + store.UpdateRecord(0, &databroker.Record{ Version: uint64(i), Type: data.TypeUrl, Id: userID, @@ -222,7 +222,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) { Version: fmt.Sprint(i), Id: fmt.Sprint(i), }) - store.UpdateRecord(&databroker.Record{ + store.UpdateRecord(0, &databroker.Record{ Version: uint64(i), Type: data.TypeUrl, Id: fmt.Sprint(i), diff --git a/authorize/evaluator/opa/policy/authz.rego b/authorize/evaluator/opa/policy/authz.rego index e9db31b8dda..13fa0e65779 100644 --- a/authorize/evaluator/opa/policy/authz.rego +++ b/authorize/evaluator/opa/policy/authz.rego @@ -5,6 +5,10 @@ default allow = false # 5 minutes from now in seconds five_minutes := (time.now_ns() / 1e9) + (60 * 5) +# databroker versions to know which version of the data was evaluated +databroker_server_version := data.databroker_server_version +databroker_record_version := data.databroker_record_version + route_policy_idx := first_allowed_route_policy_idx(input.http.url) route_policy := data.route_policies[route_policy_idx] diff --git a/authorize/evaluator/opa_test.go b/authorize/evaluator/opa_test.go index d81d61584cf..eba5df89ae1 100644 --- a/authorize/evaluator/opa_test.go +++ b/authorize/evaluator/opa_test.go @@ -3,6 +3,7 @@ package evaluator import ( "context" "encoding/json" + "math" "testing" "time" @@ -11,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" + "google.golang.org/protobuf/types/known/wrapperspb" "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" @@ -37,13 +39,13 @@ func TestOPA(t *testing.T) { eval := func(t *testing.T, policies []config.Policy, data []proto.Message, req *Request, isValidClientCertificate bool) rego.Result { authzPolicy, err := readPolicy() require.NoError(t, err) - store := NewStoreFromProtos(data...) + store := NewStoreFromProtos(math.MaxUint64, data...) store.UpdateIssuer("authenticate.example.com") store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user")) store.UpdateRoutePolicies(policies) store.UpdateSigningKey(privateJWK) r := rego.New( - rego.Store(store.opaStore), + rego.Store(store), rego.Module("pomerium.authz", string(authzPolicy)), rego.Query("result = data.pomerium.authz"), getGoogleCloudServerlessHeadersRegoOption, @@ -646,4 +648,12 @@ func TestOPA(t *testing.T) { }, true) assert.True(t, res.Bindings["result"].(M)["allow"].(bool)) }) + t.Run("databroker versions", func(t *testing.T) { + res := eval(t, nil, []proto.Message{ + wrapperspb.String("test"), + }, &Request{}, false) + serverVersion, recordVersion := getDataBrokerVersions(res.Bindings) + assert.Equal(t, uint64(math.MaxUint64), serverVersion) + assert.NotEqual(t, uint64(0), recordVersion) // random + }) } diff --git a/authorize/evaluator/result.go b/authorize/evaluator/result.go new file mode 100644 index 00000000000..3de1f20910c --- /dev/null +++ b/authorize/evaluator/result.go @@ -0,0 +1,115 @@ +package evaluator + +import ( + "fmt" + "strconv" + + "github.com/open-policy-agent/opa/rego" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/log" +) + +// Result is the result of evaluation. +type Result struct { + Status int + Message string + Headers map[string]string + MatchingPolicy *config.Policy + + DataBrokerServerVersion, DataBrokerRecordVersion uint64 +} + +func getMatchingPolicy(vars rego.Vars, policies []config.Policy) *config.Policy { + result, ok := vars["result"].(map[string]interface{}) + if !ok { + return nil + } + + idx, err := strconv.Atoi(fmt.Sprint(result["route_policy_idx"])) + if err != nil { + return nil + } + + if idx >= len(policies) { + return nil + } + + return &policies[idx] +} + +func getAllowVar(vars rego.Vars) bool { + result, ok := vars["result"].(map[string]interface{}) + if !ok { + return false + } + + allow, ok := result["allow"].(bool) + if !ok { + return false + } + return allow +} + +func getDenyVar(vars rego.Vars) []Result { + result, ok := vars["result"].(map[string]interface{}) + if !ok { + return nil + } + + denials, ok := result["deny"].([]interface{}) + if !ok { + return nil + } + + results := make([]Result, 0, len(denials)) + for _, denial := range denials { + denial, ok := denial.([]interface{}) + if !ok || len(denial) != 2 { + continue + } + + status, err := strconv.Atoi(fmt.Sprint(denial[0])) + if err != nil { + log.Error().Err(err).Msg("invalid type in deny") + continue + } + msg := fmt.Sprint(denial[1]) + + results = append(results, Result{ + Status: status, + Message: msg, + }) + } + return results +} + +func getHeadersVar(vars rego.Vars) map[string]string { + headers := make(map[string]string) + + result, ok := vars["result"].(map[string]interface{}) + if !ok { + return headers + } + + m, ok := result["identity_headers"].(map[string]interface{}) + if !ok { + return headers + } + + for k, v := range m { + headers[k] = fmt.Sprint(v) + } + + return headers +} + +func getDataBrokerVersions(vars rego.Vars) (serverVersion, recordVersion uint64) { + result, ok := vars["result"].(map[string]interface{}) + if !ok { + return 0, 0 + } + serverVersion, _ = strconv.ParseUint(fmt.Sprint(result["databroker_server_version"]), 10, 64) + recordVersion, _ = strconv.ParseUint(fmt.Sprint(result["databroker_record_version"]), 10, 64) + return serverVersion, recordVersion +} diff --git a/authorize/evaluator/store.go b/authorize/evaluator/store.go index 228f1f1c758..95f095d9700 100644 --- a/authorize/evaluator/store.go +++ b/authorize/evaluator/store.go @@ -25,7 +25,7 @@ import ( // A Store stores data for the OPA rego policy evaluation. type Store struct { - opaStore storage.Store + storage.Store mu sync.RWMutex dataBrokerData map[string]map[string]proto.Message @@ -34,13 +34,13 @@ type Store struct { // NewStore creates a new Store. func NewStore() *Store { return &Store{ - opaStore: inmem.New(), + Store: inmem.New(), dataBrokerData: make(map[string]map[string]proto.Message), } } // NewStoreFromProtos creates a new Store from an existing set of protobuf messages. -func NewStoreFromProtos(msgs ...proto.Message) *Store { +func NewStoreFromProtos(serverVersion uint64, msgs ...proto.Message) *Store { s := NewStore() for _, msg := range msgs { any, err := anypb.New(msg) @@ -58,11 +58,34 @@ func NewStoreFromProtos(msgs ...proto.Message) *Store { record.Id = hasID.GetId() } - s.UpdateRecord(record) + s.UpdateRecord(serverVersion, record) } return s } +// NewTransaction calls the underlying store NewTransaction and takes the transaction lock. +func (s *Store) NewTransaction(ctx context.Context, params ...storage.TransactionParams) (storage.Transaction, error) { + txn, err := s.Store.NewTransaction(ctx, params...) + if err != nil { + return nil, err + } + s.mu.RLock() + return txn, err +} + +// Commit calls the underlying store Commit and releases the transaction lock. +func (s *Store) Commit(ctx context.Context, txn storage.Transaction) error { + err := s.Store.Commit(ctx, txn) + s.mu.RUnlock() + return err +} + +// Abort calls the underlying store Abort and releases the transaction lock. +func (s *Store) Abort(ctx context.Context, txn storage.Transaction) { + s.Store.Abort(ctx, txn) + s.mu.RUnlock() +} + // ClearRecords removes all the records from the store. func (s *Store) ClearRecords() { s.mu.Lock() @@ -107,10 +130,13 @@ func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) { } // UpdateRecord updates a record in the store. -func (s *Store) UpdateRecord(record *databroker.Record) { +func (s *Store) UpdateRecord(serverVersion uint64, record *databroker.Record) { s.mu.Lock() defer s.mu.Unlock() + s.write("/databroker_server_version", fmt.Sprint(serverVersion)) + s.write("/databroker_record_version", fmt.Sprint(record.GetVersion())) + m, ok := s.dataBrokerData[record.GetType()] if !ok { m = make(map[string]proto.Message) @@ -130,36 +156,37 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) { } func (s *Store) write(rawPath string, value interface{}) { - p, ok := storage.ParsePath(rawPath) - if !ok { - log.Error(). - Str("path", rawPath). - Msg("opa-store: invalid path, ignoring data") + err := storage.Txn(context.Background(), s.Store, storage.WriteParams, func(txn storage.Transaction) error { + return s.writeTxn(txn, rawPath, value) + }) + if err != nil { + log.Error().Err(err).Msg("opa-store: error writing data") return } +} - err := storage.Txn(context.Background(), s.opaStore, storage.WriteParams, func(txn storage.Transaction) error { - if len(p) > 1 { - err := storage.MakeDir(context.Background(), s.opaStore, txn, p[:len(p)-1]) - if err != nil { - return err - } - } +func (s *Store) writeTxn(txn storage.Transaction, rawPath string, value interface{}) error { + p, ok := storage.ParsePath(rawPath) + if !ok { + return fmt.Errorf("invalid path") + } - var op storage.PatchOp = storage.ReplaceOp - _, err := s.opaStore.Read(context.Background(), txn, p) - if storage.IsNotFound(err) { - op = storage.AddOp - } else if err != nil { + if len(p) > 1 { + err := storage.MakeDir(context.Background(), s, txn, p[:len(p)-1]) + if err != nil { return err } + } - return s.opaStore.Write(context.Background(), txn, op, p, value) - }) - if err != nil { - log.Error().Err(err).Msg("opa-store: error writing data") - return + var op storage.PatchOp = storage.ReplaceOp + _, err := s.Read(context.Background(), txn, p) + if storage.IsNotFound(err) { + op = storage.AddOp + } else if err != nil { + return err } + + return s.Write(context.Background(), txn, op, p, value) } // GetDataBrokerRecordOption returns a function option that can retrieve databroker data. diff --git a/authorize/evaluator/store_test.go b/authorize/evaluator/store_test.go index 6a443e55b6d..6c5268379f5 100644 --- a/authorize/evaluator/store_test.go +++ b/authorize/evaluator/store_test.go @@ -21,7 +21,7 @@ func TestStore(t *testing.T) { Email: "name@example.com", } any, _ := anypb.New(u) - s.UpdateRecord(&databroker.Record{ + s.UpdateRecord(0, &databroker.Record{ Version: 1, Type: any.GetTypeUrl(), Id: u.GetId(), @@ -36,7 +36,7 @@ func TestStore(t *testing.T) { "email": "name@example.com", }, toMap(v)) - s.UpdateRecord(&databroker.Record{ + s.UpdateRecord(0, &databroker.Record{ Version: 2, Type: any.GetTypeUrl(), Id: u.GetId(), @@ -47,7 +47,7 @@ func TestStore(t *testing.T) { v = s.GetRecordData(any.GetTypeUrl(), u.GetId()) assert.Nil(t, v) - s.UpdateRecord(&databroker.Record{ + s.UpdateRecord(0, &databroker.Record{ Version: 3, Type: any.GetTypeUrl(), Id: u.GetId(), diff --git a/authorize/grpc.go b/authorize/grpc.go index f9e19d1e3af..91ab37e6108 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -3,7 +3,6 @@ package authorize import ( "context" "encoding/base64" - "errors" "io/ioutil" "net/http" "net/url" @@ -19,10 +18,7 @@ import ( "github.com/pomerium/pomerium/internal/telemetry/requestid" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" - "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" - "github.com/pomerium/pomerium/pkg/grpcutil" envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" ) @@ -83,81 +79,6 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe return a.deniedResponse(in, int32(reply.Status), reply.Message, nil) } -func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (*user.User, error) { - ctx, span := trace.StartSpan(ctx, "authorize.forceSync") - defer span.End() - if ss == nil { - return nil, nil - } - s := a.forceSyncSession(ctx, ss.ID) - if s == nil { - return nil, errors.New("session not found") - } - u := a.forceSyncUser(ctx, s.GetUserId()) - return u, nil -} - -func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) interface{ GetUserId() string } { - ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession") - defer span.End() - - state := a.state.Load() - - s, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session) - if ok { - return s - } - - sa, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID).(*user.ServiceAccount) - if ok { - return sa - } - - res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{ - Type: grpcutil.GetTypeURL(new(session.Session)), - Id: sessionID, - }) - if err != nil { - log.Warn().Err(err).Msg("failed to get session from databroker") - return nil - } - - if current := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID); current == nil { - a.store.UpdateRecord(res.GetRecord()) - } - s, _ = a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session) - - return s -} - -func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User { - ctx, span := trace.StartSpan(ctx, "authorize.forceSyncUser") - defer span.End() - - state := a.state.Load() - - u, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User) - if ok { - return u - } - - res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{ - Type: grpcutil.GetTypeURL(new(user.User)), - Id: userID, - }) - if err != nil { - log.Warn().Err(err).Msg("failed to get user from databroker") - return nil - } - - if current := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID); current == nil { - a.store.UpdateRecord(res.GetRecord()) - } - u, _ = a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User) - - return u -} - func getForwardAuthURL(r *http.Request) *url.URL { urqQuery := r.URL.Query().Get("uri") u, _ := urlutil.ParseAndValidateURL(urqQuery) @@ -329,6 +250,8 @@ func logAuthorizeCheck( evt = evt.Str("message", reply.Message) evt = evt.Str("user", u.GetId()) evt = evt.Str("email", u.GetEmail()) + evt = evt.Uint64("databroker_server_version", reply.DataBrokerServerVersion) + evt = evt.Uint64("databroker_record_version", reply.DataBrokerRecordVersion) } // potentially sensitive, only log if debug mode diff --git a/authorize/grpc_test.go b/authorize/grpc_test.go index 2759918d87f..f52cdf12008 100644 --- a/authorize/grpc_test.go +++ b/authorize/grpc_test.go @@ -2,12 +2,10 @@ package authorize import ( "context" - "errors" "net/url" "testing" envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" - "github.com/golang/protobuf/ptypes" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" @@ -21,8 +19,6 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/grpc/session" - "github.com/pomerium/pomerium/pkg/grpc/user" ) const certPEM = ` @@ -313,132 +309,6 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { assert.Equal(t, expect, actual) } -func TestSync(t *testing.T) { - mockSession := func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - data, _ := ptypes.MarshalAny(&session.Session{ - Id: in.GetId(), - UserId: "user1", - }) - return &databroker.GetResponse{ - Record: &databroker.Record{ - Version: 1, - Type: data.GetTypeUrl(), - Id: in.GetId(), - Data: data, - }, - }, nil - } - mockUser := func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - data, _ := ptypes.MarshalAny(&user.User{Id: in.GetId()}) - return &databroker.GetResponse{ - Record: &databroker.Record{ - Version: 1, - Type: data.GetTypeUrl(), - Id: in.GetId(), - Data: data, - }, - }, nil - } - - mockGetByType := map[string]func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error){ - "type.googleapis.com/session.Session": mockSession, - "type.googleapis.com/user.User": mockUser, - } - dbdClient := mockDataBrokerServiceClient{ - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - if in.GetId() == "not-existed-id" { - return nil, errors.New("not found") - } - f, ok := mockGetByType[in.GetType()] - if !ok { - return nil, errors.New("not found") - } - return f(ctx, in, opts...) - }, - } - o := &config.Options{ - AuthenticateURLString: "https://authN.example.com", - DataBrokerURLString: "https://databroker.example.com", - SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", - Policies: testPolicies(t), - } - - ctx := context.Background() - - tests := []struct { - name string - sessionState *sessions.State - databrokerClient mockDataBrokerServiceClient - wantErr bool - }{ - { - "good with data in databroker data", - &sessions.State{ID: "dbd_session_id"}, - mockDataBrokerServiceClient{ - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - data, _ := ptypes.MarshalAny(&session.Session{ - Id: in.GetId(), - UserId: "dbd_user1", - }) - if in.GetType() == "type.googleapis.com/user.User" { - data, _ = ptypes.MarshalAny(&user.User{ - Id: "dbd_user1", - }) - } - return &databroker.GetResponse{ - Record: &databroker.Record{ - Version: 1, - Type: data.GetTypeUrl(), - Id: in.GetId(), - Data: data, - }, - }, nil - }, - }, - false, - }, - {"good", &sessions.State{ID: "SESSION_ID"}, dbdClient, false}, - {"nil session state", nil, dbdClient, false}, - {"not found session state", &sessions.State{ID: "not-existed-id"}, dbdClient, true}, - { - "user not found", - &sessions.State{ID: "session_with_not_found_user"}, - mockDataBrokerServiceClient{ - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - if in.GetType() == "type.googleapis.com/user.User" { - return nil, errors.New("user not found") - } - data, _ := ptypes.MarshalAny(&session.Session{ - Id: in.GetId(), - UserId: "user1", - }) - return &databroker.GetResponse{ - Record: &databroker.Record{ - Version: 1, - Type: data.GetTypeUrl(), - Id: in.GetId(), - Data: data, - }, - }, nil - }, - }, - false, - }, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - a, err := New(&config.Config{Options: o}) - require.NoError(t, err) - a.state.Load().dataBrokerClient = dbdClient - _, err = a.forceSync(ctx, tc.sessionState) - assert.True(t, (err != nil) == tc.wantErr) - }) - } -} - type mockDataBrokerServiceClient struct { databroker.DataBrokerServiceClient diff --git a/authorize/sync.go b/authorize/sync.go index 97e04e1d220..dd5b98269ed 100644 --- a/authorize/sync.go +++ b/authorize/sync.go @@ -2,11 +2,32 @@ package authorize import ( "context" + "errors" "sync" + "time" + "github.com/cenkalti/backoff/v4" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/grpcutil" +) + +const ( + forceSyncRecordMaxWait = 5 * time.Second ) +type sessionOrServiceAccount interface { + GetUserId() string +} + type dataBrokerSyncer struct { *databroker.Syncer authorize *Authorize @@ -29,9 +50,9 @@ func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) { syncer.authorize.store.ClearRecords() } -func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*databroker.Record) { +func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) { for _, record := range records { - syncer.authorize.store.UpdateRecord(record) + syncer.authorize.store.UpdateRecord(serverVersion, record) } // the first time we update records we signal the initial sync @@ -39,3 +60,112 @@ func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*da close(syncer.authorize.dataBrokerInitialSync) }) } + +func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (*user.User, error) { + ctx, span := trace.StartSpan(ctx, "authorize.forceSync") + defer span.End() + if ss == nil { + return nil, nil + } + s := a.forceSyncSession(ctx, ss.ID) + if s == nil { + return nil, errors.New("session not found") + } + u := a.forceSyncUser(ctx, s.GetUserId()) + return u, nil +} + +func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) sessionOrServiceAccount { + ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession") + defer span.End() + + ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait) + defer clearTimeout() + + s, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session) + if ok { + return s + } + + sa, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID).(*user.ServiceAccount) + if ok { + return sa + } + + // wait for the session to show up + record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID) + if err != nil { + return nil + } + s, ok = record.(*session.Session) + if !ok { + return nil + } + return s +} + +func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User { + ctx, span := trace.StartSpan(ctx, "authorize.forceSyncUser") + defer span.End() + + ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait) + defer clearTimeout() + + u, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User) + if ok { + return u + } + + // wait for the user to show up + record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(user.User)), userID) + if err != nil { + return nil + } + u, ok = record.(*user.User) + if !ok { + return nil + } + return u +} + +// waitForRecordSync waits for the first sync of a record to complete +func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, recordID string) (proto.Message, error) { + bo := backoff.NewExponentialBackOff() + bo.InitialInterval = time.Millisecond + bo.MaxElapsedTime = 0 + bo.Reset() + + for { + current := a.store.GetRecordData(recordTypeURL, recordID) + if current != nil { + // record found, so it's already synced + return current, nil + } + + _, err := a.state.Load().dataBrokerClient.Get(ctx, &databroker.GetRequest{ + Type: recordTypeURL, + Id: recordID, + }) + if status.Code(err) == codes.NotFound { + // record not found, so no need to wait + return nil, nil + } else if err != nil { + log.Error(). + Err(err). + Str("type", recordTypeURL). + Str("id", recordID). + Msg("authorize: error retrieving record") + return nil, err + } + + select { + case <-ctx.Done(): + log.Warn(). + Str("type", recordTypeURL). + Str("id", recordID). + Msg("authorize: first sync of record did not complete") + return nil, ctx.Err() + case <-time.After(bo.NextBackOff()): + } + } +} diff --git a/authorize/sync_test.go b/authorize/sync_test.go new file mode 100644 index 00000000000..a675b5f7dfc --- /dev/null +++ b/authorize/sync_test.go @@ -0,0 +1,116 @@ +package authorize + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpcutil" +) + +func TestAuthorize_waitForRecordSync(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30) + defer clearTimeout() + + o := &config.Options{ + AuthenticateURLString: "https://authN.example.com", + DataBrokerURLString: "https://databroker.example.com", + SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", + Policies: testPolicies(t), + } + t.Run("skip if exists", func(t *testing.T) { + a, err := New(&config.Config{Options: o}) + require.NoError(t, err) + + a.store.UpdateRecord(0, newRecord(&session.Session{ + Id: "SESSION_ID", + })) + a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{ + get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { + panic("should never be called") + }, + } + a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID") + }) + t.Run("skip if not found", func(t *testing.T) { + a, err := New(&config.Config{Options: o}) + require.NoError(t, err) + + callCount := 0 + a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{ + get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { + callCount++ + return nil, status.Error(codes.NotFound, "not found") + }, + } + a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID") + assert.Equal(t, 1, callCount, "should be called once") + }) + t.Run("poll", func(t *testing.T) { + a, err := New(&config.Config{Options: o}) + require.NoError(t, err) + + callCount := 0 + a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{ + get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { + callCount++ + switch callCount { + case 1: + s := &session.Session{Id: "SESSION_ID"} + a.store.UpdateRecord(0, newRecord(s)) + return &databroker.GetResponse{Record: newRecord(s)}, nil + default: + panic("should never be called") + } + }, + } + a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID") + }) + t.Run("timeout", func(t *testing.T) { + a, err := New(&config.Config{Options: o}) + require.NoError(t, err) + + tctx, clearTimeout := context.WithTimeout(ctx, time.Millisecond*100) + defer clearTimeout() + + callCount := 0 + a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{ + get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { + callCount++ + s := &session.Session{Id: "SESSION_ID"} + return &databroker.GetResponse{Record: newRecord(s)}, nil + }, + } + a.waitForRecordSync(tctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID") + assert.Greater(t, callCount, 5) // should be ~ 20, but allow for non-determinism + }) +} + +type storableMessage interface { + proto.Message + GetId() string +} + +func newRecord(msg storableMessage) *databroker.Record { + any, err := anypb.New(msg) + if err != nil { + panic(err) + } + return &databroker.Record{ + Version: 1, + Type: any.GetTypeUrl(), + Id: msg.GetId(), + Data: any, + } +} diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index 7288ed0d343..b29cbbe72bb 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -206,7 +206,7 @@ func (s *syncerHandler) ClearRecords(ctx context.Context) { s.src.mu.Unlock() } -func (s *syncerHandler) UpdateRecords(ctx context.Context, records []*databroker.Record) { +func (s *syncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) { if len(records) == 0 { return } diff --git a/internal/identity/manager/sync.go b/internal/identity/manager/sync.go index b6e31913e2e..81582a8efa2 100644 --- a/internal/identity/manager/sync.go +++ b/internal/identity/manager/sync.go @@ -50,7 +50,7 @@ func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrok return syncer.cfg.Load().dataBrokerClient } -func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*databroker.Record) { +func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) { select { case <-ctx.Done(): case syncer.update <- updateRecordsMessage{records: records}: diff --git a/pkg/grpc/databroker/syncer.go b/pkg/grpc/databroker/syncer.go index 573d475269d..c5bd1cff725 100644 --- a/pkg/grpc/databroker/syncer.go +++ b/pkg/grpc/databroker/syncer.go @@ -39,7 +39,7 @@ func WithTypeURL(typeURL string) SyncerOption { type SyncerHandler interface { GetDataBrokerServiceClient() DataBrokerServiceClient ClearRecords(ctx context.Context) - UpdateRecords(ctx context.Context, records []*Record) + UpdateRecords(ctx context.Context, serverVersion uint64, records []*Record) } // A Syncer is a helper type for working with Sync and SyncLatest. It will make a call to @@ -122,7 +122,7 @@ func (syncer *Syncer) init(ctx context.Context) error { syncer.recordVersion = recordVersion syncer.serverVersion = serverVersion - syncer.handler.UpdateRecords(ctx, records) + syncer.handler.UpdateRecords(ctx, serverVersion, records) return nil } @@ -157,7 +157,7 @@ func (syncer *Syncer) sync(ctx context.Context) error { } syncer.recordVersion = res.GetRecord().GetVersion() if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() { - syncer.handler.UpdateRecords(ctx, []*Record{res.GetRecord()}) + syncer.handler.UpdateRecords(ctx, syncer.serverVersion, []*Record{res.GetRecord()}) } } } diff --git a/pkg/grpc/databroker/syncer_test.go b/pkg/grpc/databroker/syncer_test.go index 79e493d9b63..af865247325 100644 --- a/pkg/grpc/databroker/syncer_test.go +++ b/pkg/grpc/databroker/syncer_test.go @@ -19,7 +19,7 @@ import ( type testSyncerHandler struct { getDataBrokerServiceClient func() DataBrokerServiceClient clearRecords func(ctx context.Context) - updateRecords func(ctx context.Context, records []*Record) + updateRecords func(ctx context.Context, serverVersion uint64, records []*Record) } func (t testSyncerHandler) GetDataBrokerServiceClient() DataBrokerServiceClient { @@ -30,8 +30,8 @@ func (t testSyncerHandler) ClearRecords(ctx context.Context) { t.clearRecords(ctx) } -func (t testSyncerHandler) UpdateRecords(ctx context.Context, records []*Record) { - t.updateRecords(ctx, records) +func (t testSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*Record) { + t.updateRecords(ctx, serverVersion, records) } type testServer struct { @@ -166,7 +166,7 @@ func TestSyncer(t *testing.T) { clearRecords: func(ctx context.Context) { clearCh <- struct{}{} }, - updateRecords: func(ctx context.Context, records []*Record) { + updateRecords: func(ctx context.Context, serverVersion uint64, records []*Record) { updateCh <- records }, })