Skip to content

Commit

Permalink
internal: Clean up request validation API
Browse files Browse the repository at this point in the history
This simplifies the request validation APIs slightly. The intention is to
expose `ValidateUnary` and `ValidateOneway` as public-facing APIs in a
different diff so that transport authors can perform validation easily.
  • Loading branch information
abhinav committed Dec 16, 2016
1 parent e52251c commit 689d2e0
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 66 deletions.
63 changes: 31 additions & 32 deletions internal/request/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,30 @@ import (
// Validator helps validate requests.
//
// v := Validator{Request: request}
// v.Validate()
// v.ValidateCommon(ctx)
// ...
// v.ParseTTL(ttlstring)
// request, err := v.ValidateUnary()
// request, err := v.ValidateUnary(ctx)
type Validator struct {
Request *transport.Request
errTTL error
}

// Validate is a shortcut for the case where a request needs to be validated
// without changing the TTL. This should be used to validate all request types
func Validate(ctx context.Context, req *transport.Request) (*transport.Request, error) {
v := Validator{Request: req}
return v.Validate(ctx)
}

// ValidateUnary validates a unary request. This should be used after a successful Validate()
func ValidateUnary(ctx context.Context, req *transport.Request) (*transport.Request, error) {
// ValidateUnary validates a unary request.
func ValidateUnary(ctx context.Context, req *transport.Request) error {
v := Validator{Request: req}
if err := v.ValidateCommon(ctx); err != nil {
return err
}
return v.ValidateUnary(ctx)
}

// ValidateOneway validates a oneway request. This should be used after a successful Validate()
func ValidateOneway(ctx context.Context, req *transport.Request) (*transport.Request, error) {
// ValidateOneway validates a oneway request.
func ValidateOneway(ctx context.Context, req *transport.Request) error {
v := Validator{Request: req}
if err := v.ValidateCommon(ctx); err != nil {
return err
}
return v.ValidateOneway(ctx)
}

Expand All @@ -65,8 +64,8 @@ func ValidateOneway(ctx context.Context, req *transport.Request) (*transport.Req
// parse and validate that TTL. Should only be used for unary requests
func (v *Validator) ParseTTL(ctx context.Context, ttl string) (context.Context, func()) {
if ttl == "" {
// The TTL is missing so set it to 0 and let Validate() fail with the
// correct error message.
// The TTL is missing so set it to 0 and let ValidateUnary() fail with
// the correct error message.
return ctx, func() {}
}

Expand All @@ -92,10 +91,10 @@ func (v *Validator) ParseTTL(ctx context.Context, ttl string) (context.Context,
return context.WithTimeout(ctx, time.Duration(ttlms)*time.Millisecond)
}

// Validate checks that the request inside this validator is valid and returns
// either the validated request or an error. This should be used to check all
// requests, prior to the RPC type specifc validation.
func (v *Validator) Validate(ctx context.Context) (*transport.Request, error) {
// ValidateCommon checks validity of the common attributes of the request.
// This should be used to check ALL requests prior to calling
// RPC-type-specific validators.
func (v *Validator) ValidateCommon(ctx context.Context) error {
// check missing params
var missingParams []string
if v.Request.Service == "" {
Expand All @@ -111,29 +110,29 @@ func (v *Validator) Validate(ctx context.Context) (*transport.Request, error) {
missingParams = append(missingParams, "encoding")
}
if len(missingParams) > 0 {
return nil, missingParametersError{Parameters: missingParams}
return missingParametersError{Parameters: missingParams}
}

return v.Request, nil
return nil
}

// ValidateUnary validates a unary request. This should be used after a successful v.Validate()
func (v *Validator) ValidateUnary(ctx context.Context) (*transport.Request, error) {
// ValidateUnary validates a unary request. This should be used after a
// successful v.ValidateCommon()
func (v *Validator) ValidateUnary(ctx context.Context) error {
if v.errTTL != nil {
return nil, v.errTTL
return v.errTTL
}

_, hasDeadline := ctx.Deadline()

if !hasDeadline {
return nil, missingParametersError{Parameters: []string{"TTL"}}
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
return missingParametersError{Parameters: []string{"TTL"}}
}

return v.Request, nil
return nil
}

// ValidateOneway validates a oneway request. This should be used after a successful Validate()
func (v *Validator) ValidateOneway(ctx context.Context) (*transport.Request, error) {
// ValidateOneway validates a oneway request. This should be used after a
// successful ValidateCommon()
func (v *Validator) ValidateOneway(ctx context.Context) error {
// Currently, no extra checks for oneway requests are required
return v.Request, nil
return nil
}
16 changes: 2 additions & 14 deletions internal/request/validator_outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,7 @@ type OnewayValidatorOutbound struct{ transport.OnewayOutbound }

// Call performs the given request, failing early if the request is invalid.
func (o UnaryValidatorOutbound) Call(ctx context.Context, request *transport.Request) (*transport.Response, error) {
request, err := Validate(ctx, request)
if err != nil {
return nil, err
}

request, err = ValidateUnary(ctx, request)
if err != nil {
if err := ValidateUnary(ctx, request); err != nil {
return nil, err
}

Expand All @@ -49,13 +43,7 @@ func (o UnaryValidatorOutbound) Call(ctx context.Context, request *transport.Req

// CallOneway performs the given request, failing early if the request is invalid.
func (o OnewayValidatorOutbound) CallOneway(ctx context.Context, request *transport.Request) (transport.Ack, error) {
request, err := Validate(ctx, request)
if err != nil {
return nil, err
}

request, err = ValidateOneway(ctx, request)
if err != nil {
if err := ValidateOneway(ctx, request); err != nil {
return nil, err
}

Expand Down
6 changes: 3 additions & 3 deletions internal/request/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ func TestValidator(t *testing.T) {
v := Validator{Request: tt.req}

ctx := context.Background()
_, err := v.Validate(ctx)
err := v.ValidateCommon(ctx)

if err == nil && tt.transportType == transport.Oneway {
_, err = v.ValidateOneway(ctx)
err = v.ValidateOneway(ctx)
} else if err == nil { // default to unary
var cancel func()

Expand All @@ -173,7 +173,7 @@ func TestValidator(t *testing.T) {
defer cancel()
}

_, err = v.ValidateUnary(ctx)
err = v.ValidateUnary(ctx)
}

if tt.wantErr != nil {
Expand Down
11 changes: 3 additions & 8 deletions transport/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (h handler) callHandler(w http.ResponseWriter, req *http.Request, start tim

ctx, span := h.createSpan(ctx, req, treq, start)

treq, err := v.Validate(ctx)
err := v.ValidateCommon(ctx)
if err != nil {
return err
}
Expand All @@ -106,18 +106,13 @@ func (h handler) callHandler(w http.ResponseWriter, req *http.Request, start tim
case transport.Unary:
defer span.Finish()

ctx, cancel := v.ParseTTL(ctx, popHeader(req.Header, TTLMSHeader))
defer cancel()

treq, err = v.ValidateUnary(ctx)
if err != nil {
if err := v.ValidateUnary(ctx); err != nil {
return err
}
err = transport.DispatchUnaryHandler(ctx, spec.Unary(), start, treq, newResponseWriter(w))

case transport.Oneway:
treq, err = v.ValidateOneway(ctx)
if err != nil {
if err := v.ValidateOneway(ctx); err != nil {
return err
}
err = handleOnewayRequest(span, treq, spec.Oneway())
Expand Down
10 changes: 5 additions & 5 deletions transport/tchannel/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ func (h handler) callHandler(ctx context.Context, call inboundCall, start time.T
rw := newResponseWriter(treq, call)
defer rw.Close() // TODO(abg): log if this errors

treq, err = request.Validate(ctx, treq)
if err != nil {
v := request.Validator{Request: treq}
if err := v.ValidateCommon(ctx); err != nil {
return err
}

Expand All @@ -159,10 +159,10 @@ func (h handler) callHandler(ctx context.Context, call inboundCall, start time.T

switch spec.Type() {
case transport.Unary:
treq, err = request.ValidateUnary(ctx, treq)
if err == nil {
err = transport.DispatchUnaryHandler(ctx, spec.Unary(), start, treq, rw)
if err := v.ValidateUnary(ctx); err != nil {
return err
}
err = transport.DispatchUnaryHandler(ctx, spec.Unary(), start, treq, rw)

default:
err = errors.UnsupportedTypeError{Transport: "TChannel", Type: string(spec.Type())}
Expand Down
6 changes: 2 additions & 4 deletions transport/x/redis/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ func (i *Inbound) handle() error {
defer span.Finish()

v := request.Validator{Request: req}
req, err = v.Validate(ctx)
if err != nil {
if err := v.ValidateCommon(ctx); err != nil {
return transport.UpdateSpanWithErr(span, err)
}

Expand All @@ -166,8 +165,7 @@ func (i *Inbound) handle() error {
return transport.UpdateSpanWithErr(span, err)
}

req, err = v.ValidateOneway(ctx)
if err != nil {
if err := v.ValidateOneway(ctx); err != nil {
return transport.UpdateSpanWithErr(span, err)
}

Expand Down

0 comments on commit 689d2e0

Please sign in to comment.