Skip to content

Commit

Permalink
http3: implement on the HTTPStreamer on the ResponseWriter, flush hea…
Browse files Browse the repository at this point in the history
…der (#4469)

Currently the HTTPStreamer is implemented on the http.Request.Body. This
complicates usage, since it's not easily possible to flush the HTTP
header, requiring users to manually flash the header before taking over
the stream.

With this change, the HTTP header is now flushed automatically as soon
as HTTPStream is called.
  • Loading branch information
marten-seemann committed Apr 27, 2024
1 parent 083ceb4 commit 34f4d14
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 30 deletions.
24 changes: 1 addition & 23 deletions http3/body.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@ import (
"github.com/quic-go/quic-go"
)

// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by:
// * for the server: the http.Request.Body
// * for the client: the http.Response.Body
// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set.
// When a stream is taken over, it's the caller's responsibility to close the stream.
type HTTPStreamer interface {
HTTPStream() Stream
}

// A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body.
// It is used by WebTransport to create WebTransport streams after a session has been established.
type Hijacker interface {
Expand All @@ -32,8 +23,6 @@ type body struct {
remainingContentLength int64
violatedContentLength bool
hasContentLength bool

wasHijacked bool // set when HTTPStream is called
}

func newBody(str *stream, contentLength int64) *body {
Expand All @@ -45,15 +34,7 @@ func newBody(str *stream, contentLength int64) *body {
return b
}

func (r *body) HTTPStream() Stream {
r.wasHijacked = true
return r.str
}

func (r *body) StreamID() quic.StreamID { return r.str.StreamID() }
func (r *body) wasStreamHijacked() bool {
return r.wasHijacked
}

func (r *body) checkContentLengthViolation() error {
if !r.hasContentLength {
Expand Down Expand Up @@ -97,10 +78,7 @@ type requestBody struct {
getSettings func() *Settings
}

var (
_ io.ReadCloser = &requestBody{}
_ HTTPStreamer = &requestBody{}
)
var _ io.ReadCloser = &requestBody{}

func newRequestBody(str *stream, contentLength int64, connCtx context.Context, rcvdSettings <-chan struct{}, getSettings func() *Settings) *requestBody {
return &requestBody{
Expand Down
18 changes: 18 additions & 0 deletions http3/response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ import (
"github.com/quic-go/qpack"
)

// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented the http.Response.Body.
// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set.
// When a stream is taken over, it's the caller's responsibility to close the stream.
type HTTPStreamer interface {
HTTPStream() Stream
}

// The maximum length of an encoded HTTP/3 frame header is 16:
// The frame has a type and length field, both QUIC varints (maximum 8 bytes in length)
const frameHeaderLen = 16
Expand All @@ -36,13 +43,16 @@ type responseWriter struct {
headerWritten bool // set once the response header has been serialized to the stream
isHead bool

hijacked bool // set on HTTPStream is called

logger *slog.Logger
}

var (
_ http.ResponseWriter = &responseWriter{}
_ http.Flusher = &responseWriter{}
_ Hijacker = &responseWriter{}
_ HTTPStreamer = &responseWriter{}
)

func newResponseWriter(str *stream, conn Connection, isHead bool, logger *slog.Logger) *responseWriter {
Expand Down Expand Up @@ -220,6 +230,14 @@ func (w *responseWriter) Flush() {
}
}

func (w *responseWriter) HTTPStream() Stream {
w.hijacked = true
w.Flush()
return w.str
}

func (w *responseWriter) wasStreamHijacked() bool { return w.hijacked }

func (w *responseWriter) Connection() Connection {
return w.conn
}
Expand Down
2 changes: 1 addition & 1 deletion http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
handler.ServeHTTP(r, req)
}()

if body.wasStreamHijacked() {
if r.wasStreamHijacked() {
return
}

Expand Down
20 changes: 18 additions & 2 deletions http3/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ var _ = Describe("Server", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(handlerCalled)
r.Body.(HTTPStreamer).HTTPStream()
w.(HTTPStreamer).HTTPStream()
str.Write([]byte("foobar"))
})

Expand All @@ -590,10 +590,26 @@ var _ = Describe("Server", func() {
b = append(b, []byte("foobar")...)
setRequest(append(requestData, b...))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write([]byte("foobar")).Return(6, nil)
var buf bytes.Buffer
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()

s.handleConn(conn)
Eventually(handlerCalled).Should(BeClosed())

// The buffer is expected to contain:
// 1. The response header (in a HEADERS frame)
// 2. the "foobar" (unframed)
frame, err := parseNextFrame(&buf, nil)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
df := frame.(*headersFrame)
data := make([]byte, df.Length)
_, err = io.ReadFull(&buf, data)
Expect(err).ToNot(HaveOccurred())
hdrs, err := qpackDecoder.DecodeFull(data)
Expect(err).ToNot(HaveOccurred())
Expect(hdrs).To(ContainElement(qpack.HeaderField{Name: ":status", Value: "200"}))
Expect(buf.Bytes()).To(Equal([]byte("foobar")))
})

It("errors when the client sends a too large header frame", func() {
Expand Down
6 changes: 2 additions & 4 deletions integrationtests/self/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,8 @@ var _ = Describe("HTTP tests", func() {
defer GinkgoRecover()
close(handlerCalled)
w.WriteHeader(http.StatusOK)
w.(http.Flusher).Flush()

str := r.Body.(http3.HTTPStreamer).HTTPStream()
str := w.(http3.HTTPStreamer).HTTPStream()
str.Write([]byte("foobar"))

// Do this in a Go routine, so that the handler returns early.
Expand Down Expand Up @@ -734,9 +733,8 @@ var _ = Describe("HTTP tests", func() {
Eventually(conn.ReceivedSettings()).Should(BeClosed())
Expect(conn.Settings().EnableDatagrams).To(BeTrue())
w.WriteHeader(http.StatusOK)
w.(http.Flusher).Flush()

str := r.Body.(http3.HTTPStreamer).HTTPStream()
str := w.(http3.HTTPStreamer).HTTPStream()
go str.Read([]byte{0}) // need to continue reading from stream to observe state transitions

for {
Expand Down

0 comments on commit 34f4d14

Please sign in to comment.