diff --git a/instrumentation/nethttp/client.go b/instrumentation/nethttp/client.go index 75fa6ff1..1f72ae9c 100644 --- a/instrumentation/nethttp/client.go +++ b/instrumentation/nethttp/client.go @@ -222,43 +222,73 @@ func (t *Transport) doRoundTrip(req *http.Request) (*http.Response, error) { // Gets the request payload func getRequestPayload(req *http.Request, bufferSize int) string { - var rqPayload string - if req != nil && req.Body != nil && req.Body != http.NoBody && req.GetBody != nil { - rqBody, rqErr := req.GetBody() - if rqErr == nil { - rqBodyBuffer := make([]byte, bufferSize) - if len, err := rqBody.Read(rqBodyBuffer); err == nil && len > 0 { - if len < bufferSize { - rqBodyBuffer = rqBodyBuffer[:len] - } - rqRunes := bytes.Runes(rqBodyBuffer) - rqPayload = string(rqRunes) - } + if req == nil || req.Body == nil || req.Body == http.NoBody { + return "" + } + if req.GetBody == nil { + // GetBody is nil in server requests + nBody, payload := getBodyPayload(req.Body, bufferSize) + req.Body = nBody + return payload + } + rqBody, rqErr := req.GetBody() + if rqErr != nil { + return "" + } + rqBodyBuffer := make([]byte, bufferSize) + if ln, err := rqBody.Read(rqBodyBuffer); err == nil && ln > 0 { + if ln < bufferSize { + rqBodyBuffer = rqBodyBuffer[:ln] } + return string(bytes.Runes(rqBodyBuffer)) } - return rqPayload + return "" +} + +// Gets the payload from a body +func getBodyPayload(body io.ReadCloser, bufferSize int) (io.ReadCloser, string) { + if body == nil { + return body, "" + } + rsBodyBuffer := make([]byte, bufferSize) + ln, _ := body.Read(rsBodyBuffer) + if ln == 0 { + return body, "" + } + if ln < bufferSize { + rsBodyBuffer = rsBodyBuffer[:ln] + } + rsPayload := string(bytes.Runes(rsBodyBuffer)) + rBody := struct { + io.Reader + io.Closer + }{ + io.MultiReader(bytes.NewReader(rsBodyBuffer), body), + body, + } + return rBody, rsPayload } // Gets the response payload func getResponsePayload(resp *http.Response, bufferSize int) string { - var rsPayload string - if resp != nil && resp.Body != nil && resp.Body != http.NoBody { - rsBodyBuffer := make([]byte, bufferSize) - len, _ := resp.Body.Read(rsBodyBuffer) - if len > 0 { - if len < bufferSize { - rsBodyBuffer = rsBodyBuffer[:len] - } - rsRunes := bytes.Runes(rsBodyBuffer) - rsPayload = string(rsRunes) - resp.Body = struct { - io.Reader - io.Closer - }{ - io.MultiReader(bytes.NewReader(rsBodyBuffer), resp.Body), - resp.Body, - } - } + if resp == nil || resp.Body == nil || resp.Body == http.NoBody { + return "" + } + rsBodyBuffer := make([]byte, bufferSize) + ln, _ := resp.Body.Read(rsBodyBuffer) + if ln == 0 { + return "" + } + if ln < bufferSize { + rsBodyBuffer = rsBodyBuffer[:ln] + } + rsPayload := string(bytes.Runes(rsBodyBuffer)) + resp.Body = struct { + io.Reader + io.Closer + }{ + io.MultiReader(bytes.NewReader(rsBodyBuffer), resp.Body), + resp.Body, } return rsPayload } diff --git a/instrumentation/nethttp/nethttp_test.go b/instrumentation/nethttp/nethttp_test.go index 61cb9de2..ffbedc01 100644 --- a/instrumentation/nethttp/nethttp_test.go +++ b/instrumentation/nethttp/nethttp_test.go @@ -1,6 +1,7 @@ package nethttp import ( + "bytes" "fmt" "net/http" "net/http/httptest" @@ -15,7 +16,7 @@ import ( var r *tracer.InMemorySpanRecorder func TestMain(m *testing.M) { - PatchHttpDefaultClient() + PatchHttpDefaultClient(WithPayloadInstrumentation()) // Test tracer r = tracer.NewInMemoryRecorder() @@ -64,10 +65,10 @@ func TestHttpServer(t *testing.T) { return } }) - server := httptest.NewServer(Middleware(nil)) + server := httptest.NewServer(Middleware(nil, MWPayloadInstrumentation())) url := fmt.Sprintf("%s/hello", server.URL) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequest("POST", url, bytes.NewReader([]byte("Hello world request"))) if err != nil { t.Fatalf("%+v", err) } @@ -87,18 +88,22 @@ func TestHttpServer(t *testing.T) { t.Fatalf("there aren't the right number of spans: %d", len(spans)) } checkTags(t, spans[0].Tags, map[string]string{ - "component": "net/http", - "http.method": "GET", - "http.url": "/hello", - "span.kind": "server", - "http.status_code": "200", + "component": "net/http", + "http.method": "POST", + "http.url": "/hello", + "span.kind": "server", + "http.status_code": "200", + "http.request_payload": "Hello world request", + "http.response_payload": "Hello world", }) checkTags(t, spans[1].Tags, map[string]string{ - "component": "net/http", - "http.method": "GET", - "http.url": url, - "peer.ipv4": "127.0.0.1", - "span.kind": "client", + "component": "net/http", + "http.method": "POST", + "http.url": url, + "peer.ipv4": "127.0.0.1", + "span.kind": "client", + "http.request_payload": "Hello world request", + "http.response_payload": "Hello world", }) }