Skip to content

Commit

Permalink
GODRIVER-2348 Use context listener to cancel HB check
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez committed May 30, 2024
1 parent a719b0a commit c58c276
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 160 deletions.
16 changes: 6 additions & 10 deletions x/mongo/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,16 @@ import (
"go.mongodb.org/mongo-driver/x/mongo/driver/mnet"
)

// ServerSelectionTimeoutGetter returns a timeout that should be used to set a
// deadline for server selection. This logic is not handleded internally by the
// ServerSelector, as a resulting deadline may be applicable by follow-up
// operations, such as checking out a connection.
type ServerSelectionTimeoutGetter interface {
GetServerSelectionTimeout() time.Duration
}

// Deployment is implemented by types that can select a server from a deployment.
type Deployment interface {
ServerSelectionTimeoutGetter

SelectServer(context.Context, description.ServerSelector) (Server, error)
Kind() description.TopologyKind

// GetServerSelectionTimeout returns a timeout that should be used to set a
// deadline for server selection. This logic is not handleded internally by
// the ServerSelector, as a resulting deadline may be applicable by follow-up
// operations such as checking out a connection.
GetServerSelectionTimeout() time.Duration
}

// Connector represents a type that can connect to a server.
Expand Down
14 changes: 0 additions & 14 deletions x/mongo/driver/topology/cancellation_listener.go

This file was deleted.

57 changes: 11 additions & 46 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ type connection struct {
connectContextMade chan struct{}
canStream bool
currentlyStreaming bool
cancellationListener cancellationListener
cancellationListener contextListener
serverConnectionID *int64 // the server's ID for this client's connection
cancelConnSig chan struct{}
prevCanceled atomic.Value

// pool related fields
pool *pool
Expand All @@ -89,7 +90,7 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
connectDone: make(chan struct{}),
config: cfg,
connectContextMade: make(chan struct{}),
cancellationListener: newCancellListener(),
cancellationListener: newContextDoneListener(),
}
// Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered
// at any point during connection establishment can be processed without the connection being considered stale.
Expand Down Expand Up @@ -527,6 +528,14 @@ func (c *connection) ServerConnectionID() *int64 {
return c.serverConnectionID
}

func (c *connection) previousCanceled() bool {
if val := c.prevCanceled.Load(); val != nil {
return val.(bool)
}

return false
}

// initConnection is an adapter used during connection initialization. It has the minimum
// functionality necessary to implement the driver.Connection interface, which is required to pass a
// *connection to a Handshaker.
Expand Down Expand Up @@ -830,47 +839,3 @@ func configureTLS(ctx context.Context,
}
return client, nil
}

// TODO: Naming?

// cancellListener listens for context cancellation and notifies listeners via a
// callback function.
type cancellListener struct {
aborted bool
done chan struct{}
}

// newCancellListener constructs a cancellListener.
func newCancellListener() *cancellListener {
return &cancellListener{
done: make(chan struct{}),
}
}

// Listen blocks until the provided context is cancelled or listening is aborted
// via the StopListening function. If this detects that the context has been
// cancelled (i.e. errors.Is(ctx.Err(), context.Canceled), the provided callback is
// called to abort in-progress work. Even if the context expires, this function
// will block until StopListening is called.
func (c *cancellListener) Listen(ctx context.Context, abortFn func()) {
c.aborted = false

select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
c.aborted = true
abortFn()
}

<-c.done
case <-c.done:
}
}

// StopListening stops the in-progress Listen call. This blocks if there is no
// in-progress Listen call. This function will return true if the provided abort
// callback was called when listening for cancellation on the previous context.
func (c *cancellListener) StopListening() bool {
c.done <- struct{}{}
return c.aborted
}
4 changes: 2 additions & 2 deletions x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ func (d *dialer) lenclosed() int {
}

type testCancellationListener struct {
listener *cancellListener
listener *contextDoneListener
numListen int
numStopListening int
aborted bool
Expand All @@ -1058,7 +1058,7 @@ type testCancellationListener struct {
// returned by the StopListening method.
func newTestCancellationListener(aborted bool) *testCancellationListener {
return &testCancellationListener{
listener: newCancellListener(),
listener: newContextDoneListener(),
aborted: aborted,
}
}
Expand Down
91 changes: 91 additions & 0 deletions x/mongo/driver/topology/context_listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package topology

import (
"context"
"errors"
"sync/atomic"
)

type contextListener interface {
Listen(context.Context, func())
StopListening() bool
}

// contextDoneListener listens for context-ending eventsand notifies listeners
// via a callback function.
type contextDoneListener struct {
aborted atomic.Value
done chan struct{}
blockOnDone bool
}

var _ contextListener = &contextDoneListener{}

// newContextDoneListener constructs a contextDoneListener that will block
// when a context is done until StopListening is called.
func newContextDoneListener() *contextDoneListener {
return &contextDoneListener{
done: make(chan struct{}),
blockOnDone: true,
}
}

// newNonBlockingContextDoneLIstener constructs a contextDoneListener that
// will not block when a context is done. In this case there are two ways to
// unblock the listener: a finished context or a call to StopListening.
func newNonBlockingContextDoneListener() *contextDoneListener {
return &contextDoneListener{
done: make(chan struct{}),
blockOnDone: false,
}
}

// Listen blocks until the provided context is cancelled or listening is aborted
// via the StopListening function. If this detects that the context has been
// cancelled (i.e. errors.Is(ctx.Err(), context.Canceled), the provided callback
// is called to abort in-progress work. If blockOnDone is true, this function
// will block until StopListening is called, even if the context expires.
func (c *contextDoneListener) Listen(ctx context.Context, abortFn func()) {
c.aborted.Store(false)

select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
c.aborted.Store(true)

abortFn()
}

if c.blockOnDone {
<-c.done
}
case <-c.done:
}
}

// StopListening stops the in-progress Listen call. If blockOnDone is true, then
// this blocks if there is no in-progress Listen call. This function will return
// true if the provided abort callback was called when listening for
// cancellation on the previous context.
func (c *contextDoneListener) StopListening() bool {
if c.blockOnDone {
c.done <- struct{}{}
} else {
select {
case c.done <- struct{}{}:
default:
}
}

if aborted := c.aborted.Load(); aborted != nil {
return aborted.(bool)
}

return false
}
Loading

0 comments on commit c58c276

Please sign in to comment.