Skip to content

Commit

Permalink
[proxy] rewrote chunked response handler
Browse files Browse the repository at this point in the history
1) We cannot send "Connection: close", because the fsouza docker client
   expects the tcp socket to stay open between requests.

2) Because we cannot force-close the connection, we can't hijack the
   connection (because go's net/http doesn't let use un-hijack it).

3) Because we need to maintain the individual chunking of messages (for
   docker-py), we can't just copy the response body, as Go will remove and
   re-add the chunking willy-nilly.

Therefore, we have to read each chunk one-by-one, and flush the
ResponseWriter after each one.
  • Loading branch information
paulbellamy committed Jul 20, 2015
1 parent bbacaa4 commit 11663dd
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 23 deletions.
107 changes: 84 additions & 23 deletions proxy/proxy_intercept.go
Expand Up @@ -2,18 +2,31 @@ package proxy

import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httputil"
"strconv"

"github.com/fsouza/go-dockerclient"
. "github.com/weaveworks/weave/common"
)

const (
maxLineLength = 4096 // assumed <= bufio.defaultBufSize
maxChunkSize = bufio.MaxScanTokenSize
)

var (
ErrChunkTooLong = errors.New("chunk too long")
ErrInvalidChunkLength = errors.New("invalid byte in chunk length")
ErrLineTooLong = errors.New("header line too long")
ErrMalformedChunkEncoding = errors.New("malformed chunked encoding")
)

func (proxy *Proxy) Intercept(i interceptor, w http.ResponseWriter, r *http.Request) {
if err := i.InterceptRequest(r); err != nil {
switch err.(type) {
Expand Down Expand Up @@ -71,7 +84,7 @@ func (proxy *Proxy) Intercept(i interceptor, w http.ResponseWriter, r *http.Requ
}

func doRawStream(w http.ResponseWriter, resp *http.Response, client *httputil.ClientConn) {
down, downBuf, up, rem, err := hijack(w, client)
down, downBuf, up, remaining, err := hijack(w, client)
if err != nil {
http.Error(w, "Unable to hijack connection for raw stream mode", http.StatusInternalServerError)
return
Expand All @@ -96,7 +109,7 @@ func doRawStream(w http.ResponseWriter, resp *http.Response, client *httputil.Cl

upDone := make(chan struct{})
downDone := make(chan struct{})
go copyStream(down, io.MultiReader(rem, up), upDone)
go copyStream(down, io.MultiReader(remaining, up), upDone)
go copyStream(up, downBuf, downDone)
<-upDone
<-downDone
Expand All @@ -116,33 +129,81 @@ func copyStream(dst io.Writer, src io.Reader, done chan struct{}) {
}
}

type writeFlusher interface {
io.Writer
http.Flusher
}

func doChunkedResponse(w http.ResponseWriter, resp *http.Response, client *httputil.ClientConn) {
// Because we can't go back to request/response after we
// hijack the connection, we need to close it and make the
// client open another.
w.Header().Add("Connection", "close")
wf, ok := w.(writeFlusher)
if !ok {
http.Error(w, "Error forwarding chunked response body: flush not available", http.StatusInternalServerError)
return
}

w.WriteHeader(resp.StatusCode)

down, _, up, rem, err := hijack(w, client)
up, remaining := client.Hijack()
defer up.Close()

var err error
chunks := bufio.NewScanner(io.MultiReader(remaining, up))
chunks.Split(splitChunks)
for chunks.Scan() && err == nil {
_, err = wf.Write(chunks.Bytes())
wf.Flush()
}
if err == nil {
err = chunks.Err()
}
if err != nil {
http.Error(w, "Unable to hijack response stream for chunked response", http.StatusInternalServerError)
return
Log.Errorf("Error forwarding chunked response body: %s", err)
}
defer up.Close()
defer down.Close()
// Copy the chunked response body to downstream,
// stopping at the end of the chunked section.
rawResponseBody := io.MultiReader(rem, up)
if _, err := io.Copy(ioutil.Discard, httputil.NewChunkedReader(io.TeeReader(rawResponseBody, down))); err != nil {
http.Error(w, "Error copying chunked response body", http.StatusInternalServerError)
return
}

// a bufio.SplitFunc for http chunks
func splitChunks(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
resp.Trailer.Write(down)
// a chunked response ends with a CRLF
down.Write([]byte("\r\n"))

i := bytes.IndexByte(data, '\n')
if i < 0 {
return 0, nil, nil
}
if i > maxLineLength {
return 0, nil, ErrLineTooLong
}

chunkSize64, err := strconv.ParseInt(
string(bytes.TrimRight(data[:i], " \t\r\n")),
16,
64,
)
switch {
case err != nil:
return 0, nil, ErrInvalidChunkLength
case chunkSize64 > maxChunkSize:
return 0, nil, ErrChunkTooLong
case chunkSize64 == 0:
return 0, nil, io.EOF
}
chunkSize := int(chunkSize64)

data = data[i+1:]

if len(data) < chunkSize+2 {
return 0, nil, nil
}

if data[chunkSize] != '\r' || data[chunkSize+1] != '\n' {
return 0, nil, ErrMalformedChunkEncoding
}

return i + chunkSize + 3, data[:chunkSize], nil
}

func hijack(w http.ResponseWriter, client *httputil.ClientConn) (down net.Conn, downBuf *bufio.ReadWriter, up net.Conn, rem io.Reader, err error) {
func hijack(w http.ResponseWriter, client *httputil.ClientConn) (down net.Conn, downBuf *bufio.ReadWriter, up net.Conn, remaining io.Reader, err error) {
hj, ok := w.(http.Hijacker)
if !ok {
err = errors.New("Unable to cast to Hijack")
Expand All @@ -152,6 +213,6 @@ func hijack(w http.ResponseWriter, client *httputil.ClientConn) (down net.Conn,
if err != nil {
return
}
up, rem = client.Hijack()
up, remaining = client.Hijack()
return
}
136 changes: 136 additions & 0 deletions proxy/proxy_intercept_test.go
@@ -0,0 +1,136 @@
// Based on net/http/internal
package proxy

import (
"bufio"
"bytes"
"strconv"
"strings"
"testing"
)

func TestChunk(t *testing.T) {
r := bufio.NewScanner(bytes.NewBufferString(
"7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n",
))
r.Split(splitChunks)

assertNextChunk(t, r, "hello, ")
assertNextChunk(t, r, "world! 0123456789abcdef")
assertNoMoreChunks(t, r)
}

func TestMalformedChunks(t *testing.T) {
r := bufio.NewScanner(bytes.NewBufferString(
"7\r\nhello, GARBAGEBYTES17\r\nworld! 0123456789abcdef\r\n0\r\n",
))
r.Split(splitChunks)

// First chunk fails
{
if r.Scan() {
t.Errorf("Expected failure when reading chunks, but got one")
}
e := "malformed chunked encoding"
if r.Err() == nil || r.Err().Error() != e {
t.Errorf("chunk reader errored %q; want %q", r.Err(), e)
}
data := r.Bytes()
if len(data) != 0 {
t.Errorf("chunk should have been empty. got %q", data)
}
}

if r.Scan() {
t.Errorf("Expected no more chunks, but found too many")
}
}

func TestChunkTooLarge(t *testing.T) {
data := make([]byte, maxChunkSize+1)
r := bufio.NewScanner(bytes.NewBufferString(strings.Join(
[]string{
strconv.FormatInt(maxChunkSize+1, 16), string(data),
"0", "",
},
"\r\n",
)))
r.Split(splitChunks)

// First chunk fails
{
if r.Scan() {
t.Errorf("Expected failure when reading chunks, but got one")
}
e := "chunk too long"
if r.Err() == nil || r.Err().Error() != e {
t.Errorf("chunk reader errored %q; want %q", r.Err(), e)
}
data := r.Bytes()
if len(data) != 0 {
t.Errorf("chunk should have been empty. got %q", data)
}
}

if r.Scan() {
t.Errorf("Expected no more chunks, but found too many")
}
}

func TestInvalidChunkSize(t *testing.T) {
r := bufio.NewScanner(bytes.NewBufferString(
"foobar\r\nhello, \r\n0\r\n",
))
r.Split(splitChunks)

// First chunk fails
{
if r.Scan() {
t.Errorf("Expected failure when reading chunks, but got one")
}
e := "invalid byte in chunk length"
if r.Err() == nil || r.Err().Error() != e {
t.Errorf("chunk reader errored %q; want %q", r.Err(), e)
}
data := r.Bytes()
if len(data) != 0 {
t.Errorf("chunk should have been empty. got %q", data)
}
}

if r.Scan() {
t.Errorf("Expected no more chunks, but found too many")
}
}

func TestBytesAfterLastChunkAreIgnored(t *testing.T) {
r := bufio.NewScanner(bytes.NewBufferString(
"7\r\nhello, \r\n0\r\nGARBAGEBYTES",
))
r.Split(splitChunks)

assertNextChunk(t, r, "hello, ")
assertNoMoreChunks(t, r)
}

func assertNextChunk(t *testing.T, r *bufio.Scanner, expected string) {
if !r.Scan() {
t.Fatalf("Expected chunk, but ran out early: %v", r.Err())
}
if r.Err() != nil {
t.Fatalf("Error reading chunk: %q", r.Err())
}
data := r.Bytes()
if string(data) != expected {
t.Errorf("chunk reader read %q; want %q", data, expected)
}
}

func assertNoMoreChunks(t *testing.T, r *bufio.Scanner) {
if r.Scan() {
t.Errorf("Expected no more chunks, but found too many")
}
if r.Err() != nil {
t.Errorf("Expected no error, but found: %q", r.Err())
}
}

0 comments on commit 11663dd

Please sign in to comment.