From bf75b033df55248f47d75d322d8e0464009b8fe3 Mon Sep 17 00:00:00 2001 From: Casey Marshall Date: Tue, 4 Jan 2022 15:34:56 -0600 Subject: [PATCH] feat: openapi request and response validation middleware This introduces a simple validation middleware which makes use of the recently added openapi3filter.Validator (see https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3filter#example-Validator) --- go.sum | 1 + versionware/export_test.go | 3 + versionware/handler.go | 8 +- versionware/handler_test.go | 32 +-- versionware/validator.go | 125 ++++++++++ versionware/validator_test.go | 444 ++++++++++++++++++++++++++++++++++ 6 files changed, 594 insertions(+), 19 deletions(-) create mode 100644 versionware/export_test.go create mode 100644 versionware/validator.go create mode 100644 versionware/validator_test.go diff --git a/go.sum b/go.sum index 30499082..efb11144 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,7 @@ github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= diff --git a/versionware/export_test.go b/versionware/export_test.go new file mode 100644 index 00000000..ad4fe6e1 --- /dev/null +++ b/versionware/export_test.go @@ -0,0 +1,3 @@ +package versionware + +var DefaultValidatorConfig = defaultValidatorConfig diff --git a/versionware/handler.go b/versionware/handler.go index 4c100967..ca15ab6b 100644 --- a/versionware/handler.go +++ b/versionware/handler.go @@ -44,7 +44,7 @@ func NewHandler(vhs ...VersionHandler) *Handler { h := &Handler{ handlers: make([]http.Handler, len(vhs)), versions: make([]vervet.Version, len(vhs)), - errFunc: defaultErrorHandler, + errFunc: DefaultVersionError, } handlerVersions := map[string]http.Handler{} for i := range vhs { @@ -59,8 +59,10 @@ func NewHandler(vhs ...VersionHandler) *Handler { return h } -func defaultErrorHandler(w http.ResponseWriter, r *http.Request, status int, err error) { - http.Error(w, err.Error(), status) +// DefaultVersionError provides a basic implementation of VersionErrorHandler +// that uses http.Error. +func DefaultVersionError(w http.ResponseWriter, r *http.Request, status int, err error) { + http.Error(w, http.StatusText(status), status) } // HandleErrors changes the default error handler to the provided function. It diff --git a/versionware/handler_test.go b/versionware/handler_test.go index 40d50af3..6781b9c3 100644 --- a/versionware/handler_test.go +++ b/versionware/handler_test.go @@ -75,19 +75,16 @@ func TestHandler(t *testing.T) { c.Assert(err, qt.IsNil) }), }}...) - s := httptest.NewServer(h) - c.Cleanup(s.Close) - tests := []struct { requested, resolved string contents string status int }{{ - "2021-08-31", "", "no matching version\n", 404, + "2021-08-31", "", "Not Found\n", 404, }, { - "bad wolf", "", "400 Bad Request", 400, + "", "", "Bad Request\n", 400, }, { - "", "", "missing required query parameter 'version'\n", 400, + "bad wolf", "", "400 Bad Request", 400, }, { "2021-09-16", "2021-09-01", "sept", 200, }, { @@ -100,15 +97,18 @@ func TestHandler(t *testing.T) { "2023-02-05", "2021-11-01", "nov", 200, }} for i, test := range tests { - c.Logf("test#%d: requested %s resolved %s", i, test.requested, test.resolved) - req, err := http.NewRequest("GET", s.URL+"?version="+test.requested, nil) - c.Assert(err, qt.IsNil) - resp, err := s.Client().Do(req) - c.Assert(err, qt.IsNil) - defer resp.Body.Close() - c.Assert(resp.StatusCode, qt.Equals, test.status) - contents, err := ioutil.ReadAll(resp.Body) - c.Assert(err, qt.IsNil) - c.Assert(string(contents), qt.Equals, test.contents) + c.Run(fmt.Sprintf("%d requested %s resolved %s", i, test.requested, test.resolved), func(c *qt.C) { + s := httptest.NewServer(h) + c.Cleanup(s.Close) + req, err := http.NewRequest("GET", s.URL+"?version="+test.requested, nil) + c.Assert(err, qt.IsNil) + resp, err := s.Client().Do(req) + c.Assert(err, qt.IsNil) + defer resp.Body.Close() + c.Assert(resp.StatusCode, qt.Equals, test.status) + contents, err := ioutil.ReadAll(resp.Body) + c.Assert(err, qt.IsNil) + c.Assert(string(contents), qt.Equals, test.contents) + }) } } diff --git a/versionware/validator.go b/versionware/validator.go new file mode 100644 index 00000000..5fc2db29 --- /dev/null +++ b/versionware/validator.go @@ -0,0 +1,125 @@ +package versionware + +import ( + "fmt" + "net/http" + "sort" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" + + "github.com/snyk/vervet" +) + +// Validator provides versioned OpenAPI validation middleware for HTTP requests +// and responses. +type Validator struct { + versions vervet.VersionSlice + validators []*openapi3filter.Validator + errFunc VersionErrorHandler +} + +// ValidatorConfig defines how a new Validator may be configured. +type ValidatorConfig struct { + // ServerURL overrides the server URLs in the given OpenAPI specs to match + // the URL of requests reaching the backend service. If unset, requests + // must match the servers defined in OpenAPI specs. + ServerURL string + + // VersionError is called on any error that occurs when trying to resolve the + // API version. + VersionError VersionErrorHandler + + // Options further configure the request and response validation. See + // https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3filter#ValidatorOption + // for available options. + Options []openapi3filter.ValidatorOption +} + +var defaultValidatorConfig = ValidatorConfig{ + VersionError: DefaultVersionError, + Options: []openapi3filter.ValidatorOption{ + openapi3filter.OnErr(func(w http.ResponseWriter, status int, code openapi3filter.ErrCode, _ error) { + statusText := http.StatusText(http.StatusInternalServerError) + switch code { + case openapi3filter.ErrCodeCannotFindRoute: + statusText = "Not Found" + case openapi3filter.ErrCodeRequestInvalid: + statusText = "Bad Request" + } + http.Error(w, statusText, status) + }), + }, +} + +// NewValidator returns a new validation middleware, which validates versioned +// requests according to the given OpenAPI spec versions. For configuration +// defaults, a nil config may be used. +func NewValidator(config *ValidatorConfig, docs ...*openapi3.T) (*Validator, error) { + if config == nil { + config = &defaultValidatorConfig + } + if config.ServerURL != "" { + for i := range docs { + docs[i].Servers = []*openapi3.Server{{URL: config.ServerURL}} + } + } + v := &Validator{ + versions: make([]vervet.Version, len(docs)), + validators: make([]*openapi3filter.Validator, len(docs)), + errFunc: config.VersionError, + } + validatorVersions := map[string]*openapi3filter.Validator{} + for i := range docs { + if config.ServerURL != "" { + docs[i].Servers = []*openapi3.Server{{URL: config.ServerURL}} + } + versionStr, err := vervet.ExtensionString(docs[i].ExtensionProps, vervet.ExtSnykApiVersion) + if err != nil { + return nil, err + } + version, err := vervet.ParseVersion(versionStr) + if err != nil { + return nil, err + } + v.versions[i] = *version + router, err := gorillamux.NewRouter(docs[i]) + if err != nil { + return nil, err + } + validatorVersions[version.String()] = openapi3filter.NewValidator(router, config.Options...) + } + sort.Sort(v.versions) + for i := range v.versions { + v.validators[i] = validatorVersions[v.versions[i].String()] + } + return v, nil +} + +// Middleware returns an http.Handler which wraps the given handler with +// request and response validation according to the requested API version. +func (v *Validator) Middleware(h http.Handler) http.Handler { + handlers := make([]http.Handler, len(v.validators)) + for i := range v.versions { + handlers[i] = v.validators[i].Middleware(h) + } + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + versionParam := req.URL.Query().Get("version") + if versionParam == "" { + v.errFunc(w, req, http.StatusBadRequest, fmt.Errorf("missing required query parameter 'version'")) + return + } + requested, err := vervet.ParseVersion(versionParam) + if err != nil { + v.errFunc(w, req, http.StatusBadRequest, err) + return + } + resolvedIndex, err := v.versions.ResolveIndex(*requested) + if err != nil { + v.errFunc(w, req, http.StatusNotFound, err) + return + } + handlers[resolvedIndex].ServeHTTP(w, req) + }) +} diff --git a/versionware/validator_test.go b/versionware/validator_test.go new file mode 100644 index 00000000..9a86424e --- /dev/null +++ b/versionware/validator_test.go @@ -0,0 +1,444 @@ +package versionware_test + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "regexp" + "testing" + + qt "github.com/frankban/quicktest" + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + + "github.com/snyk/vervet/versionware" +) + +const ( + v20210820 = ` +openapi: 3.0.0 +x-snyk-api-version: 2021-08-20 +info: + title: 'Validator' + version: '0.0.0' +paths: + /test/{id}: + get: + operationId: getTest + description: get a test + parameters: + - in: path + name: id + schema: + type: string + required: true + - in: query + name: version + schema: + type: string + required: true + responses: + '200': + description: 'respond with test resource' + content: + application/json: + schema: { $ref: '#/components/schemas/TestResource' } + '400': { $ref: '#/components/responses/ErrorResponse' } + '404': { $ref: '#/components/responses/ErrorResponse' } + '500': { $ref: '#/components/responses/ErrorResponse' } +components: + schemas: + TestContents: + type: object + properties: + name: + type: string + expected: + type: number + actual: + type: number + required: [name, expected, actual] + additionalProperties: false + TestResource: + type: object + properties: + id: + type: string + contents: + { $ref: '#/components/schemas/TestContents' } + required: [id, contents] + additionalProperties: false + Error: + type: object + properties: + code: + type: string + message: + type: string + required: [code, message] + additionalProperties: false + responses: + ErrorResponse: + description: 'an error occurred' + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } +` + v20210916 = ` +openapi: 3.0.0 +x-snyk-api-version: 2021-09-16 +info: + title: 'Validator' + version: '0.0.0' +paths: + /test: + post: + operationId: newTest + description: create a new test + parameters: + - in: query + name: version + schema: + type: string + required: true + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/TestContents' } + responses: + '201': + description: 'created test' + content: + application/json: + schema: { $ref: '#/components/schemas/TestResource' } + '400': { $ref: '#/components/responses/ErrorResponse' } + '500': { $ref: '#/components/responses/ErrorResponse' } + /test/{id}: + get: + operationId: getTest + description: get a test + parameters: + - in: path + name: id + schema: + type: string + required: true + - in: query + name: version + schema: + type: string + required: true + responses: + '200': + description: 'respond with test resource' + content: + application/json: + schema: { $ref: '#/components/schemas/TestResource' } + '400': { $ref: '#/components/responses/ErrorResponse' } + '404': { $ref: '#/components/responses/ErrorResponse' } + '500': { $ref: '#/components/responses/ErrorResponse' } +components: + schemas: + TestContents: + type: object + properties: + name: + type: string + expected: + type: number + actual: + type: number + noodles: + type: boolean + required: [name, expected, actual, noodles] + additionalProperties: false + TestResource: + type: object + properties: + id: + type: string + contents: + { $ref: '#/components/schemas/TestContents' } + required: [id, contents] + additionalProperties: false + Error: + type: object + properties: + code: + type: string + message: + type: string + required: [code, message] + additionalProperties: false + responses: + ErrorResponse: + description: 'an error occurred' + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } +` +) + +type validatorTestHandler struct { + contentType string + getBody, postBody string + errBody string + errStatusCode int +} + +const v20210820_Body = `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}}` +const v20210916_Body = `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10, "noodles": true}}` + +func (h validatorTestHandler) withDefaults() validatorTestHandler { + if h.contentType == "" { + h.contentType = "application/json" + } + if h.getBody == "" { + h.getBody = v20210916_Body + } + if h.postBody == "" { + h.postBody = v20210916_Body + } + if h.errBody == "" { + h.errBody = `{"code":"bad","message":"bad things"}` + } + return h +} + +var testUrlRE = regexp.MustCompile(`^/test(/\d+)?$`) + +func (h *validatorTestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", h.contentType) + if h.errStatusCode != 0 { + w.WriteHeader(h.errStatusCode) + w.Write([]byte(h.errBody)) + return + } + if !testUrlRE.MatchString(r.URL.Path) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(h.errBody)) + return + } + switch r.Method { + case "GET": + w.WriteHeader(http.StatusOK) + w.Write([]byte(h.getBody)) + case "POST": + w.WriteHeader(http.StatusCreated) + w.Write([]byte(h.postBody)) + default: + http.Error(w, h.errBody, http.StatusMethodNotAllowed) + } +} + +func TestValidator(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + docs := make([]*openapi3.T, 2) + for i, specStr := range []string{v20210820, v20210916} { + doc, err := openapi3.NewLoader().LoadFromData([]byte(specStr)) + c.Assert(err, qt.IsNil) + err = doc.Validate(ctx) + c.Assert(err, qt.IsNil) + docs[i] = doc + } + + type testRequest struct { + method, path, body, contentType string + } + type testResponse struct { + statusCode int + body string + } + tests := []struct { + name string + handler validatorTestHandler + options []openapi3filter.ValidatorOption + request testRequest + response testResponse + strict bool + }{{ + name: "valid GET", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42?version=2021-09-17", + }, + response: testResponse{ + 200, v20210916_Body, + }, + strict: true, + }, { + name: "valid POST", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=2021-09-17", + body: `{"name": "foo", "expected": 9, "actual": 10, "noodles": true}`, + contentType: "application/json", + }, + response: testResponse{ + 201, v20210916_Body, + }, + strict: true, + }, { + name: "not found; no GET operation for /test", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test?version=2021-09-17", + }, + response: testResponse{ + 404, "Not Found\n", + }, + strict: true, + }, { + name: "not found; no POST operation for /test/42", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test/42?version=2021-09-17", + }, + response: testResponse{ + 404, "Not Found\n", + }, + strict: true, + }, { + name: "invalid request; missing version", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42", + }, + response: testResponse{ + 400, "Bad Request\n", + }, + strict: true, + }, { + name: "invalid POST request; wrong property type", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=2021-09-17", + body: `{"name": "foo", "expected": "nine", "actual": "ten", "noodles": false}`, + contentType: "application/json", + }, + response: testResponse{ + 400, "Bad Request\n", + }, + strict: true, + }, { + name: "invalid POST request; missing property", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=2021-09-17", + body: `{"name": "foo", "expected": 9}`, + contentType: "application/json", + }, + response: testResponse{ + 400, "Bad Request\n", + }, + strict: true, + }, { + name: "invalid POST request; extra property", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=2021-09-17", + body: `{"name": "foo", "expected": 9, "actual": 10, "noodles": false, "ideal": 8}`, + contentType: "application/json", + }, + response: testResponse{ + 400, "Bad Request\n", + }, + strict: true, + }, { + name: "valid response; 404 error", + handler: validatorTestHandler{ + contentType: "application/json", + errBody: `{"code": "404", "message": "not found"}`, + errStatusCode: 404, + }.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42?version=2021-09-17", + }, + response: testResponse{ + 404, `{"code": "404", "message": "not found"}`, + }, + strict: true, + }, { + name: "invalid response; invalid error", + handler: validatorTestHandler{ + errBody: `"not found"`, + errStatusCode: 404, + }.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42?version=2021-09-17", + }, + response: testResponse{ + 500, "Internal Server Error\n", + }, + strict: true, + }, { + name: "invalid POST response; not strict", + handler: validatorTestHandler{ + postBody: `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}, "extra": true}`, + }.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=2021-09-17", + body: `{"name": "foo", "expected": 9, "actual": 10, "noodles": true}`, + contentType: "application/json", + }, + response: testResponse{ + statusCode: 201, + body: `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}, "extra": true}`, + }, + strict: false, + }} + for i, test := range tests { + c.Run(fmt.Sprintf("%d %s", i, test.name), func(c *qt.C) { + // Set up a test HTTP server + var h http.Handler + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + })) + defer s.Close() + + config := versionware.DefaultValidatorConfig + config.ServerURL = s.URL + config.Options = append(config.Options, append(test.options, openapi3filter.Strict(test.strict))...) + v, err := versionware.NewValidator(&config, docs...) + c.Assert(err, qt.IsNil) + h = v.Middleware(&test.handler) + + // Test: make a client request + var requestBody io.Reader + if test.request.body != "" { + requestBody = bytes.NewBufferString(test.request.body) + } + req, err := http.NewRequest(test.request.method, s.URL+test.request.path, requestBody) + c.Assert(err, qt.IsNil) + + if test.request.contentType != "" { + req.Header.Set("Content-Type", test.request.contentType) + } + resp, err := s.Client().Do(req) + c.Assert(err, qt.IsNil) + defer resp.Body.Close() + c.Assert(test.response.statusCode, qt.Equals, resp.StatusCode) + + body, err := ioutil.ReadAll(resp.Body) + c.Assert(err, qt.IsNil) + c.Assert(test.response.body, qt.Equals, string(body)) + }) + } +}