Skip to content

Commit

Permalink
remove shutdown method on the Connection (#4249)
Browse files Browse the repository at this point in the history
There's no need to have a dedicated shutdown method, as the use case
(shutting down an outgoing connection attempt on context cancellation)
can be achieved by using Connection.destroy.
  • Loading branch information
marten-seemann committed Jan 19, 2024
1 parent d3c2020 commit b3eb375
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 130 deletions.
2 changes: 1 addition & 1 deletion client.go
Expand Up @@ -232,7 +232,7 @@ func (c *client) dial(ctx context.Context) error {

select {
case <-ctx.Done():
c.conn.shutdown()
c.conn.destroy(nil)
return context.Cause(ctx)
case err := <-errorChan:
return err
Expand Down
2 changes: 1 addition & 1 deletion client_test.go
Expand Up @@ -87,7 +87,7 @@ var _ = Describe("Client", func() {

AfterEach(func() {
if s, ok := cl.conn.(*connection); ok {
s.shutdown()
s.destroy(nil)
}
Eventually(areConnsRunning).Should(BeFalse())
})
Expand Down
2 changes: 0 additions & 2 deletions closed_conn.go
Expand Up @@ -41,7 +41,6 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) {
c.sendPacket(p.remoteAddr, p.info)
}

func (c *closedLocalConn) shutdown() {}
func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }

Expand All @@ -59,6 +58,5 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
}

func (s *closedRemoteConn) handlePacket(receivedPacket) {}
func (s *closedRemoteConn) shutdown() {}
func (s *closedRemoteConn) destroy(error) {}
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }
2 changes: 0 additions & 2 deletions closed_conn_test.go
Expand Up @@ -14,8 +14,6 @@ var _ = Describe("Closed local connection", func() {
It("tells its perspective", func() {
conn := newClosedLocalConn(nil, protocol.PerspectiveClient, utils.DefaultLogger)
Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient))
// stop the connection
conn.shutdown()
})

It("repeats the packet containing the CONNECTION_CLOSE frame", func() {
Expand Down
7 changes: 0 additions & 7 deletions connection.go
Expand Up @@ -1572,13 +1572,6 @@ func (s *connection) closeRemote(e error) {
})
}

// Close the connection. It sends a NO_ERROR application error.
// It waits until the run loop has stopped before returning
func (s *connection) shutdown() {
s.closeLocal(nil)
<-s.ctx.Done()
}

func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error {
s.closeLocal(&qerr.ApplicationError{
ErrorCode: code,
Expand Down
59 changes: 17 additions & 42 deletions connection_test.go
Expand Up @@ -465,7 +465,7 @@ var _ = Describe("Connection", func() {
}),
tracer.EXPECT().Close(),
)
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(areConnsRunning).Should(BeFalse())
Expect(conn.Context().Done()).To(BeClosed())
})
Expand All @@ -479,8 +479,8 @@ var _ = Describe("Connection", func() {
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
conn.shutdown()
conn.CloseWithError(0, "")
conn.CloseWithError(0, "")
Eventually(areConnsRunning).Should(BeFalse())
Expect(conn.Context().Done()).To(BeClosed())
})
Expand Down Expand Up @@ -551,29 +551,6 @@ var _ = Describe("Connection", func() {
}
})

It("cancels the context when the run loop exists", func() {
runConn()
streamManager.EXPECT().CloseWithError(gomock.Any())
expectReplaceWithClosed()
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
returned := make(chan struct{})
go func() {
defer GinkgoRecover()
ctx := conn.Context()
<-ctx.Done()
Expect(ctx.Err()).To(MatchError(context.Canceled))
close(returned)
}()
Consistently(returned).ShouldNot(BeClosed())
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
Eventually(returned).Should(BeClosed())
Expect(context.Cause(conn.Context())).To(MatchError(context.Canceled))
})

It("doesn't send any more packets after receiving a CONNECTION_CLOSE", func() {
unpacker := NewMockUnpacker(mockCtrl)
conn.handshakeConfirmed = true
Expand Down Expand Up @@ -964,7 +941,7 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(conn.Context().Done()).Should(BeClosed())
})

Expand Down Expand Up @@ -1219,7 +1196,7 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
sender.EXPECT().Close()
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(conn.Context().Done()).Should(BeClosed())
Eventually(connDone).Should(BeClosed())
})
Expand Down Expand Up @@ -1422,7 +1399,7 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
sender.EXPECT().Close()
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(conn.Context().Done()).Should(BeClosed())
})

Expand Down Expand Up @@ -1811,7 +1788,7 @@ var _ = Describe("Connection", func() {
sender.EXPECT().Close()
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(conn.Context().Done()).Should(BeClosed())
})

Expand Down Expand Up @@ -1937,7 +1914,7 @@ var _ = Describe("Connection", func() {
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(conn.Context().Done()).Should(BeClosed())
})

Expand Down Expand Up @@ -2059,7 +2036,7 @@ var _ = Describe("Connection", func() {
cryptoSetup.EXPECT().Close()
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(conn.Context().Done()).Should(BeClosed())
})

Expand All @@ -2073,13 +2050,11 @@ var _ = Describe("Connection", func() {
close(done)
}()
streamManager.EXPECT().CloseWithError(gomock.Any())
expectReplaceWithClosed()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
conn.destroy(nil)
Eventually(done).Should(BeClosed())
Expect(context.Cause(conn.Context())).To(MatchError(context.Canceled))
})
Expand Down Expand Up @@ -2165,7 +2140,7 @@ var _ = Describe("Connection", func() {
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(conn.Context().Done()).Should(BeClosed())
})

Expand Down Expand Up @@ -2316,7 +2291,7 @@ var _ = Describe("Connection", func() {
expectReplaceWithClosed()
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(conn.Context().Done()).Should(BeClosed())
})

Expand Down Expand Up @@ -2403,7 +2378,7 @@ var _ = Describe("Connection", func() {
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(conn.Context().Done()).Should(BeClosed())
})

Expand Down Expand Up @@ -2623,7 +2598,7 @@ var _ = Describe("Client Connection", func() {
mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(conn.Context().Done()).Should(BeClosed())
time.Sleep(200 * time.Millisecond)
})
Expand Down Expand Up @@ -2930,7 +2905,7 @@ var _ = Describe("Client Connection", func() {
}

AfterEach(func() {
conn.shutdown()
conn.CloseWithError(0, "")
Eventually(conn.Context().Done()).Should(BeClosed())
Eventually(errChan).Should(BeClosed())
})
Expand Down Expand Up @@ -2974,7 +2949,7 @@ var _ = Describe("Client Connection", func() {
Eventually(processed).Should(BeClosed())
// close first
expectClose(true, false)
conn.shutdown()
conn.CloseWithError(0, "")
// then check. Avoids race condition when accessing idleTimeout
Expect(conn.idleTimeout).To(Equal(18 * time.Second))
})
Expand Down
20 changes: 20 additions & 0 deletions integrationtests/self/handshake_test.go
Expand Up @@ -82,6 +82,26 @@ var _ = Describe("Handshake tests", func() {
}()
}

It("returns the context cancellation error on timeouts", func() {
ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond))
defer cancel()
errChan := make(chan error, 1)
go func() {
_, err := quic.DialAddr(
ctx,
"localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway
getTLSClientConfig(),
getQuicConfig(nil),
)
errChan <- err
}()

var err error
Eventually(errChan).Should(Receive(&err))
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(context.DeadlineExceeded))
})

It("returns the cancellation reason when a dial is canceled", func() {
ctx, cancel := context.WithCancelCause(context.Background())
errChan := make(chan error, 1)
Expand Down
36 changes: 0 additions & 36 deletions mock_packet_handler_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 0 additions & 36 deletions mock_quic_conn_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion packet_handler_map.go
Expand Up @@ -191,7 +191,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p

time.AfterFunc(h.deleteRetiredConnsAfter, func() {
h.mutex.Lock()
handler.shutdown()
for _, id := range ids {
delete(h.handlers, id)
}
Expand Down
2 changes: 0 additions & 2 deletions server.go
Expand Up @@ -24,7 +24,6 @@ var ErrServerClosed = errors.New("quic: server closed")
// packetHandler handles packets
type packetHandler interface {
handlePacket(receivedPacket)
shutdown()
destroy(error)
getPerspective() protocol.Perspective
}
Expand All @@ -45,7 +44,6 @@ type quicConn interface {
getPerspective() protocol.Perspective
run() error
destroy(error)
shutdown()
}

type zeroRTTQueue struct {
Expand Down

0 comments on commit b3eb375

Please sign in to comment.