Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http3: automatically add content-length for small responses #3989

Merged
merged 9 commits into from
Aug 21, 2023
97 changes: 66 additions & 31 deletions http3/response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,61 @@ import (
"github.com/quic-go/qpack"
)

// The maximum length of an encoded HTTP/3 frame header is 16:
// The frame has a type and length field, both QUIC varints (maximum 8 bytes in length)
const frameHeaderLen = 16

// headerWriter wraps the stream, so that the first Write call flushes the header to the stream
type headerWriter struct {
str quic.Stream
header http.Header
status int // status code passed to WriteHeader
written bool

logger utils.Logger
}

// writeHeader encodes and flush header to the stream
func (hw *headerWriter) writeHeader() error {
var headers bytes.Buffer
enc := qpack.NewEncoder(&headers)
enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)})

for k, v := range hw.header {
for index := range v {
enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
}
}

buf := make([]byte, 0, frameHeaderLen+headers.Len())
buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
hw.logger.Infof("Responding with %d", hw.status)
buf = append(buf, headers.Bytes()...)

_, err := hw.str.Write(buf)
return err
}

// first Write will trigger flushing header
func (hw *headerWriter) Write(p []byte) (int, error) {
if !hw.written {
if err := hw.writeHeader(); err != nil {
return 0, err
}
hw.written = true
}
return hw.str.Write(p)
}

type responseWriter struct {
*headerWriter
conn quic.Connection
str quic.Stream
bufferedStr *bufio.Writer
buf []byte

header http.Header
status int // status code passed to WriteHeader
headerWritten bool
contentLen int64 // if handler set valid Content-Length header
numWritten int64 // bytes written

logger utils.Logger
}

var (
Expand All @@ -37,13 +79,16 @@ var (
)

func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
hw := &headerWriter{
str: str,
header: http.Header{},
logger: logger,
}
return &responseWriter{
header: http.Header{},
buf: make([]byte, 16),
conn: conn,
str: str,
bufferedStr: bufio.NewWriter(str),
logger: logger,
headerWriter: hw,
buf: make([]byte, frameHeaderLen),
conn: conn,
bufferedStr: bufio.NewWriter(hw),
}
}

Expand Down Expand Up @@ -83,27 +128,8 @@ func (w *responseWriter) WriteHeader(status int) {
}
w.status = status

var headers bytes.Buffer
enc := qpack.NewEncoder(&headers)
enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})

for k, v := range w.header {
for index := range v {
enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
}
}

w.buf = w.buf[:0]
w.buf = (&headersFrame{Length: uint64(headers.Len())}).Append(w.buf)
w.logger.Infof("Responding with %d", status)
if _, err := w.bufferedStr.Write(w.buf); err != nil {
w.logger.Errorf("could not write headers frame: %s", err.Error())
}
if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil {
w.logger.Errorf("could not write header frame payload: %s", err.Error())
}
if !w.headerWritten {
w.Flush()
w.writeHeader()
}
}

Expand Down Expand Up @@ -146,6 +172,15 @@ func (w *responseWriter) Write(p []byte) (int, error) {
}

func (w *responseWriter) FlushError() error {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
if !w.written {
if err := w.writeHeader(); err != nil {
return err
}
w.written = true
}
return w.bufferedStr.Flush()
}

Expand Down
8 changes: 7 additions & 1 deletion http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net"
"net/http"
"runtime"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -627,7 +628,12 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q

// only write response when there is no panic
if !panicked {
r.WriteHeader(http.StatusOK)
// response not written to the client yet, set Content-Length
if !r.written {
if _, haveCL := r.header["Content-Length"]; !haveCL {
r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10))
}
}
r.Flush()
}
// If the EOF was read by the handler, CancelRead() is a no-op.
Expand Down
41 changes: 41 additions & 0 deletions http3/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,47 @@ var _ = Describe("Server", func() {
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
})

It("sets Content-Length when the handler doesn't flush to the client", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foobar"))
})

responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())

serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"}))
// status, content-length, date, content-type
Expect(hfs).To(HaveLen(4))
})

It("not sets Content-Length when the handler flushes to the client", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foobar"))
// force flush
w.(http.Flusher).Flush()
})

responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())

serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
// status, date, content-type
Expect(hfs).To(HaveLen(3))
})

It("handles a aborting handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic(http.ErrAbortHandler)
Expand Down
12 changes: 12 additions & 0 deletions integrationtests/self/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ var _ = Describe("HTTP tests", func() {
Expect(string(body)).To(Equal("Hello, World!\n"))
})

It("sets content-length for small response", func() {
mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
w.Write([]byte("foobar"))
})

resp, err := client.Get(fmt.Sprintf("https://localhost:%d/small", port))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
Expect(resp.Header.Get("Content-Length")).To(Equal(strconv.Itoa(len("foobar"))))
})

It("requests to different servers with the same udpconn", func() {
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/remoteAddr", port))
Expect(err).ToNot(HaveOccurred())
Expand Down
Loading