Skip to content

Commit

Permalink
feat: verify organisation provided in header exists`
Browse files Browse the repository at this point in the history
  • Loading branch information
Muchogoc committed Feb 21, 2023
1 parent 939bfd8 commit b740977
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 33 deletions.
2 changes: 1 addition & 1 deletion pkg/clinical/presentation/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func Router(ctx context.Context) (*mux.Router, error) {
//Authenticated routes
gqlR := r.Path("/graphql").Subrouter()
gqlR.Use(authutils.SladeAuthenticationMiddleware(*authClient))
gqlR.Use(rest.TenantIdentifierExtractionMiddleware)
gqlR.Use(rest.TenantIdentifierExtractionMiddleware(usecases))
gqlR.Methods(
http.MethodPost, http.MethodGet, http.MethodOptions,
).HandlerFunc(GQLHandler(ctx, usecases))
Expand Down
73 changes: 50 additions & 23 deletions pkg/clinical/presentation/rest/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,33 @@ import (
"net/http"

"github.com/savannahghi/clinical/pkg/clinical/application/utils"
"github.com/savannahghi/clinical/pkg/clinical/domain"
"github.com/savannahghi/serverutils"
)

// Validators defines the methods used to validate the various identifiers that the api expects
type Validators interface {
FindOrganizationByID(ctx context.Context, organizationID string) (*domain.FHIROrganizationRelayPayload, error)
}

// OrganisationValidator verifies that the provided organisation exists in clinical
// to ensure the request comes from a known/registered organisation
func OrganisationValidator(v Validators, identifier string) error {
_, err := v.FindOrganizationByID(context.Background(), identifier)
if err != nil {
return fmt.Errorf("failed to find provided organisation")
}

return nil
}

// TenantIdentifier is a type representing a header name and a corresponding context key
// The header name is what will be used to extract the specified header and the context key
// Will be the key value used when adding the header in the request context
type TenantIdentifier struct {
HeaderKey string
ContextKey utils.ContextKey
HeaderKey string
ContextKey utils.ContextKey
ValidatorFunc func(v Validators, identifier string) error
}

type errResponse struct {
Expand All @@ -37,27 +55,36 @@ func handleError(w http.ResponseWriter, err error) {
// tasks such as filtering, or database queries
// Note that this middleware assumes that the IDs are included in the request as headers
// and it does not perform any validation or sanitization of the ID values.
func TenantIdentifierExtractionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := []TenantIdentifier{
{
HeaderKey: "OrganizationID",
ContextKey: utils.OrganizationIDContextKey,
},
// TODO: Add more headers here as needed e.g FacilityID, ProgramID
}

for _, header := range headers {
headerValue := r.Header.Get(header.HeaderKey)
if headerValue == "" {
err := fmt.Errorf("expected `%s` header to be included in the request", header.HeaderKey)
handleError(w, err)
return
func TenantIdentifierExtractionMiddleware(validator Validators) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := []TenantIdentifier{
{
HeaderKey: "OrganizationID",
ContextKey: utils.OrganizationIDContextKey,
ValidatorFunc: OrganisationValidator,
},
// TODO: Add more headers here as needed e.g FacilityID, ProgramID
}

ctx := context.WithValue(r.Context(), header.ContextKey, headerValue)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
for _, header := range headers {
headerValue := r.Header.Get(header.HeaderKey)
if headerValue == "" {
err := fmt.Errorf("expected `%s` header to be included in the request", header.HeaderKey)
handleError(w, err)
return
}

err := header.ValidatorFunc(validator, headerValue)
if err != nil {
handleError(w, err)
return
}

ctx := context.WithValue(r.Context(), header.ContextKey, headerValue)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}
}
5 changes: 4 additions & 1 deletion pkg/clinical/presentation/rest/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/savannahghi/clinical/pkg/clinical/application/utils"
"github.com/savannahghi/clinical/pkg/clinical/presentation/rest"
"github.com/savannahghi/clinical/pkg/clinical/usecases/clinical/mock"
)

func TestIDExtractionMiddleware(t *testing.T) {
Expand Down Expand Up @@ -41,7 +42,9 @@ func TestIDExtractionMiddleware(t *testing.T) {
}

res := httptest.NewRecorder()
middleware := rest.TenantIdentifierExtractionMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

usecase := mock.NewFHIRUsecaseMock()
middleware := rest.TenantIdentifierExtractionMiddleware(usecase)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
for expectedKey, expectedValue := range test.expectedContext {
ctxValue := ctx.Value(expectedKey)
Expand Down
20 changes: 12 additions & 8 deletions tests/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"time"

"github.com/brianvoe/gofakeit"
"github.com/google/uuid"
"github.com/imroc/req"
"github.com/savannahghi/authutils"
"github.com/savannahghi/clinical/pkg/clinical/application/common/helpers"
Expand Down Expand Up @@ -85,12 +84,6 @@ func TestMain(m *testing.M) {
return
}

headers, err = GetGraphQLHeaders(ctx)
if err != nil {
log.Printf("error adding the graphql headers")
return
}

srv, baseURL, serverErr = serverutils.StartTestServer(
ctx,
presentation.PrepareServer,
Expand All @@ -109,6 +102,12 @@ func TestMain(m *testing.M) {

testInteractor = i

headers, err = GetGraphQLHeaders(ctx)
if err != nil {
log.Printf("error adding the graphql headers")
return
}

// run the tests
log.Printf("about to run tests")
code := m.Run()
Expand All @@ -128,11 +127,16 @@ func TestMain(m *testing.M) {
func GetGraphQLHeaders(ctx context.Context) (map[string]string, error) {
accessToken := fmt.Sprintf("Bearer %s", oauthPayload.AccessToken)

orgID, err := testInteractor.GetORCreateOrganization(ctx, testProviderCode)
if err != nil {
return nil, fmt.Errorf("can't get or create test organization : %v", err)
}

return req.Header{
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": accessToken,
"OrganizationID": uuid.New().String(),
"OrganizationID": *orgID,
}, nil
}

Expand Down

0 comments on commit b740977

Please sign in to comment.