From d8fe4c014a33327c1b3c9f6bd93ac962fa72ebfe Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 13 May 2024 18:11:21 +0800 Subject: [PATCH] http3: use the connection, not the stream context, on the server side --- http3/client.go | 8 +++++- http3/client_test.go | 5 ++++ http3/conn.go | 8 ++++-- http3/conn_test.go | 10 +++++++ http3/http_stream_test.go | 5 ++-- http3/server.go | 30 +++++++++++--------- http3/server_test.go | 12 +++++--- integrationtests/self/http_test.go | 44 +++++++++++++++++++++++++++--- 8 files changed, 96 insertions(+), 26 deletions(-) diff --git a/http3/client.go b/http3/client.go index f5370549bab..23ac088fe95 100644 --- a/http3/client.go +++ b/http3/client.go @@ -82,7 +82,13 @@ func (c *SingleDestinationRoundTripper) Start() Connection { func (c *SingleDestinationRoundTripper) init() { c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {}) c.requestWriter = newRequestWriter() - c.hconn = newConnection(c.Connection, c.EnableDatagrams, protocol.PerspectiveClient, c.Logger) + c.hconn = newConnection( + c.Connection.Context(), + c.Connection, + c.EnableDatagrams, + protocol.PerspectiveClient, + c.Logger, + ) // send the SETTINGs frame, using 0-RTT data, if possible go func() { if err := c.setupConn(c.hconn); err != nil { diff --git a/http3/client_test.go b/http3/client_test.go index f368ffa9a1b..4e3b0f23e1b 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -326,6 +326,7 @@ var _ = Describe("Client", func() { return len(b), nil }) conn := mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().Context().Return(context.Background()) conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().OpenStreamSync(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { <-settingsFrameWritten @@ -360,6 +361,7 @@ var _ = Describe("Client", func() { <-done return nil, errors.New("test done") }).MaxTimes(1) + conn.EXPECT().Context().Return(context.Background()) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{ExtendedConnect: true}).Append(b) r := bytes.NewReader(b) @@ -392,6 +394,7 @@ var _ = Describe("Client", func() { wg.Done() return nil, errors.New("test done") }).MaxTimes(1) + conn.EXPECT().Context().Return(context.Background()) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{ExtendedConnect: true}).Append(b) r := bytes.NewReader(b) @@ -427,6 +430,7 @@ var _ = Describe("Client", func() { var wg sync.WaitGroup wg.Add(2) conn := mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().Context().Return(context.Background()) conn.EXPECT().OpenUniStream().DoAndReturn(func() (quic.SendStream, error) { <-done wg.Done() @@ -507,6 +511,7 @@ var _ = Describe("Client", func() { str.EXPECT().Context().Return(context.Background()).AnyTimes() str.EXPECT().StreamID().AnyTimes() conn = mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().Context().Return(context.Background()) conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone diff --git a/http3/conn.go b/http3/conn.go index 7ea4b292918..df7fb2825bc 100644 --- a/http3/conn.go +++ b/http3/conn.go @@ -37,6 +37,7 @@ type Connection interface { type connection struct { quic.Connection + ctx context.Context perspective protocol.Perspective logger *slog.Logger @@ -53,12 +54,14 @@ type connection struct { } func newConnection( + ctx context.Context, quicConn quic.Connection, enableDatagrams bool, perspective protocol.Perspective, logger *slog.Logger, ) *connection { - c := &connection{ + return &connection{ + ctx: ctx, Connection: quicConn, perspective: perspective, logger: logger, @@ -67,7 +70,6 @@ func newConnection( receivedSettings: make(chan struct{}), streams: make(map[protocol.StreamID]*datagrammer), } - return c } func (c *connection) clearStream(id quic.StreamID) { @@ -264,3 +266,5 @@ func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSetti // Settings returns the settings received on this connection. // It is only valid to call this function after the channel returned by ReceivedSettings was closed. func (c *connection) Settings() *Settings { return c.settings } + +func (c *connection) Context() context.Context { return c.ctx } diff --git a/http3/conn_test.go b/http3/conn_test.go index 8af906cd0e3..4be8b7ae549 100644 --- a/http3/conn_test.go +++ b/http3/conn_test.go @@ -24,6 +24,7 @@ var _ = Describe("Connection", func() { qconn := mockquic.NewMockEarlyConnection(mockCtrl) qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(nil, errors.New("no datagrams")) conn := newConnection( + context.Background(), qconn, false, protocol.PerspectiveServer, @@ -56,6 +57,7 @@ var _ = Describe("Connection", func() { It("rejects duplicate control streams", func() { qconn := mockquic.NewMockEarlyConnection(mockCtrl) conn := newConnection( + context.Background(), qconn, false, protocol.PerspectiveServer, @@ -97,6 +99,7 @@ var _ = Describe("Connection", func() { It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { qconn := mockquic.NewMockEarlyConnection(mockCtrl) conn := newConnection( + context.Background(), qconn, false, protocol.PerspectiveClient, @@ -125,6 +128,7 @@ var _ = Describe("Connection", func() { It(fmt.Sprintf("rejects duplicate QPACK %s streams", name), func() { qconn := mockquic.NewMockEarlyConnection(mockCtrl) conn := newConnection( + context.Background(), qconn, false, protocol.PerspectiveClient, @@ -160,6 +164,7 @@ var _ = Describe("Connection", func() { It("resets streams other than the control stream and the QPACK streams", func() { qconn := mockquic.NewMockEarlyConnection(mockCtrl) conn := newConnection( + context.Background(), qconn, false, protocol.PerspectiveServer, @@ -185,6 +190,7 @@ var _ = Describe("Connection", func() { It("errors when the first frame on the control stream is not a SETTINGS frame", func() { qconn := mockquic.NewMockEarlyConnection(mockCtrl) conn := newConnection( + context.Background(), qconn, false, protocol.PerspectiveServer, @@ -215,6 +221,7 @@ var _ = Describe("Connection", func() { It("errors when parsing the frame on the control stream fails", func() { qconn := mockquic.NewMockEarlyConnection(mockCtrl) conn := newConnection( + context.Background(), qconn, false, protocol.PerspectiveServer, @@ -252,6 +259,7 @@ var _ = Describe("Connection", func() { It(fmt.Sprintf("errors when parsing the %s opens a push stream", pers), func() { qconn := mockquic.NewMockEarlyConnection(mockCtrl) conn := newConnection( + context.Background(), qconn, false, pers.Opposite(), @@ -281,6 +289,7 @@ var _ = Describe("Connection", func() { It("errors when the server advertises datagram support (and we enabled support for it)", func() { qconn := mockquic.NewMockEarlyConnection(mockCtrl) conn := newConnection( + context.Background(), qconn, true, protocol.PerspectiveClient, @@ -319,6 +328,7 @@ var _ = Describe("Connection", func() { BeforeEach(func() { qconn = mockquic.NewMockEarlyConnection(mockCtrl) conn = newConnection( + context.Background(), qconn, true, protocol.PerspectiveClient, diff --git a/http3/http_stream_test.go b/http3/http_stream_test.go index a97dcb7b46f..a4c889e2cf5 100644 --- a/http3/http_stream_test.go +++ b/http3/http_stream_test.go @@ -2,6 +2,7 @@ package http3 import ( "bytes" + "context" "io" "math" "net/http" @@ -42,7 +43,7 @@ var _ = Describe("Stream", func() { errorCbCalled = true return nil }).AnyTimes() - str = newStream(qstr, newConnection(conn, false, protocol.PerspectiveClient, nil), nil) + str = newStream(qstr, newConnection(context.Background(), conn, false, protocol.PerspectiveClient, nil), nil) }) It("reads DATA frames in a single run", func() { @@ -170,7 +171,7 @@ var _ = Describe("Request Stream", func() { requestWriter := newRequestWriter() conn := mockquic.NewMockEarlyConnection(mockCtrl) str = newRequestStream( - newStream(qstr, newConnection(conn, false, protocol.PerspectiveClient, nil), nil), + newStream(qstr, newConnection(context.Background(), conn, false, protocol.PerspectiveClient, nil), nil), requestWriter, make(chan struct{}), qpack.NewDecoder(func(qpack.HeaderField) {}), diff --git a/http3/server.go b/http3/server.go index 5d7aec8a558..18853b89d5f 100644 --- a/http3/server.go +++ b/http3/server.go @@ -194,9 +194,8 @@ type Server struct { // In that case, the stream type will not be set. UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool) - // ConnContext optionally specifies a function that modifies - // the context used for a new connection c. The provided ctx - // has a ServerContextKey value. + // ConnContext optionally specifies a function that modifies the context used for a new connection c. + // The provided ctx has a ServerContextKey value. ConnContext func(ctx context.Context, c quic.Connection) context.Context Logger *slog.Logger @@ -436,7 +435,19 @@ func (s *Server) handleConn(conn quic.Connection) error { }).Append(b) str.Write(b) + ctx := conn.Context() + ctx = context.WithValue(ctx, ServerContextKey, s) + ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr()) + ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr()) + if s.ConnContext != nil { + ctx = s.ConnContext(ctx, conn) + if ctx == nil { + panic("http3: ConnContext returned nil") + } + } + hconn := newConnection( + ctx, conn, s.EnableDatagrams, protocol.PerspectiveServer, @@ -533,17 +544,10 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat s.Logger.Debug("handling request", "method", req.Method, "host", req.Host, "uri", req.RequestURI) } - ctx := str.Context() - ctx = context.WithValue(ctx, ServerContextKey, s) - ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr()) - ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr()) - if s.ConnContext != nil { - ctx = s.ConnContext(ctx, conn.Connection) - if ctx == nil { - panic("http3: ConnContext returned nil") - } - } + ctx, cancel := context.WithCancel(conn.Context()) req = req.WithContext(ctx) + context.AfterFunc(str.Context(), cancel) + r := newResponseWriter(hstr, conn, req.Method == http.MethodHead, s.Logger) handler := s.Handler if handler == nil { diff --git a/http3/server_test.go b/http3/server_test.go index f6b635fed29..7e2138d5c1e 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -148,7 +148,7 @@ var _ = Describe("Server", func() { qconn.EXPECT().LocalAddr().AnyTimes() qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() qconn.EXPECT().Context().Return(context.Background()).AnyTimes() - conn = newConnection(qconn, false, protocol.PerspectiveServer, nil) + conn = newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil) }) It("calls the HTTP handler function", func() { @@ -169,8 +169,6 @@ var _ = Describe("Server", func() { Eventually(requestChan).Should(Receive(&req)) Expect(req.Host).To(Equal("www.example.com")) Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) - Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) - Expect(req.Context().Value(testConnContextKey("test"))).To(Equal(conn.Connection)) }) It("returns 200 with an empty handler", func() { @@ -555,6 +553,7 @@ var _ = Describe("Server", func() { conn = mockquic.NewMockEarlyConnection(mockCtrl) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Write(gomock.Any()) + conn.EXPECT().Context().Return(context.Background()) conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone @@ -734,7 +733,9 @@ var _ = Describe("Server", func() { handlerCalled := make(chan struct{}) s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer GinkgoRecover() - Expect(r.Context().Done()).To(BeClosed()) + // The context is canceled via context.AfterFunc, + // which performs the cancellation in a new Go routine. + Eventually(r.Context().Done()).Should(BeClosed()) Expect(r.Context().Err()).To(MatchError(context.Canceled)) close(handlerCalled) }) @@ -1159,6 +1160,9 @@ var _ = Describe("Server", func() { conn := mockquic.NewMockEarlyConnection(mockCtrl) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Write(gomock.Any()) + conn.EXPECT().LocalAddr() + conn.EXPECT().RemoteAddr() + conn.EXPECT().Context().Return(context.Background()) conn.EXPECT().OpenUniStream().Return(controlStr, nil) testDone := make(chan struct{}) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 0264d1a4e4a..db5523d69af 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -350,9 +350,9 @@ var _ = Describe("HTTP tests", func() { mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) { defer GinkgoRecover() defer close(handlerCalled) + // TODO(4508): check for request context cancellations for { if _, err := w.Write([]byte("foobar")); err != nil { - Expect(r.Context().Done()).To(BeClosed()) var http3Err *http3.Error Expect(errors.As(err, &http3Err)).To(BeTrue()) Expect(http3Err.ErrorCode).To(Equal(http3.ErrCode(0x10c))) @@ -570,7 +570,7 @@ var _ = Describe("HTTP tests", func() { tracingID = c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) return ctx } - mux.HandleFunc("/conn-context", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/http3-conn-context", func(w http.ResponseWriter, r *http.Request) { defer GinkgoRecover() v, ok := r.Context().Value(ctxKey(0)).(string) Expect(ok).To(BeTrue()) @@ -589,9 +589,45 @@ var _ = Describe("HTTP tests", func() { Expect(id).To(Equal(tracingID)) }) - resp, err := client.Get(fmt.Sprintf("https://localhost:%d/conn-context", port)) + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/http3-conn-context", port)) Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }) + + It("uses the QUIC connection context", func() { + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + tr := &quic.Transport{ + Conn: conn, + ConnContext: func() context.Context { + //nolint:staticcheck + return context.WithValue(context.Background(), "foo", "bar") + }, + } + defer tr.Close() + tlsConf := getTLSConfig() + tlsConf.NextProtos = []string{http3.NextProtoH3} + ln, err := tr.Listen(tlsConf, getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + mux.HandleFunc("/quic-conn-context", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + v, ok := r.Context().Value("foo").(string) + Expect(ok).To(BeTrue()) + Expect(v).To(Equal("bar")) + }) + go func() { + defer GinkgoRecover() + c, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + server.ServeQUICConn(c) + }() + + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/quic-conn-context", conn.LocalAddr().(*net.UDPAddr).Port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) }) It("checks the server's settings", func() {