From 5335aec94254e5b7cdb4e3106c3afb223940e9d7 Mon Sep 17 00:00:00 2001 From: tidwall Date: Wed, 13 Mar 2019 15:41:49 -0700 Subject: [PATCH] Allow for standard SQS URLs Both now work: https://sqs.us-east-1.amazonaws.com/349840735605/TestTile38Queue sqs://us-east-1:349840735605/TestTile38Queue --- internal/endpoint/endpoint.go | 48 ++++++++++++++++++++--------------- internal/endpoint/sqs.go | 32 +++++++++++++++++++++-- 2 files changed, 58 insertions(+), 22 deletions(-) diff --git a/internal/endpoint/endpoint.go b/internal/endpoint/endpoint.go index 452dcdbbc..49299c05b 100644 --- a/internal/endpoint/endpoint.go +++ b/internal/endpoint/endpoint.go @@ -90,6 +90,7 @@ type Endpoint struct { KeyFile string } SQS struct { + PlainURL string QueueID string Region string CredPath string @@ -217,7 +218,12 @@ func parseEndpoint(s string) (Endpoint, error) { case strings.HasPrefix(s, "http:"): endpoint.Protocol = HTTP case strings.HasPrefix(s, "https:"): - endpoint.Protocol = HTTP + if probeSQS(s) { + endpoint.SQS.PlainURL = s + endpoint.Protocol = SQS + } else { + endpoint.Protocol = HTTP + } case strings.HasPrefix(s, "disque:"): endpoint.Protocol = Disque case strings.HasPrefix(s, "grpc:"): @@ -469,22 +475,28 @@ func parseEndpoint(s string) (Endpoint, error) { // credpath - path where aws credentials are located // credprofile - credential profile if endpoint.Protocol == SQS { - // Parsing connection from URL string - hp := strings.Split(s, ":") - switch len(hp) { - default: - return endpoint, errors.New("invalid SQS url") - case 2: - endpoint.SQS.Region = hp[0] - endpoint.SQS.QueueID = hp[1] - } + if endpoint.SQS.PlainURL == "" { + // Parsing connection from URL string + hp := strings.Split(s, ":") + switch len(hp) { + default: + return endpoint, errors.New("invalid SQS url") + case 2: + endpoint.SQS.Region = hp[0] + endpoint.SQS.QueueID = hp[1] + } - // Parsing SQS queue name - if len(sp) > 1 { - var err error - endpoint.SQS.QueueName, err = url.QueryUnescape(sp[1]) - if err != nil { - return endpoint, errors.New("invalid SQS queue name") + // Parsing SQS queue name + if len(sp) > 1 { + var err error + endpoint.SQS.QueueName, err = url.QueryUnescape(sp[1]) + if err != nil { + return endpoint, errors.New("invalid SQS queue name") + } + } + // Throw error if we not provide any queue name + if endpoint.SQS.QueueName == "" { + return endpoint, errors.New("missing SQS queue name") } } @@ -512,10 +524,6 @@ func parseEndpoint(s string) (Endpoint, error) { } } } - // Throw error if we not provide any queue name - if endpoint.SQS.QueueName == "" { - return endpoint, errors.New("missing SQS queue name") - } } // Basic AMQP connection strings in HOOKS interface diff --git a/internal/endpoint/sqs.go b/internal/endpoint/sqs.go index 77f4280c6..363196418 100644 --- a/internal/endpoint/sqs.go +++ b/internal/endpoint/sqs.go @@ -3,6 +3,7 @@ package endpoint import ( "errors" "fmt" + "strings" "sync" "time" @@ -31,7 +32,11 @@ type SQSConn struct { } func (conn *SQSConn) generateSQSURL() string { - return "https://sqs." + conn.ep.SQS.Region + "amazonaws.com/" + conn.ep.SQS.QueueID + "/" + conn.ep.SQS.QueueName + if conn.ep.SQS.PlainURL != "" { + return conn.ep.SQS.PlainURL + } + return "https://sqs." + conn.ep.SQS.Region + ".amazonaws.com/" + + conn.ep.SQS.QueueID + "/" + conn.ep.SQS.QueueName } // Expired returns true if the connection has expired @@ -74,8 +79,14 @@ func (conn *SQSConn) Send(msg string) error { } creds = credentials.NewSharedCredentials(credPath, credProfile) } + var region string + if conn.ep.SQS.Region != "" { + region = conn.ep.SQS.Region + } else { + region = sqsRegionFromPlainURL(conn.ep.SQS.PlainURL) + } sess := session.Must(session.NewSession(&aws.Config{ - Region: aws.String(conn.ep.SQS.Region), + Region: ®ion, Credentials: creds, MaxRetries: aws.Int(5), })) @@ -114,3 +125,20 @@ func newSQSConn(ep Endpoint) *SQSConn { t: time.Now(), } } + +func probeSQS(s string) bool { + // https://sqs.eu-central-1.amazonaws.com/123456789/myqueue + return strings.HasPrefix(s, "https://sqs.") && + strings.Contains(s, ".amazonaws.com/") +} + +func sqsRegionFromPlainURL(s string) string { + parts := strings.Split(s, "https://sqs.") + if len(parts) > 1 { + parts = strings.Split(parts[1], ".amazonaws.com/") + if len(parts) > 1 { + return parts[0] + } + } + return "" +}