Skip to content

Commit

Permalink
Merge fa00b7c into 689d2e0
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinav committed Dec 16, 2016
2 parents 689d2e0 + fa00b7c commit 8cb4a01
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 107 deletions.
19 changes: 0 additions & 19 deletions internal/request/errors.go
Expand Up @@ -57,22 +57,3 @@ func (e missingParametersError) Error() string {
s += fmt.Sprintf(", and %s", ps[len(ps)-1])
return s
}

// invalidTTLError is a failure to process a request because the TTL was in an
// invalid format.
type invalidTTLError struct {
Service string
Procedure string
TTL string
}

func (e invalidTTLError) AsHandlerError() errors.HandlerError {
return errors.HandlerBadRequestError(e)
}

func (e invalidTTLError) Error() string {
return fmt.Sprintf(
`invalid TTL %q for procedure %q of service %q: must be positive integer`,
e.TTL, e.Procedure, e.Service,
)
}
43 changes: 1 addition & 42 deletions internal/request/validator.go
Expand Up @@ -22,9 +22,6 @@ package request

import (
"context"
"fmt"
"strconv"
"time"

"go.uber.org/yarpc/api/transport"
)
Expand All @@ -34,11 +31,9 @@ import (
// v := Validator{Request: request}
// v.ValidateCommon(ctx)
// ...
// v.ParseTTL(ttlstring)
// request, err := v.ValidateUnary(ctx)
// err := v.ValidateUnary(ctx)
type Validator struct {
Request *transport.Request
errTTL error
}

// ValidateUnary validates a unary request.
Expand All @@ -59,38 +54,6 @@ func ValidateOneway(ctx context.Context, req *transport.Request) error {
return v.ValidateOneway(ctx)
}

// ParseTTL takes a context parses the given TTL, clamping the context to that TTL
// and as a side-effect, tracking any errors encountered while attempting to
// parse and validate that TTL. Should only be used for unary requests
func (v *Validator) ParseTTL(ctx context.Context, ttl string) (context.Context, func()) {
if ttl == "" {
// The TTL is missing so set it to 0 and let ValidateUnary() fail with
// the correct error message.
return ctx, func() {}
}

ttlms, err := strconv.Atoi(ttl)
if err != nil {
v.errTTL = invalidTTLError{
Service: v.Request.Service,
Procedure: v.Request.Procedure,
TTL: ttl,
}
return ctx, func() {}
}
// negative TTLs are invalid
if ttlms < 0 {
v.errTTL = invalidTTLError{
Service: v.Request.Service,
Procedure: v.Request.Procedure,
TTL: fmt.Sprint(ttlms),
}
return ctx, func() {}
}

return context.WithTimeout(ctx, time.Duration(ttlms)*time.Millisecond)
}

// ValidateCommon checks validity of the common attributes of the request.
// This should be used to check ALL requests prior to calling
// RPC-type-specific validators.
Expand Down Expand Up @@ -119,10 +82,6 @@ func (v *Validator) ValidateCommon(ctx context.Context) error {
// ValidateUnary validates a unary request. This should be used after a
// successful v.ValidateCommon()
func (v *Validator) ValidateUnary(ctx context.Context) error {
if v.errTTL != nil {
return v.errTTL
}

if _, hasDeadline := ctx.Deadline(); !hasDeadline {
return missingParametersError{Parameters: []string{"TTL"}}
}
Expand Down
41 changes: 1 addition & 40 deletions internal/request/validator_test.go
Expand Up @@ -34,9 +34,7 @@ func TestValidator(t *testing.T) {
tests := []struct {
req *transport.Request
transportType transport.Type

ttl time.Duration
ttlString string // set to try parseTTL
ttl time.Duration

wantErr error
wantMessage string
Expand Down Expand Up @@ -118,38 +116,6 @@ func TestValidator(t *testing.T) {
},
wantMessage: "missing service name, procedure, caller name, and encoding",
},
{
req: &transport.Request{
Caller: "caller",
Service: "service",
Encoding: "raw",
Procedure: "hello",
},
transportType: transport.Unary,
ttlString: "-1000",
wantErr: invalidTTLError{
Service: "service",
Procedure: "hello",
TTL: "-1000",
},
wantMessage: `invalid TTL "-1000" for procedure "hello" of service "service": must be positive integer`,
},
{
req: &transport.Request{
Caller: "caller",
Service: "service",
Encoding: "raw",
Procedure: "hello",
},
transportType: transport.Unary,
ttlString: "not an integer",
wantErr: invalidTTLError{
Service: "service",
Procedure: "hello",
TTL: "not an integer",
},
wantMessage: `invalid TTL "not an integer" for procedure "hello" of service "service": must be positive integer`,
},
}

for _, tt := range tests {
Expand All @@ -168,11 +134,6 @@ func TestValidator(t *testing.T) {
defer cancel()
}

if tt.ttlString != "" {
ctx, cancel = v.ParseTTL(ctx, tt.ttlString)
defer cancel()
}

err = v.ValidateUnary(ctx)
}

Expand Down
12 changes: 7 additions & 5 deletions transport/http/handler.go
Expand Up @@ -85,15 +85,14 @@ func (h handler) callHandler(w http.ResponseWriter, req *http.Request, start tim
}

ctx := req.Context()

v := request.Validator{Request: treq}
ctx, cancel := v.ParseTTL(ctx, popHeader(req.Header, TTLMSHeader))
ctx, cancel, parseTTLErr := parseTTL(ctx, treq, popHeader(req.Header, TTLMSHeader))
// parseTTLErr != nil is a problem only if the request is unary.
defer cancel()

ctx, span := h.createSpan(ctx, req, treq, start)

err := v.ValidateCommon(ctx)
if err != nil {
v := request.Validator{Request: treq}
if err := v.ValidateCommon(ctx); err != nil {
return err
}

Expand All @@ -105,6 +104,9 @@ func (h handler) callHandler(w http.ResponseWriter, req *http.Request, start tim
switch spec.Type() {
case transport.Unary:
defer span.Finish()
if parseTTLErr != nil {
return parseTTLErr
}

if err := v.ValidateUnary(ctx); err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion transport/http/handler_test.go
Expand Up @@ -265,7 +265,7 @@ func TestHandlerFailures(t *testing.T) {

code := rw.Code
assert.True(t, code >= 400 && code < 500, "expected 400 level code")
assert.Equal(t, rw.Body.String(), tt.msg)
assert.Equal(t, tt.msg, rw.Body.String())
}
}

Expand Down
82 changes: 82 additions & 0 deletions transport/http/ttl.go
@@ -0,0 +1,82 @@
// Copyright (c) 2016 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package http

import (
"context"
"fmt"
"strconv"
"time"

"go.uber.org/yarpc/api/transport"
"go.uber.org/yarpc/internal/errors"
)

// parseTTL takes a context parses the given TTL, clamping the context to that
// TTL and as a side-effect, tracking any errors encountered while attempting
// to parse and validate that TTL.
//
// Leaves the context unchanged if the TTL is empty.
func parseTTL(ctx context.Context, req *transport.Request, ttl string) (_ context.Context, cancel func(), _ error) {
if ttl == "" {
return ctx, func() {}, nil
}

ttlms, err := strconv.Atoi(ttl)
if err != nil {
return ctx, func() {}, invalidTTLError{
Service: req.Service,
Procedure: req.Procedure,
TTL: ttl,
}
}

// negative TTLs are invalid
if ttlms < 0 {
return ctx, func() {}, invalidTTLError{
Service: req.Service,
Procedure: req.Procedure,
TTL: fmt.Sprint(ttlms),
}
}

ctx, cancel = context.WithTimeout(ctx, time.Duration(ttlms)*time.Millisecond)
return ctx, cancel, nil
}

// invalidTTLError is a failure to process a request because the TTL was in an
// invalid format.
type invalidTTLError struct {
Service string
Procedure string
TTL string
}

func (e invalidTTLError) AsHandlerError() errors.HandlerError {
return errors.HandlerBadRequestError(e)
}

func (e invalidTTLError) Error() string {
return fmt.Sprintf(
`invalid TTL %q for procedure %q of service %q: must be positive integer`,
e.TTL, e.Procedure, e.Service,
)
}
58 changes: 58 additions & 0 deletions transport/http/ttl_test.go
@@ -0,0 +1,58 @@
package http

import (
"context"
"testing"

"go.uber.org/yarpc/api/transport"

"github.com/stretchr/testify/assert"
)

func TestParseTTL(t *testing.T) {
req := &transport.Request{
Caller: "caller",
Service: "service",
Procedure: "hello",
Encoding: "raw",
}

tests := []struct {
ttlString string
wantErr error
wantMessage string
}{
{
ttlString: "-1000",
wantErr: invalidTTLError{
Service: "service",
Procedure: "hello",
TTL: "-1000",
},
wantMessage: `invalid TTL "-1000" for procedure "hello" of service "service": must be positive integer`,
},
{
ttlString: "not an integer",
wantErr: invalidTTLError{
Service: "service",
Procedure: "hello",
TTL: "not an integer",
},
wantMessage: `invalid TTL "not an integer" for procedure "hello" of service "service": must be positive integer`,
},
}

for _, tt := range tests {
ctx, cancel, err := parseTTL(context.Background(), req, tt.ttlString)
defer cancel()

if tt.wantErr != nil && assert.Error(t, err) {
assert.Equal(t, tt.wantErr, err)
assert.Equal(t, tt.wantMessage, err.Error())
} else {
assert.NoError(t, err)
_, ok := ctx.Deadline()
assert.True(t, ok)
}
}
}

0 comments on commit 8cb4a01

Please sign in to comment.