diff --git a/pkg/clinical/presentation/config.go b/pkg/clinical/presentation/config.go index 816f69d5..d2d848e9 100644 --- a/pkg/clinical/presentation/config.go +++ b/pkg/clinical/presentation/config.go @@ -172,7 +172,7 @@ func Router(ctx context.Context) (*mux.Router, error) { //Authenticated routes gqlR := r.Path("/graphql").Subrouter() gqlR.Use(authutils.SladeAuthenticationMiddleware(*authClient)) - gqlR.Use(rest.TenantIdentifierExtractionMiddleware) + gqlR.Use(rest.TenantIdentifierExtractionMiddleware(usecases)) gqlR.Methods( http.MethodPost, http.MethodGet, http.MethodOptions, ).HandlerFunc(GQLHandler(ctx, usecases)) diff --git a/pkg/clinical/presentation/rest/middleware.go b/pkg/clinical/presentation/rest/middleware.go index df46cbcd..11d7bce9 100644 --- a/pkg/clinical/presentation/rest/middleware.go +++ b/pkg/clinical/presentation/rest/middleware.go @@ -6,15 +6,33 @@ import ( "net/http" "github.com/savannahghi/clinical/pkg/clinical/application/utils" + "github.com/savannahghi/clinical/pkg/clinical/domain" "github.com/savannahghi/serverutils" ) +// Validators defines the methods used to validate the various identifiers that the api expects +type Validators interface { + FindOrganizationByID(ctx context.Context, organizationID string) (*domain.FHIROrganizationRelayPayload, error) +} + +// OrganisationValidator verifies that the provided organisation exists in clinical +// to ensure the request comes from a known/registered organisation +func OrganisationValidator(v Validators, identifier string) error { + _, err := v.FindOrganizationByID(context.Background(), identifier) + if err != nil { + return fmt.Errorf("failed to find provided organisation") + } + + return nil +} + // TenantIdentifier is a type representing a header name and a corresponding context key // The header name is what will be used to extract the specified header and the context key // Will be the key value used when adding the header in the request context type TenantIdentifier struct { - HeaderKey string - ContextKey utils.ContextKey + HeaderKey string + ContextKey utils.ContextKey + ValidatorFunc func(v Validators, identifier string) error } type errResponse struct { @@ -37,27 +55,36 @@ func handleError(w http.ResponseWriter, err error) { // tasks such as filtering, or database queries // Note that this middleware assumes that the IDs are included in the request as headers // and it does not perform any validation or sanitization of the ID values. -func TenantIdentifierExtractionMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - headers := []TenantIdentifier{ - { - HeaderKey: "OrganizationID", - ContextKey: utils.OrganizationIDContextKey, - }, - // TODO: Add more headers here as needed e.g FacilityID, ProgramID - } - - for _, header := range headers { - headerValue := r.Header.Get(header.HeaderKey) - if headerValue == "" { - err := fmt.Errorf("expected `%s` header to be included in the request", header.HeaderKey) - handleError(w, err) - return +func TenantIdentifierExtractionMiddleware(validator Validators) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headers := []TenantIdentifier{ + { + HeaderKey: "OrganizationID", + ContextKey: utils.OrganizationIDContextKey, + ValidatorFunc: OrganisationValidator, + }, + // TODO: Add more headers here as needed e.g FacilityID, ProgramID } - ctx := context.WithValue(r.Context(), header.ContextKey, headerValue) - r = r.WithContext(ctx) - } - next.ServeHTTP(w, r) - }) + for _, header := range headers { + headerValue := r.Header.Get(header.HeaderKey) + if headerValue == "" { + err := fmt.Errorf("expected `%s` header to be included in the request", header.HeaderKey) + handleError(w, err) + return + } + + err := header.ValidatorFunc(validator, headerValue) + if err != nil { + handleError(w, err) + return + } + + ctx := context.WithValue(r.Context(), header.ContextKey, headerValue) + r = r.WithContext(ctx) + } + next.ServeHTTP(w, r) + }) + } } diff --git a/pkg/clinical/presentation/rest/middleware_test.go b/pkg/clinical/presentation/rest/middleware_test.go index 63f7e08d..db3edb2c 100644 --- a/pkg/clinical/presentation/rest/middleware_test.go +++ b/pkg/clinical/presentation/rest/middleware_test.go @@ -7,6 +7,7 @@ import ( "github.com/savannahghi/clinical/pkg/clinical/application/utils" "github.com/savannahghi/clinical/pkg/clinical/presentation/rest" + "github.com/savannahghi/clinical/pkg/clinical/usecases/clinical/mock" ) func TestIDExtractionMiddleware(t *testing.T) { @@ -41,7 +42,9 @@ func TestIDExtractionMiddleware(t *testing.T) { } res := httptest.NewRecorder() - middleware := rest.TenantIdentifierExtractionMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + usecase := mock.NewFHIRUsecaseMock() + middleware := rest.TenantIdentifierExtractionMiddleware(usecase)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() for expectedKey, expectedValue := range test.expectedContext { ctxValue := ctx.Value(expectedKey) diff --git a/tests/config_test.go b/tests/config_test.go index 7ba2fff7..78bfab00 100644 --- a/tests/config_test.go +++ b/tests/config_test.go @@ -12,7 +12,6 @@ import ( "time" "github.com/brianvoe/gofakeit" - "github.com/google/uuid" "github.com/imroc/req" "github.com/savannahghi/authutils" "github.com/savannahghi/clinical/pkg/clinical/application/common/helpers" @@ -85,12 +84,6 @@ func TestMain(m *testing.M) { return } - headers, err = GetGraphQLHeaders(ctx) - if err != nil { - log.Printf("error adding the graphql headers") - return - } - srv, baseURL, serverErr = serverutils.StartTestServer( ctx, presentation.PrepareServer, @@ -109,6 +102,12 @@ func TestMain(m *testing.M) { testInteractor = i + headers, err = GetGraphQLHeaders(ctx) + if err != nil { + log.Printf("error adding the graphql headers") + return + } + // run the tests log.Printf("about to run tests") code := m.Run() @@ -128,11 +127,16 @@ func TestMain(m *testing.M) { func GetGraphQLHeaders(ctx context.Context) (map[string]string, error) { accessToken := fmt.Sprintf("Bearer %s", oauthPayload.AccessToken) + orgID, err := testInteractor.GetORCreateOrganization(ctx, testProviderCode) + if err != nil { + return nil, fmt.Errorf("can't get or create test organization : %v", err) + } + return req.Header{ "Accept": "application/json", "Content-Type": "application/json", "Authorization": accessToken, - "OrganizationID": uuid.New().String(), + "OrganizationID": *orgID, }, nil }