Skip to content

Commit

Permalink
only create a single session for duplicate Initials
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed May 29, 2020
1 parent 85c19fb commit dad30e7
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 58 deletions.
2 changes: 0 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,6 @@ func (c *client) dial(ctx context.Context) error {
c.version,
)
c.mutex.Unlock()
// It's not possible to use the stateless reset token for the client's (first) connection ID,
// since there's no way to securely communicate it to the server.
c.packetHandlers.Add(c.srcConnID, c)

errorChan := make(chan error, 1)
Expand Down
14 changes: 14 additions & 0 deletions mock_packet_handler_manager_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions packet_handler_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
return true
}

func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool {
sid := string(clientDestConnID)
h.mutex.Lock()
defer h.mutex.Unlock()

if _, ok := h.handlers[sid]; ok {
return false
}

sess := fn()
h.handlers[sid] = sess
h.handlers[string(newConnID)] = sess
return true
}

func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
h.mutex.Lock()
delete(h.handlers, string(id))
Expand Down
8 changes: 8 additions & 0 deletions packet_handler_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,14 @@ var _ = Describe("Packet Handler Map", func() {
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
})

It("says if a connection ID is already taken, for AddWithConnID", func() {
clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
newConnID1 := protocol.ConnectionID{1, 2, 3, 4}
newConnID2 := protocol.ConnectionID{4, 3, 2, 1}
Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue())
Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse())
})
})

Context("running a server", func() {
Expand Down
61 changes: 31 additions & 30 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type unknownPacketHandler interface {
}

type packetHandlerManager interface {
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
Destroy() error
sessionRunner
SetServer(unknownPacketHandler)
Expand Down Expand Up @@ -421,39 +422,39 @@ func (s *baseServer) createNewSession(
srcConnID protocol.ConnectionID,
version protocol.VersionNumber,
) quicSession {
var qlogger qlog.Tracer
if s.config.GetLogWriter != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := clientDestConnID
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
if w := s.config.GetLogWriter(connID); w != nil {
qlogger = qlog.NewTracer(w, protocol.PerspectiveServer, connID)
var sess quicSession
if added := s.sessionHandler.AddWithConnID(clientDestConnID, srcConnID, func() packetHandler {
var qlogger qlog.Tracer
if s.config.GetLogWriter != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := clientDestConnID
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
if w := s.config.GetLogWriter(connID); w != nil {
qlogger = qlog.NewTracer(w, protocol.PerspectiveServer, connID)
}
}
}
sess := s.newSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionHandler,
origDestConnID,
clientDestConnID,
destConnID,
srcConnID,
s.sessionHandler.GetStatelessResetToken(srcConnID),
s.config,
s.tlsConf,
s.tokenGenerator,
s.acceptEarlySessions,
qlogger,
s.logger,
version,
)
if added := s.sessionHandler.Add(clientDestConnID, sess); !added {
// We're already keeping track of this connection ID.
// This might happen if we receive two copies of the Initial at the same time.
sess = s.newSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionHandler,
origDestConnID,
clientDestConnID,
destConnID,
srcConnID,
s.sessionHandler.GetStatelessResetToken(srcConnID),
s.config,
s.tlsConf,
s.tokenGenerator,
s.acceptEarlySessions,
qlogger,
s.logger,
version,
)
return sess
}); !added {
return nil
}
s.sessionHandler.Add(srcConnID, sess)
go sess.run()
go s.handleNewSession(sess)
return sess
Expand Down
81 changes: 55 additions & 26 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,7 @@ var _ = Describe("Server", func() {
var token [16]byte
rand.Read(token[:])
var newConnID protocol.ConnectionID
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
newConnID = c
return token
})

sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
_ connection,
Expand Down Expand Up @@ -439,9 +436,13 @@ var _ = Describe("Server", func() {
return sess
}

phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true)
phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool {
Expect(c).To(Equal(newConnID))
phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
newConnID = c
return token
})
fn()
return true
})

Expand Down Expand Up @@ -502,14 +503,20 @@ var _ = Describe("Server", func() {
Expect(serv.handlePacketImpl(zeroRTTPacket)).To(BeTrue())
// Then receive the Initial packet.
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
fn()
return true
})
Expect(serv.handlePacketImpl(initialPacket)).To(BeTrue())
Expect(createdSession).To(BeTrue())
})

It("drops packets if the receive queue is full", func() {
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().Add(gomock.Any(), gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
}).AnyTimes()

serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
acceptSession := make(chan struct{})
Expand All @@ -532,7 +539,12 @@ var _ = Describe("Server", func() {
) quicSession {
<-acceptSession
atomic.AddUint32(&counter, 1)
return nil
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
sess.EXPECT().run()
sess.EXPECT().Context().Return(context.Background())
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess
}

serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}))
Expand Down Expand Up @@ -577,10 +589,9 @@ var _ = Describe("Server", func() {
}

p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false)
phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false)
Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdSession).To(BeTrue())
Expect(createdSession).To(BeFalse())
})

It("rejects new connection attempts if the accept queue is full", func() {
Expand Down Expand Up @@ -612,8 +623,11 @@ var _ = Describe("Server", func() {
return sess
}

phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
}).Times(protocol.MaxAcceptQueueSize)

var wg sync.WaitGroup
wg.Add(protocol.MaxAcceptQueueSize)
Expand Down Expand Up @@ -673,8 +687,11 @@ var _ = Describe("Server", func() {
return sess
}

phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})

serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
Expand Down Expand Up @@ -772,8 +789,11 @@ var _ = Describe("Server", func() {
sess.EXPECT().Context().Return(context.Background())
return sess
}
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake
Expand Down Expand Up @@ -836,8 +856,11 @@ var _ = Describe("Server", func() {
sess.EXPECT().Context().Return(context.Background())
return sess
}
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed())
close(ready)
Expand Down Expand Up @@ -874,8 +897,11 @@ var _ = Describe("Server", func() {
return sess
}

phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
}).Times(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
serv.handlePacket(getInitialWithRandomDestConnID())
}
Expand Down Expand Up @@ -927,8 +953,11 @@ var _ = Describe("Server", func() {
return sess
}

phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
Eventually(sessionCreated).Should(BeClosed())
Expand Down

0 comments on commit dad30e7

Please sign in to comment.