diff --git a/storage/s3.go b/storage/s3.go index f6ebcffd3..5676c0e57 100644 --- a/storage/s3.go +++ b/storage/s3.go @@ -1235,9 +1235,9 @@ func (sc *SessionCache) newSession(ctx context.Context, opts Options) (*session. WithLogger(sdkLogger{}) } - awsCfg.Retryer = newCustomRetryer(opts.MaxRetries) + awsCfg.Retryer = newCustomRetryer(sc, opts.MaxRetries) - useSharedConfig := session.SharedConfigEnable + useSharedConfig := session.SharedConfigDisable { // Reverse of what the SDK does: if AWS_SDK_LOAD_CONFIG is 0 (or a // falsy value) disable shared configs @@ -1276,7 +1276,7 @@ func (sc *SessionCache) newSession(ctx context.Context, opts Options) (*session. return sess, nil } -func (sc *SessionCache) clear() { +func (sc *SessionCache) Clear() { sc.Lock() defer sc.Unlock() sc.sessions = map[Options]*session.Session{} @@ -1324,10 +1324,12 @@ func setSessionRegion(ctx context.Context, sess *session.Session, bucket string) // error codes. Such as, retry for S3 InternalError code. type customRetryer struct { client.DefaultRetryer + sc *SessionCache } -func newCustomRetryer(maxRetries int) *customRetryer { +func newCustomRetryer(sc *SessionCache, maxRetries int) *customRetryer { return &customRetryer{ + sc: sc, DefaultRetryer: client.DefaultRetryer{ NumMaxRetries: maxRetries, }, @@ -1337,13 +1339,27 @@ func newCustomRetryer(maxRetries int) *customRetryer { // ShouldRetry overrides SDK's built in DefaultRetryer, adding custom retry // logics that are not included in the SDK. func (c *customRetryer) ShouldRetry(req *request.Request) bool { - shouldRetry := errHasCode(req.Error, "InternalError") || errHasCode(req.Error, "RequestTimeTooSkewed") || errHasCode(req.Error, "SlowDown") || strings.Contains(req.Error.Error(), "connection reset") || strings.Contains(req.Error.Error(), "connection timed out") + log.Error(log.ErrorMessage{ + Command: "retrier", + Err: req.Error.Error(), + }) + + shouldRetry := errHasCode(req.Error, "InternalError") || errHasCode(req.Error, "RequestTimeTooSkewed") || errHasCode(req.Error, "SlowDown") || strings.Contains(req.Error.Error(), "connection reset") || strings.Contains(req.Error.Error(), "connection timed out") || errHasCode(req.Error, "ExpiredToken") || errHasCode(req.Error, "ExpiredTokenException") + + if errHasCode(req.Error, "ExpiredToken") || errHasCode(req.Error, "ExpiredTokenException") { + log.Debug(log.DebugMessage{ + Err: "Clearing the token", + }) + + c.sc.Clear() + } + if !shouldRetry { shouldRetry = c.DefaultRetryer.ShouldRetry(req) } // Errors related to tokens - if errHasCode(req.Error, "ExpiredToken") || errHasCode(req.Error, "ExpiredTokenException") || errHasCode(req.Error, "InvalidToken") { + if errHasCode(req.Error, "InvalidToken") { return false } diff --git a/storage/s3_test.go b/storage/s3_test.go index aee355b1c..792470310 100644 --- a/storage/s3_test.go +++ b/storage/s3_test.go @@ -97,7 +97,7 @@ func TestNewSessionPathStyle(t *testing.T) { } func TestNewSessionWithRegionSetViaEnv(t *testing.T) { - globalSessionCache.clear() + globalSessionCache.Clear() const expectedRegion = "us-west-2" @@ -116,7 +116,7 @@ func TestNewSessionWithRegionSetViaEnv(t *testing.T) { } func TestNewSessionWithNoSignRequest(t *testing.T) { - globalSessionCache.clear() + globalSessionCache.Clear() sess, err := globalSessionCache.newSession(context.Background(), Options{ NoSignRequest: true, @@ -190,7 +190,7 @@ aws_secret_access_key = p2_profile_access_key` } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - globalSessionCache.clear() + globalSessionCache.Clear() sess, err := globalSessionCache.newSession(context.Background(), Options{ Profile: tc.profileName, CredentialFile: tc.fileName, @@ -1041,7 +1041,7 @@ func TestSessionRegionDetection(t *testing.T) { opts.bucket = tc.bucket } - globalSessionCache.clear() + globalSessionCache.Clear() sess, err := globalSessionCache.newSession(context.Background(), opts) if err != nil { @@ -1241,7 +1241,7 @@ func TestAWSLogLevel(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - globalSessionCache.clear() + globalSessionCache.Clear() sess, err := globalSessionCache.newSession(context.Background(), Options{ LogLevel: log.LevelFromString(tc.level), })