Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http3: use the connection, not the stream context, on the server side #4510

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion http3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ func (c *SingleDestinationRoundTripper) Start() Connection {
func (c *SingleDestinationRoundTripper) init() {
c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
c.requestWriter = newRequestWriter()
c.hconn = newConnection(c.Connection, c.EnableDatagrams, protocol.PerspectiveClient, c.Logger)
c.hconn = newConnection(
c.Connection.Context(),
c.Connection,
c.EnableDatagrams,
protocol.PerspectiveClient,
c.Logger,
)
// send the SETTINGs frame, using 0-RTT data, if possible
go func() {
if err := c.setupConn(c.hconn); err != nil {
Expand Down
5 changes: 5 additions & 0 deletions http3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ var _ = Describe("Client", func() {
return len(b), nil
})
conn := mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().OpenStreamSync(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
<-settingsFrameWritten
Expand Down Expand Up @@ -360,6 +361,7 @@ var _ = Describe("Client", func() {
<-done
return nil, errors.New("test done")
}).MaxTimes(1)
conn.EXPECT().Context().Return(context.Background())
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{ExtendedConnect: true}).Append(b)
r := bytes.NewReader(b)
Expand Down Expand Up @@ -392,6 +394,7 @@ var _ = Describe("Client", func() {
wg.Done()
return nil, errors.New("test done")
}).MaxTimes(1)
conn.EXPECT().Context().Return(context.Background())
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{ExtendedConnect: true}).Append(b)
r := bytes.NewReader(b)
Expand Down Expand Up @@ -427,6 +430,7 @@ var _ = Describe("Client", func() {
var wg sync.WaitGroup
wg.Add(2)
conn := mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().OpenUniStream().DoAndReturn(func() (quic.SendStream, error) {
<-done
wg.Done()
Expand Down Expand Up @@ -507,6 +511,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Context().Return(context.Background()).AnyTimes()
str.EXPECT().StreamID().AnyTimes()
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
Expand Down
8 changes: 6 additions & 2 deletions http3/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Connection interface {

type connection struct {
quic.Connection
ctx context.Context

perspective protocol.Perspective
logger *slog.Logger
Expand All @@ -53,12 +54,14 @@ type connection struct {
}

func newConnection(
ctx context.Context,
quicConn quic.Connection,
enableDatagrams bool,
perspective protocol.Perspective,
logger *slog.Logger,
) *connection {
c := &connection{
return &connection{
ctx: ctx,
Connection: quicConn,
perspective: perspective,
logger: logger,
Expand All @@ -67,7 +70,6 @@ func newConnection(
receivedSettings: make(chan struct{}),
streams: make(map[protocol.StreamID]*datagrammer),
}
return c
}

func (c *connection) clearStream(id quic.StreamID) {
Expand Down Expand Up @@ -264,3 +266,5 @@ func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSetti
// Settings returns the settings received on this connection.
// It is only valid to call this function after the channel returned by ReceivedSettings was closed.
func (c *connection) Settings() *Settings { return c.settings }

func (c *connection) Context() context.Context { return c.ctx }
10 changes: 10 additions & 0 deletions http3/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var _ = Describe("Connection", func() {
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(nil, errors.New("no datagrams"))
conn := newConnection(
context.Background(),
qconn,
false,
protocol.PerspectiveServer,
Expand Down Expand Up @@ -56,6 +57,7 @@ var _ = Describe("Connection", func() {
It("rejects duplicate control streams", func() {
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
conn := newConnection(
context.Background(),
qconn,
false,
protocol.PerspectiveServer,
Expand Down Expand Up @@ -97,6 +99,7 @@ var _ = Describe("Connection", func() {
It(fmt.Sprintf("ignores the QPACK %s streams", name), func() {
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
conn := newConnection(
context.Background(),
qconn,
false,
protocol.PerspectiveClient,
Expand Down Expand Up @@ -125,6 +128,7 @@ var _ = Describe("Connection", func() {
It(fmt.Sprintf("rejects duplicate QPACK %s streams", name), func() {
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
conn := newConnection(
context.Background(),
qconn,
false,
protocol.PerspectiveClient,
Expand Down Expand Up @@ -160,6 +164,7 @@ var _ = Describe("Connection", func() {
It("resets streams other than the control stream and the QPACK streams", func() {
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
conn := newConnection(
context.Background(),
qconn,
false,
protocol.PerspectiveServer,
Expand All @@ -185,6 +190,7 @@ var _ = Describe("Connection", func() {
It("errors when the first frame on the control stream is not a SETTINGS frame", func() {
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
conn := newConnection(
context.Background(),
qconn,
false,
protocol.PerspectiveServer,
Expand Down Expand Up @@ -215,6 +221,7 @@ var _ = Describe("Connection", func() {
It("errors when parsing the frame on the control stream fails", func() {
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
conn := newConnection(
context.Background(),
qconn,
false,
protocol.PerspectiveServer,
Expand Down Expand Up @@ -252,6 +259,7 @@ var _ = Describe("Connection", func() {
It(fmt.Sprintf("errors when parsing the %s opens a push stream", pers), func() {
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
conn := newConnection(
context.Background(),
qconn,
false,
pers.Opposite(),
Expand Down Expand Up @@ -281,6 +289,7 @@ var _ = Describe("Connection", func() {
It("errors when the server advertises datagram support (and we enabled support for it)", func() {
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
conn := newConnection(
context.Background(),
qconn,
true,
protocol.PerspectiveClient,
Expand Down Expand Up @@ -319,6 +328,7 @@ var _ = Describe("Connection", func() {
BeforeEach(func() {
qconn = mockquic.NewMockEarlyConnection(mockCtrl)
conn = newConnection(
context.Background(),
qconn,
true,
protocol.PerspectiveClient,
Expand Down
5 changes: 3 additions & 2 deletions http3/http_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http3

import (
"bytes"
"context"
"io"
"math"
"net/http"
Expand Down Expand Up @@ -42,7 +43,7 @@ var _ = Describe("Stream", func() {
errorCbCalled = true
return nil
}).AnyTimes()
str = newStream(qstr, newConnection(conn, false, protocol.PerspectiveClient, nil), nil)
str = newStream(qstr, newConnection(context.Background(), conn, false, protocol.PerspectiveClient, nil), nil)
})

It("reads DATA frames in a single run", func() {
Expand Down Expand Up @@ -170,7 +171,7 @@ var _ = Describe("Request Stream", func() {
requestWriter := newRequestWriter()
conn := mockquic.NewMockEarlyConnection(mockCtrl)
str = newRequestStream(
newStream(qstr, newConnection(conn, false, protocol.PerspectiveClient, nil), nil),
newStream(qstr, newConnection(context.Background(), conn, false, protocol.PerspectiveClient, nil), nil),
requestWriter,
make(chan struct{}),
qpack.NewDecoder(func(qpack.HeaderField) {}),
Expand Down
30 changes: 17 additions & 13 deletions http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,8 @@ type Server struct {
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)

// ConnContext optionally specifies a function that modifies
// the context used for a new connection c. The provided ctx
// has a ServerContextKey value.
// ConnContext optionally specifies a function that modifies the context used for a new connection c.
// The provided ctx has a ServerContextKey value.
ConnContext func(ctx context.Context, c quic.Connection) context.Context

Logger *slog.Logger
Expand Down Expand Up @@ -436,7 +435,19 @@ func (s *Server) handleConn(conn quic.Connection) error {
}).Append(b)
str.Write(b)

ctx := conn.Context()
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
if s.ConnContext != nil {
ctx = s.ConnContext(ctx, conn)
if ctx == nil {
panic("http3: ConnContext returned nil")
}
}

hconn := newConnection(
ctx,
conn,
s.EnableDatagrams,
protocol.PerspectiveServer,
Expand Down Expand Up @@ -533,17 +544,10 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
s.Logger.Debug("handling request", "method", req.Method, "host", req.Host, "uri", req.RequestURI)
}

ctx := str.Context()
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
if s.ConnContext != nil {
ctx = s.ConnContext(ctx, conn.Connection)
if ctx == nil {
panic("http3: ConnContext returned nil")
}
}
ctx, cancel := context.WithCancel(conn.Context())
req = req.WithContext(ctx)
context.AfterFunc(str.Context(), cancel)

r := newResponseWriter(hstr, conn, req.Method == http.MethodHead, s.Logger)
handler := s.Handler
if handler == nil {
Expand Down
12 changes: 8 additions & 4 deletions http3/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ var _ = Describe("Server", func() {
qconn.EXPECT().LocalAddr().AnyTimes()
qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
conn = newConnection(qconn, false, protocol.PerspectiveServer, nil)
conn = newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil)
})

It("calls the HTTP handler function", func() {
Expand All @@ -169,8 +169,6 @@ var _ = Describe("Server", func() {
Eventually(requestChan).Should(Receive(&req))
Expect(req.Host).To(Equal("www.example.com"))
Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337"))
Expect(req.Context().Value(ServerContextKey)).To(Equal(s))
Expect(req.Context().Value(testConnContextKey("test"))).To(Equal(conn.Connection))
})

It("returns 200 with an empty handler", func() {
Expand Down Expand Up @@ -555,6 +553,7 @@ var _ = Describe("Server", func() {
conn = mockquic.NewMockEarlyConnection(mockCtrl)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any())
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
Expand Down Expand Up @@ -734,7 +733,9 @@ var _ = Describe("Server", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
Expect(r.Context().Done()).To(BeClosed())
// The context is canceled via context.AfterFunc,
// which performs the cancellation in a new Go routine.
Eventually(r.Context().Done()).Should(BeClosed())
Expect(r.Context().Err()).To(MatchError(context.Canceled))
close(handlerCalled)
})
Expand Down Expand Up @@ -1159,6 +1160,9 @@ var _ = Describe("Server", func() {
conn := mockquic.NewMockEarlyConnection(mockCtrl)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any())
conn.EXPECT().LocalAddr()
conn.EXPECT().RemoteAddr()
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
testDone := make(chan struct{})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
Expand Down
44 changes: 40 additions & 4 deletions integrationtests/self/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ var _ = Describe("HTTP tests", func() {
mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
defer close(handlerCalled)
// TODO(4508): check for request context cancellations
for {
if _, err := w.Write([]byte("foobar")); err != nil {
Expect(r.Context().Done()).To(BeClosed())
var http3Err *http3.Error
Expect(errors.As(err, &http3Err)).To(BeTrue())
Expect(http3Err.ErrorCode).To(Equal(http3.ErrCode(0x10c)))
Expand Down Expand Up @@ -570,7 +570,7 @@ var _ = Describe("HTTP tests", func() {
tracingID = c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
return ctx
}
mux.HandleFunc("/conn-context", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/http3-conn-context", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
v, ok := r.Context().Value(ctxKey(0)).(string)
Expect(ok).To(BeTrue())
Expand All @@ -589,9 +589,45 @@ var _ = Describe("HTTP tests", func() {
Expect(id).To(Equal(tracingID))
})

resp, err := client.Get(fmt.Sprintf("https://localhost:%d/conn-context", port))
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/http3-conn-context", port))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
Expect(resp.StatusCode).To(Equal(http.StatusOK))
})

It("uses the QUIC connection context", func() {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func() context.Context {
//nolint:staticcheck
return context.WithValue(context.Background(), "foo", "bar")
},
}
defer tr.Close()
tlsConf := getTLSConfig()
tlsConf.NextProtos = []string{http3.NextProtoH3}
ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer ln.Close()

mux.HandleFunc("/quic-conn-context", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
v, ok := r.Context().Value("foo").(string)
Expect(ok).To(BeTrue())
Expect(v).To(Equal("bar"))
})
go func() {
defer GinkgoRecover()
c, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
server.ServeQUICConn(c)
}()

resp, err := client.Get(fmt.Sprintf("https://localhost:%d/quic-conn-context", conn.LocalAddr().(*net.UDPAddr).Port))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
})

It("checks the server's settings", func() {
Expand Down
Loading