From 6908f7bac4b0bf0d27496e36af337ad4f8918f71 Mon Sep 17 00:00:00 2001 From: Daniel Cadenas Date: Tue, 13 Feb 2024 21:15:36 -0300 Subject: [PATCH] Shared rate limiter for multiple conns --- service/domain/relay_address.go | 29 ++++---- service/domain/relay_address_test.go | 6 +- .../rate_limit_notice_backoff_manager.go | 69 +++++++++++++++++++ service/domain/relays/relay_connection.go | 57 +-------------- .../domain/relays/relay_connection_test.go | 3 +- service/domain/relays/relay_connections.go | 28 +++++--- 6 files changed, 112 insertions(+), 80 deletions(-) create mode 100644 service/domain/relays/rate_limit_notice_backoff_manager.go diff --git a/service/domain/relay_address.go b/service/domain/relay_address.go index d42b7b9..5e393c1 100644 --- a/service/domain/relay_address.go +++ b/service/domain/relay_address.go @@ -9,7 +9,8 @@ import ( ) type RelayAddress struct { - original string + original string + hostWithoutPort string } func NewRelayAddress(s string) (RelayAddress, error) { @@ -25,8 +26,16 @@ func NewRelayAddress(s string) (RelayAddress, error) { return RelayAddress{}, errors.New("invalid protocol") } + u.Host = strings.ToLower(u.Host) + hostWithoutPort, _, err := net.SplitHostPort(u.Host) + if err != nil { + hostWithoutPort = u.Host + } + normalizedURI := u.String() + return RelayAddress{ - original: s, + original: normalizedURI, + hostWithoutPort: hostWithoutPort, }, nil } @@ -43,22 +52,12 @@ func NewRelayAddressFromMaybeAddress(maybe MaybeRelayAddress) (RelayAddress, err } func (r RelayAddress) IsLoopbackOrPrivate() bool { - hostWithoutPort := r.getHostWithoutPort() - ip := net.ParseIP(hostWithoutPort) + ip := net.ParseIP(r.hostWithoutPort) return ip.IsLoopback() || ip.IsPrivate() } -func (r RelayAddress) getHostWithoutPort() string { - u, err := url.Parse(r.original) - if err != nil { - panic(err) // checked in constructor - } - - hostWithoutPort, _, err := net.SplitHostPort(u.Host) - if err != nil { - return u.Host - } - return hostWithoutPort +func (r RelayAddress) HostWithoutPort() string { + return r.hostWithoutPort } func (r RelayAddress) String() string { diff --git a/service/domain/relay_address_test.go b/service/domain/relay_address_test.go index df391a1..b0ba40b 100644 --- a/service/domain/relay_address_test.go +++ b/service/domain/relay_address_test.go @@ -43,9 +43,13 @@ func TestRelayAddress(t *testing.T) { Input: "wss://example.com/ ", Output: "wss://example.com", }, + { + Input: "wss://EXAMPLE.com/FooBar ", + Output: "wss://example.com/FooBar", + }, { Input: "wss://example1.com/ wss://example2.com", - Output: "wss://example1.com/ wss://example2.com", + Output: "wss://example1.com/%20wss://example2.com", }, { Input: "wss:// wss://example.com", diff --git a/service/domain/relays/rate_limit_notice_backoff_manager.go b/service/domain/relays/rate_limit_notice_backoff_manager.go new file mode 100644 index 0000000..2322ac7 --- /dev/null +++ b/service/domain/relays/rate_limit_notice_backoff_manager.go @@ -0,0 +1,69 @@ +package relays + +import ( + "math" + "sync/atomic" + "time" +) + +type RateLimitNoticeBackoffManager struct { + rateLimitNoticeCount int32 + lastBumpTime atomic.Value +} + +func NewRateLimitNoticeBackoffManager() *RateLimitNoticeBackoffManager { + r := &RateLimitNoticeBackoffManager{ + rateLimitNoticeCount: 0, + } + + r.updateLastBumpTime() + return r +} + +func (r *RateLimitNoticeBackoffManager) Bump() { + timeSinceLastBump := time.Since(r.getLastBumpTime()) + if timeSinceLastBump < 500*time.Millisecond { + // Give some time for the rate limit to be lifted before increasing the counter + return + } + + atomic.AddInt32(&r.rateLimitNoticeCount, 1) + r.updateLastBumpTime() +} + +const maxBackoffMs = 10000 +const secondsToDecreaseRateLimitNoticeCount = 60 * 5 // 5 minutes = 300 seconds + +func (r *RateLimitNoticeBackoffManager) Wait() { + rateLimitNoticeCount := atomic.LoadInt32(&r.rateLimitNoticeCount) + if rateLimitNoticeCount <= 0 { + return + } + + backoffMs := int(math.Min(float64(maxBackoffMs), math.Pow(2, float64(r.rateLimitNoticeCount))*50)) + + timeSinceLastBump := time.Since(r.getLastBumpTime()) + if timeSinceLastBump > secondsToDecreaseRateLimitNoticeCount*time.Second { + atomic.AddInt32(&r.rateLimitNoticeCount, -1) + r.updateLastBumpTime() + } + + if backoffMs > 0 { + time.Sleep(time.Duration(backoffMs) * time.Millisecond) + } +} + +func (r *RateLimitNoticeBackoffManager) updateLastBumpTime() time.Time { + t := time.Now() + r.lastBumpTime.Store(t) + return t +} + +func (r *RateLimitNoticeBackoffManager) getLastBumpTime() time.Time { + val := r.lastBumpTime.Load() + if t, ok := val.(time.Time); ok { + return t + } + + return r.updateLastBumpTime() +} diff --git a/service/domain/relays/relay_connection.go b/service/domain/relays/relay_connection.go index 14cfe23..5fbfb88 100644 --- a/service/domain/relays/relay_connection.go +++ b/service/domain/relays/relay_connection.go @@ -8,7 +8,6 @@ import ( "regexp" "strings" "sync" - "sync/atomic" "time" "github.com/boreq/errors" @@ -54,59 +53,6 @@ type ConnectionFactory interface { Address() domain.RelayAddress } -type RateLimitNoticeBackoffManager struct { - address domain.RelayAddress - rateLimitNoticeCount int32 - lastBumpTime atomic.Value // Use atomic.Value for time.Time -} - -func (r *RateLimitNoticeBackoffManager) updateLastBumpTime() { - r.lastBumpTime.Store(time.Now()) -} - -func (r *RateLimitNoticeBackoffManager) getLastBumpTime() time.Time { - return r.lastBumpTime.Load().(time.Time) -} - -func NewRateLimitNoticeBackoffManager(address domain.RelayAddress) *RateLimitNoticeBackoffManager { - r := &RateLimitNoticeBackoffManager{ - address: address, - rateLimitNoticeCount: 0, - } - - r.updateLastBumpTime() - return r -} - -func (r *RateLimitNoticeBackoffManager) Bump() { - timeSinceLastBump := time.Since(r.getLastBumpTime()) - if timeSinceLastBump < 500*time.Millisecond { - // Give some time for the rate limit to be lifted before increasing the counter - return - } - - atomic.AddInt32(&r.rateLimitNoticeCount, 1) - r.updateLastBumpTime() -} - -func (r *RateLimitNoticeBackoffManager) Wait() { - if r.rateLimitNoticeCount <= 0 { - return - } - - backoffMs := int(math.Pow(2, float64(r.rateLimitNoticeCount))) * 100 - - timeSinceLastBump := time.Since(r.getLastBumpTime()) - if timeSinceLastBump > 5*time.Second { - atomic.AddInt32(&r.rateLimitNoticeCount, -1) - r.updateLastBumpTime() - } - - if backoffMs > 0 { - time.Sleep(time.Duration(backoffMs) * time.Millisecond) - } -} - type RelayConnection struct { connectionFactory ConnectionFactory logger logging.Logger @@ -128,6 +74,7 @@ type RelayConnection struct { func NewRelayConnection( connectionFactory ConnectionFactory, + rateLimitNoticeBackoffManager *RateLimitNoticeBackoffManager, logger logging.Logger, metrics Metrics, ) *RelayConnection { @@ -141,7 +88,7 @@ func NewRelayConnection( subscriptionsUpdatedCh: make(chan struct{}), eventsToSend: make(map[domain.EventId]*eventToSend), newEventsCh: make(chan domain.Event), - rateLimitNoticeBackoffManager: NewRateLimitNoticeBackoffManager(connectionFactory.Address()), + rateLimitNoticeBackoffManager: rateLimitNoticeBackoffManager, } } diff --git a/service/domain/relays/relay_connection_test.go b/service/domain/relays/relay_connection_test.go index 8fcdded..47256bd 100644 --- a/service/domain/relays/relay_connection_test.go +++ b/service/domain/relays/relay_connection_test.go @@ -135,9 +135,10 @@ type testConnection struct { func newTestConnection(tb testing.TB, ctx context.Context) *testConnection { connection := newMockConnection() factory := newMockConnectionFactory(connection) + backoffManager := relays.NewRateLimitNoticeBackoffManager() metrics := newMockMetrics() logger := logging.NewDevNullLogger() - relayConnection := relays.NewRelayConnection(factory, logger, metrics) + relayConnection := relays.NewRelayConnection(factory, backoffManager, logger, metrics) go relayConnection.Run(ctx) return &testConnection{ diff --git a/service/domain/relays/relay_connections.go b/service/domain/relays/relay_connections.go index b7c97ae..4bf308b 100644 --- a/service/domain/relays/relay_connections.go +++ b/service/domain/relays/relay_connections.go @@ -45,17 +45,19 @@ type RelayConnections struct { longCtx context.Context - connections map[domain.RelayAddress]*RelayConnection - connectionsLock sync.Mutex + connections map[domain.RelayAddress]*RelayConnection + rateLimitNoticeBackoffManagers map[string]*RateLimitNoticeBackoffManager + connectionsLock sync.Mutex } func NewRelayConnections(ctx context.Context, logger logging.Logger, metrics Metrics) *RelayConnections { v := &RelayConnections{ - logger: logger.New("relayConnections"), - metrics: metrics, - longCtx: ctx, - connections: make(map[domain.RelayAddress]*RelayConnection), - connectionsLock: sync.Mutex{}, + logger: logger.New("relayConnections"), + metrics: metrics, + longCtx: ctx, + connections: make(map[domain.RelayAddress]*RelayConnection), + rateLimitNoticeBackoffManagers: make(map[string]*RateLimitNoticeBackoffManager), + connectionsLock: sync.Mutex{}, } go v.storeMetricsLoop(ctx) return v @@ -104,7 +106,17 @@ func (r *RelayConnections) getConnection(relayAddress domain.RelayAddress) *Rela } factory := NewWebsocketConnectionFactory(relayAddress, r.logger) - connection := NewRelayConnection(factory, r.logger, r.metrics) + + // Sometimes different addreses can point to the same relay. Example is + // wss://feeds.nostr.band/video and wss://feeds.nostr.band/audio. For these + // cases, we want to share the rate limit notice backoff manager. + rateLimitNoticeBackoffManager, exists := r.rateLimitNoticeBackoffManagers[relayAddress.HostWithoutPort()] + if !exists { + rateLimitNoticeBackoffManager = NewRateLimitNoticeBackoffManager() + r.rateLimitNoticeBackoffManagers[relayAddress.HostWithoutPort()] = rateLimitNoticeBackoffManager + } + + connection := NewRelayConnection(factory, rateLimitNoticeBackoffManager, r.logger, r.metrics) go connection.Run(r.longCtx) r.connections[relayAddress] = connection