Skip to content

Commit

Permalink
Merge 8621ed6 into 913abee
Browse files Browse the repository at this point in the history
  • Loading branch information
prashantv committed Jun 8, 2017
2 parents 913abee + 8621ed6 commit f493ff9
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 17 deletions.
4 changes: 3 additions & 1 deletion bench_method_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"time"

"github.com/yarpc/yab/encoding"
"github.com/yarpc/yab/merge"
"github.com/yarpc/yab/transport"

"github.com/opentracing/opentracing-go"
Expand All @@ -45,10 +46,11 @@ func benchmarkMethodForTest(t *testing.T, procedure string, p transport.Protocol
serializer, err := NewSerializer(rOpts)
require.NoError(t, err, "Failed to create Thrift serializer")

serializer = withTransportSerializer(p, serializer, rOpts)
tHeaders, serializer := withTransportSerializer(p, serializer, rOpts)

req, err := serializer.Request(nil)
require.NoError(t, err, "Failed to serialize Thrift body")
req.TransportHeaders = merge.Headers(req.TransportHeaders, tHeaders)

req.Timeout = time.Second
return benchmarkMethod{serializer, req}
Expand Down
39 changes: 39 additions & 0 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/yarpc/yab/testdata/gen-go/integration"
yintegration "github.com/yarpc/yab/testdata/yarpc/integration"
"github.com/yarpc/yab/testdata/yarpc/integration/fooserver"
"github.com/yarpc/yab/transport"

athrift "github.com/apache/thrift/lib/go/thrift"
"github.com/opentracing/opentracing-go"
Expand Down Expand Up @@ -263,6 +264,44 @@ func TestIntegrationProtocols(t *testing.T) {
}
}

func TestIntegrationThriftEnvelope(t *testing.T) {
tests := []struct {
disableEnvelope bool
wantEnvelopeHeader string
}{
{true, "false"},
{false, "true"},
}

for _, tt := range tests {

var called bool
server := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
called = true
envelope := r.Header.Get(transport.HTTPThriftEnvelopeheader)
assert.Equal(t, tt.wantEnvelopeHeader, envelope)
}))
defer server.Close()

opts := Options{
ROpts: RequestOptions{
ThriftFile: "testdata/integration.thrift",
Procedure: "Foo::bar",
Timeout: timeMillisFlag(time.Second),
ThriftDisableEnvelopes: tt.disableEnvelope,
},
TOpts: TransportOptions{
ServiceName: "foo",
Peers: []string{server.URL},
Jaeger: true,
},
}

runTestWithOpts(opts)
assert.True(t, called, "Server did not receive any call")
}
}

// runTestWithOpts runs with the given options and returns the
// output buffer, as well as the error buffer.
func runTestWithOpts(opts Options) (string, string) {
Expand Down
26 changes: 19 additions & 7 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ import (
"log"
"os"
"regexp"
"strconv"
"strings"
"time"

"github.com/yarpc/yab/encoding"
"github.com/yarpc/yab/merge"
"github.com/yarpc/yab/peerprovider"
"github.com/yarpc/yab/transport"

Expand Down Expand Up @@ -309,7 +311,7 @@ func runWithOptions(opts Options, out output) {
out.Fatalf("Failed while parsing options: %v\n", err)
}

serializer = withTransportSerializer(transport.Protocol(), serializer, opts.ROpts)
tHeaders, serializer := withTransportSerializer(transport.Protocol(), serializer, opts.ROpts)

// req is the transport.Request that will be used to make a call.
req, err := serializer.Request(reqInput)
Expand All @@ -318,7 +320,7 @@ func runWithOptions(opts Options, out output) {
}

req.Headers = headers
req.TransportHeaders = opts.TOpts.TransportHeaders
req.TransportHeaders = merge.Headers(tHeaders, opts.TOpts.TransportHeaders)
req.Timeout = opts.ROpts.Timeout.Duration()
if req.Timeout == 0 {
req.Timeout = time.Second
Expand Down Expand Up @@ -355,13 +357,23 @@ func getTracer(opts Options, out output) (opentracing.Tracer, io.Closer) {

// withTransportSerializer may modify the serializer for the transport used.
// E.g. Thrift payloads are not enveloped when used with TChannel.
func withTransportSerializer(p transport.Protocol, s encoding.Serializer, rOpts RequestOptions) encoding.Serializer {
switch {
case p == transport.TChannel && s.Encoding() == encoding.Thrift,
rOpts.ThriftDisableEnvelopes:
// It also returns any additional transport headers for this protocol/encoding.
func withTransportSerializer(p transport.Protocol, s encoding.Serializer, rOpts RequestOptions) (tHeaders map[string]string, _ encoding.Serializer) {
if s.Encoding() != encoding.Thrift {
return
}

disableEnvelope := p == transport.TChannel || rOpts.ThriftDisableEnvelopes
if disableEnvelope {
s = s.(noEnveloper).WithoutEnvelopes()
}
return s

if p == transport.HTTP {
tHeaders = map[string]string{
transport.HTTPThriftEnvelopeheader: strconv.FormatBool(!disableEnvelope),
}
}
return tHeaders, s
}

// makeRequest makes a request using the given transport.
Expand Down
16 changes: 12 additions & 4 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -774,9 +774,10 @@ func TestWithTransportSerializer(t *testing.T) {
noEnvelopeOpts.ThriftDisableEnvelopes = true

tests := []struct {
protocol transport.Protocol
rOpts RequestOptions
want []byte
protocol transport.Protocol
rOpts RequestOptions
want []byte
wantTransportHeaders map[string]string
}{
{
protocol: transport.HTTP,
Expand All @@ -786,11 +787,17 @@ func TestWithTransportSerializer(t *testing.T) {
Type: wire.Call,
Value: wire.NewValueStruct(wire.Struct{}),
}),
wantTransportHeaders: map[string]string{
"RPC-Thrift-Envelope": "true",
},
},
{
protocol: transport.HTTP,
rOpts: noEnvelopeOpts,
want: []byte{0},
wantTransportHeaders: map[string]string{
"RPC-Thrift-Envelope": "false",
},
},
{
protocol: transport.TChannel,
Expand All @@ -808,13 +815,14 @@ func TestWithTransportSerializer(t *testing.T) {
serializer, err := NewSerializer(tt.rOpts)
require.NoError(t, err, "Failed to create serializer for %+v", tt.rOpts)

serializer = withTransportSerializer(tt.protocol, serializer, tt.rOpts)
tHeaders, serializer := withTransportSerializer(tt.protocol, serializer, tt.rOpts)
req, err := serializer.Request(nil)
if !assert.NoError(t, err, "Failed to serialize request for %+v", tt.rOpts) {
continue
}

assert.Equal(t, tt.want, req.Body, "Body mismatch for %+v", tt.rOpts)
assert.Equal(t, tt.wantTransportHeaders, tHeaders, "Transport headers mismatch for %+v", tt.rOpts)
}
}

Expand Down
21 changes: 21 additions & 0 deletions merge/headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package merge

// Headers merges the set of headers, preferring values in right
// over left if a key exists in both maps.
func Headers(left, right map[string]string) map[string]string {
if len(left) == 0 {
return right
}
if len(right) == 0 {
return left
}

merged := make(map[string]string, len(left)+len(right))
for k, v := range left {
merged[k] = v
}
for k, v := range right {
merged[k] = v
}
return merged
}
41 changes: 41 additions & 0 deletions merge/headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package merge

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestHeaders(t *testing.T) {
tests := []struct {
left map[string]string
right map[string]string
want map[string]string
}{
{
left: nil,
right: nil,
want: nil,
},
{
left: nil,
right: map[string]string{},
want: map[string]string{},
},
{
left: map[string]string{"a": "1"},
right: nil,
want: map[string]string{"a": "1"},
},
{
left: map[string]string{"a": "1", "b": "1"},
right: map[string]string{"a": "2", "c": "2"},
want: map[string]string{"a": "2", "b": "1", "c": "2"},
},
}

for _, tt := range tests {
got := Headers(tt.left, tt.right)
assert.Equal(t, tt.want, got)
}
}
6 changes: 3 additions & 3 deletions template.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ func readYAMLRequest(base string, contents []byte, templateArgs map[string]strin

// Baggage and headers specified with command line flags override those
// specified in YAML templates.
opts.ROpts.Headers = merge(opts.ROpts.Headers, t.Headers)
opts.ROpts.Baggage = merge(opts.ROpts.Baggage, t.Baggage)
opts.ROpts.Headers = mergeInto(opts.ROpts.Headers, t.Headers)
opts.ROpts.Baggage = mergeInto(opts.ROpts.Baggage, t.Baggage)
if t.Jaeger {
opts.TOpts.Jaeger = true
}
Expand Down Expand Up @@ -197,7 +197,7 @@ type headers map[string]string

// In these cases, the existing item (target, from flags) overrides the source
// (template).
func merge(target, source headers) headers {
func mergeInto(target, source headers) headers {
if len(source) == 0 {
return target
}
Expand Down
4 changes: 2 additions & 2 deletions template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func TestAbsPeerListTemplate(t *testing.T) {
assert.Equal(t, "file:///peers.json", opts.TOpts.PeerList)
}

func TestMerge(t *testing.T) {
func TestMergeInto(t *testing.T) {
tests := []struct {
msg string
left, right headers
Expand Down Expand Up @@ -198,7 +198,7 @@ func TestMerge(t *testing.T) {

for _, tt := range tests {
t.Run(tt.msg, func(t *testing.T) {
assert.Equal(t, merge(tt.left, tt.right), tt.want, "merge properly")
assert.Equal(t, mergeInto(tt.left, tt.right), tt.want, "merge properly")
})
}
}
Expand Down
2 changes: 2 additions & 0 deletions transport/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import (
"golang.org/x/net/context"
)

const HTTPThriftEnvelopeheader = "RPC-Thrift-Envelope"

type httpTransport struct {
opts HTTPOptions
client *http.Client
Expand Down

0 comments on commit f493ff9

Please sign in to comment.