/
sns.go
56 lines (45 loc) · 1.49 KB
/
sns.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
package main
import (
"context"
"fmt"
"log"
"strings"
"github.com/aws/aws-sdk-go-v2/service/sns"
"github.com/udhos/boilerplate/awsconfig"
)
type snsClient interface {
Publish(ctx context.Context, params *sns.PublishInput, optFns ...func(*sns.Options)) (*sns.PublishOutput, error)
}
func newSnsClient(sessionName, topicArn, roleArn, endpointURL string) snsClient {
return newSnsClientAws(sessionName, topicArn, roleArn, endpointURL) // create real sns client
}
type newSnsClientFunc func(sessionName, topicArn, roleArn, endpointURL string) snsClient
func newSnsClientAws(sessionName, topicArn, roleArn, endpointURL string) *sns.Client {
const me = "snsClient"
topicRegion, errTopic := getTopicRegion(topicArn)
if errTopic != nil {
log.Fatalf("%s: topic region error: %v", me, errTopic)
}
awsConfOptions := awsconfig.Options{
Region: topicRegion,
RoleArn: roleArn,
RoleSessionName: sessionName,
EndpointURL: endpointURL,
}
awsConf, errAwsConf := awsconfig.AwsConfig(awsConfOptions)
if errAwsConf != nil {
log.Fatalf("%s: aws config error: %v", me, errAwsConf)
}
return sns.NewFromConfig(awsConf.AwsConfig)
}
// arn:aws:sns:us-east-1:123456789012:mytopic
func getTopicRegion(topicArn string) (string, error) {
const me = "getTopicRegion"
fields := strings.SplitN(topicArn, ":", 5)
if len(fields) < 5 {
return "", fmt.Errorf("%s: bad topic arn=[%s]", me, topicArn)
}
region := fields[3]
log.Printf("%s: topicRegion=[%s]", me, region)
return region, nil
}