Skip to content

Commit

Permalink
Merge pull request #1259 from weaveworks/1257-proxy-chunking
Browse files Browse the repository at this point in the history
Proxy chunking should handle arbitrarily large chunks
  • Loading branch information
tomwilkie committed Aug 10, 2015
2 parents fd2a594 + 6d7e784 commit 42debde
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 196 deletions.
123 changes: 123 additions & 0 deletions proxy/chunked.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Based on net/http/internal
package proxy

import (
"bufio"
"bytes"
"errors"
"io"
"io/ioutil"
"strconv"
)

var (
ErrLineTooLong = errors.New("header line too long")
ErrInvalidChunkLength = errors.New("invalid byte in chunk length")
)

// Unlike net/http/internal.chunkedReader, this has an interface where we can
// handle individual chunks. The interface is based on database/sql.Rows.
func NewChunkedReader(r io.Reader) *ChunkedReader {
br, ok := r.(*bufio.Reader)
if !ok {
br = bufio.NewReader(r)
}
return &ChunkedReader{r: br}
}

type ChunkedReader struct {
r *bufio.Reader
chunk *io.LimitedReader
err error
buf [2]byte
}

// Next prepares the next chunk for reading. It returns true on success, or
// false if there is no next chunk or an error happened while preparing
// it. Err should be consulted to distinguish between the two cases.
//
// Every call to Chunk, even the first one, must be preceded by a call to Next.
//
// Calls to Next will discard any unread bytes in the current Chunk.
func (cr *ChunkedReader) Next() bool {
if cr.err != nil {
return false
}

// Check the termination of the previous chunk
if cr.chunk != nil {
// Make sure the remainder is drained, in case the user of this quit
// reading early.
if _, cr.err = io.Copy(ioutil.Discard, cr.chunk); cr.err != nil {
return false
}

// Check the next two bytes after the chunk are \r\n
if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err != nil {
return false
}
if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
return false
}
} else {
cr.chunk = &io.LimitedReader{R: cr.r}
}

// Setup the next chunk
if n := cr.beginChunk(); n > 0 {
cr.chunk.N = int64(n)
} else if cr.err == nil {
cr.err = io.EOF
}
return cr.err == nil
}

// Chunk returns the io.Reader of the current chunk. On each call, this returns
// the same io.Reader for a given chunk.
func (cr *ChunkedReader) Chunk() io.Reader {
return cr.chunk
}

// Err returns the error, if any, that was encountered during iteration.
func (cr *ChunkedReader) Err() error {
if cr.err == io.EOF {
return nil
}
return cr.err
}

func (cr *ChunkedReader) beginChunk() uint64 {
var (
line []byte
n uint64
)
// chunk-size CRLF
line, cr.err = readLine(cr.r)
if cr.err != nil {
return 0
}
n, cr.err = strconv.ParseUint(string(line), 16, 64)
if cr.err != nil {
cr.err = ErrInvalidChunkLength
}
return n
}

// Read a line of bytes (up to \n) from b.
// Give up if the line exceeds the buffer size.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
func readLine(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if err == bufio.ErrBufferFull {
err = ErrLineTooLong
}
return nil, err
}
return bytes.TrimRight(p, " \t\n\r"), nil
}
174 changes: 174 additions & 0 deletions proxy/chunked_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Based on net/http/internal
package proxy

import (
"bytes"
"io"
"io/ioutil"
"strconv"
"strings"
"testing"
)

func TestChunk(t *testing.T) {
r := NewChunkedReader(bytes.NewBufferString(
"7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n",
))

assertNextChunk(t, r, "hello, ")
assertNextChunk(t, r, "world! 0123456789abcdef")
assertNoMoreChunks(t, r)
}

func TestIncompleteReadOfChunk(t *testing.T) {
r := NewChunkedReader(bytes.NewBufferString(
"7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n",
))

// Incomplete read of first chunk
{
if !r.Next() {
t.Fatalf("Expected chunk, but ran out early: %v", r.Err())
}
if r.Err() != nil {
t.Fatalf("Error reading chunk: %q", r.Err())
}
// Read just 2 bytes
buf := make([]byte, 2)
if _, err := io.ReadFull(r.Chunk(), buf[:2]); err != nil {
t.Fatalf("Error reading first bytes of chunk: %q", err)
}
if buf[0] != 'h' || buf[1] != 'e' {
t.Fatalf("Unexpected first 2 bytes of chunk: %q", buf)
}
}

assertNextChunk(t, r, "world! 0123456789abcdef")
assertNoMoreChunks(t, r)
}

func TestMalformedChunks(t *testing.T) {
r := NewChunkedReader(bytes.NewBufferString(
"7\r\nhello, GARBAGEBYTES17\r\nworld! 0123456789abcdef\r\n0\r\n",
))

assertNextChunk(t, r, "hello, ")
assertError(t, r, "malformed chunked encoding")
}

type charReader byte

// Read an infinite sequence of some char
func (r *charReader) Read(p []byte) (int, error) {
b := byte(*r)
for i := range p {
p[i] = b
}
return len(p), nil
}

func TestLargeChunks(t *testing.T) {
var expected int64 = 1024 * 1024
chars := charReader('a')
r := NewChunkedReader(io.MultiReader(
strings.NewReader(strconv.FormatInt(expected, 16)+"\r\n"),
&io.LimitedReader{N: expected, R: &chars},
strings.NewReader("\r\n0\r\n"),
))

if !r.Next() {
t.Fatalf("Expected chunk, but ran out early: %v", r.Err())
}
if r.Err() != nil {
t.Fatalf("Error reading chunk: %q", r.Err())
}
n, err := io.Copy(ioutil.Discard, r.Chunk())
if n != expected {
t.Errorf("chunk reader read %q; want %q", n, expected)
}
if err != nil {
t.Fatalf("reading chunk: %v", err)
}
assertNoMoreChunks(t, r)
}

func TestInvalidChunkSize(t *testing.T) {
r := NewChunkedReader(bytes.NewBufferString(
"foobar\r\nhello, \r\n0\r\n",
))

assertError(t, r, "invalid byte in chunk length")
}

func TestChunkSizeLineTooLong(t *testing.T) {
var (
maxLineLength = 4096
chunkSize string
)
for i := 0; i < maxLineLength; i++ {
chunkSize = chunkSize + "0"
}
chunkSize = chunkSize + "7"

r := NewChunkedReader(bytes.NewBufferString(
chunkSize + "\r\nhello, \r\n0\r\n",
))

assertError(t, r, "header line too long")
}

func TestBytesAfterLastChunkAreIgnored(t *testing.T) {
r := NewChunkedReader(bytes.NewBufferString(
"7\r\nhello, \r\n0\r\nGARBAGEBYTES",
))

assertNextChunk(t, r, "hello, ")
assertNoMoreChunks(t, r)
}

func assertNextChunk(t *testing.T, r *ChunkedReader, expected string) {
if !r.Next() {
t.Fatalf("Expected chunk, but ran out early: %v", r.Err())
}
if r.Err() != nil {
t.Fatalf("Error reading chunk: %q", r.Err())
}
data, err := ioutil.ReadAll(r.Chunk())
if string(data) != expected {
t.Errorf("chunk reader read %q; want %q", data, expected)
}
if err != nil {
t.Logf(`data: %q`, data)
t.Fatalf("reading chunk: %v", err)
}
}

func assertError(t *testing.T, r *ChunkedReader, e string) {
if r.Next() {
t.Errorf("Expected failure when reading chunks, but got one")
}
if r.Err() == nil || r.Err().Error() != e {
t.Errorf("chunk reader errored %q; want %q", r.Err(), e)
}
data, err := ioutil.ReadAll(r.Chunk())
if len(data) != 0 {
t.Errorf("chunk should have been empty. got %q", data)
}
if err != nil {
t.Logf(`data: %q`, data)
t.Errorf("reading chunk: %v", err)
}

if r.Next() {
t.Errorf("Expected no more chunks, but found too many")
}
}

func assertNoMoreChunks(t *testing.T, r *ChunkedReader) {
if r.Next() {
t.Errorf("Expected no more chunks, but found too many")
}
if r.Err() != nil {
t.Errorf("Expected no error, but found: %q", r.Err())
}
}
63 changes: 3 additions & 60 deletions proxy/proxy_intercept.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,18 @@ package proxy

import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"strconv"
"sync"

"github.com/fsouza/go-dockerclient"
. "github.com/weaveworks/weave/common"
)

const (
maxLineLength = 4096 // assumed <= bufio.defaultBufSize
maxChunkSize = bufio.MaxScanTokenSize
)

var (
ErrChunkTooLong = errors.New("chunk too long")
ErrInvalidChunkLength = errors.New("invalid byte in chunk length")
ErrLineTooLong = errors.New("header line too long")
ErrMalformedChunkEncoding = errors.New("malformed chunked encoding")
)

func (proxy *Proxy) Intercept(i interceptor, w http.ResponseWriter, r *http.Request) {
if err := i.InterceptRequest(r); err != nil {
switch err.(type) {
Expand Down Expand Up @@ -153,10 +139,9 @@ func doChunkedResponse(w http.ResponseWriter, resp *http.Response, client *httpu
defer up.Close()

var err error
chunks := bufio.NewScanner(io.MultiReader(remaining, up))
chunks.Split(splitChunks)
for chunks.Scan() && err == nil {
_, err = wf.Write(chunks.Bytes())
chunks := NewChunkedReader(io.MultiReader(remaining, up))
for chunks.Next() && err == nil {
_, err = io.Copy(wf, chunks.Chunk())
wf.Flush()
}
if err == nil {
Expand All @@ -167,48 +152,6 @@ func doChunkedResponse(w http.ResponseWriter, resp *http.Response, client *httpu
}
}

// a bufio.SplitFunc for http chunks
func splitChunks(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}

i := bytes.IndexByte(data, '\n')
if i < 0 {
return 0, nil, nil
}
if i > maxLineLength {
return 0, nil, ErrLineTooLong
}

chunkSize64, err := strconv.ParseInt(
string(bytes.TrimRight(data[:i], " \t\r\n")),
16,
64,
)
switch {
case err != nil:
return 0, nil, ErrInvalidChunkLength
case chunkSize64 > maxChunkSize:
return 0, nil, ErrChunkTooLong
case chunkSize64 == 0:
return 0, nil, io.EOF
}
chunkSize := int(chunkSize64)

data = data[i+1:]

if len(data) < chunkSize+2 {
return 0, nil, nil
}

if data[chunkSize] != '\r' || data[chunkSize+1] != '\n' {
return 0, nil, ErrMalformedChunkEncoding
}

return i + chunkSize + 3, data[:chunkSize], nil
}

func hijack(w http.ResponseWriter, client *httputil.ClientConn) (down net.Conn, downBuf *bufio.ReadWriter, up net.Conn, remaining io.Reader, err error) {
hj, ok := w.(http.Hijacker)
if !ok {
Expand Down

0 comments on commit 42debde

Please sign in to comment.