Skip to content

Commit

Permalink
share aws session
Browse files Browse the repository at this point in the history
  • Loading branch information
pjdufour-truss committed Jun 3, 2019
1 parent de0cca4 commit d127fed
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 20 deletions.
18 changes: 16 additions & 2 deletions cmd/milmove/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"sync"
"syscall"

awssession "github.com/aws/aws-sdk-go/aws/session"
"github.com/dgrijalva/jwt-go"
"github.com/gobuffalo/pop"
"github.com/gorilla/csrf"
Expand Down Expand Up @@ -352,6 +353,19 @@ func serveFunction(cmd *cobra.Command, args []string) error {
pop.Debug = true
}

var session *awssession.Session
if v.GetString(cli.EmailBackendFlag) == "ses" || v.GetString(cli.StorageBackendFlag) == "s3" {
c, errorConfig := cli.GetAWSConfig(v, v.GetBool(cli.VerboseFlag))
if errorConfig != nil {
logger.Fatal(errors.Wrap(errorConfig, "error creating aws config").Error())
}
s, errorSession := awssession.NewSession(c)
if errorSession != nil {
logger.Fatal(errors.Wrap(errorSession, "error creating aws session").Error())
}
session = s
}

// Create a connection to the DB
dbConnection, err := cli.InitDatabase(v, logger)
if err != nil {
Expand Down Expand Up @@ -398,7 +412,7 @@ func serveFunction(cmd *cobra.Command, args []string) error {
}

// Email
notificationSender := cli.InitEmail(v, logger)
notificationSender := cli.InitEmail(v, session, logger)
handlerContext.SetNotificationSender(notificationSender)

build := v.GetString(cli.BuildFlag)
Expand All @@ -415,7 +429,7 @@ func serveFunction(cmd *cobra.Command, args []string) error {
handlerContext.SetSendProductionInvoice(v.GetBool(cli.GEXSendProdInvoiceFlag))

// Storage
storer := cli.InitStorage(v, logger)
storer := cli.InitStorage(v, session, logger)
handlerContext.SetFileStorer(storer)

certificates, rootCAs, err := cli.InitDoDCertificates(v, logger)
Expand Down
11 changes: 2 additions & 9 deletions pkg/cli/email.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cli
import (
"fmt"

"github.com/aws/aws-sdk-go/aws"
awssession "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ses"
"github.com/pkg/errors"
Expand Down Expand Up @@ -51,7 +50,7 @@ func CheckEmail(v *viper.Viper) error {
}

// InitEmail initializes the email backend
func InitEmail(v *viper.Viper, logger Logger) notifications.NotificationSender {
func InitEmail(v *viper.Viper, sess *awssession.Session, logger Logger) notifications.NotificationSender {
if v.GetString(EmailBackendFlag) == "ses" {
// Setup Amazon SES (email) service
// TODO: This might be able to be combined with the AWS Session that we're using for S3 down
Expand All @@ -61,13 +60,7 @@ func InitEmail(v *viper.Viper, logger Logger) notifications.NotificationSender {
logger.Info("Using ses email backend",
zap.String("region", awsSESRegion),
zap.String("domain", awsSESDomain))
sesSession, newSessionErr := awssession.NewSession(&aws.Config{
Region: aws.String(awsSESRegion),
})
if newSessionErr != nil {
logger.Fatal("Failed to create a new AWS client config provider", zap.Error(newSessionErr))
}
sesService := ses.New(sesSession)
sesService := ses.New(sess)
return notifications.NewNotificationSender(sesService, awsSESDomain, logger)
}

Expand Down
8 changes: 2 additions & 6 deletions pkg/cli/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"path"
"path/filepath"

"github.com/aws/aws-sdk-go/aws"
awssession "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/pkg/errors"
Expand Down Expand Up @@ -65,7 +64,7 @@ func CheckStorage(v *viper.Viper) error {
}

// InitStorage initializes the storage backend
func InitStorage(v *viper.Viper, logger Logger) storage.FileStorer {
func InitStorage(v *viper.Viper, sess *awssession.Session, logger Logger) storage.FileStorer {
storageBackend := v.GetString(StorageBackendFlag)
localStorageRoot := v.GetString(LocalStorageRootFlag)
localStorageWebRoot := v.GetString(LocalStorageWebRootFlag)
Expand All @@ -88,10 +87,7 @@ func InitStorage(v *viper.Viper, logger Logger) storage.FileStorer {
if len(awsS3KeyNamespace) == 0 {
logger.Fatal("Must provide aws_s3_key_namespace parameter, exiting")
}
aws := awssession.Must(awssession.NewSession(&aws.Config{
Region: aws.String(awsS3Region),
}))
storer = storage.NewS3(awsS3Bucket, awsS3KeyNamespace, logger, aws)
storer = storage.NewS3(awsS3Bucket, awsS3KeyNamespace, logger, sess)
} else if storageBackend == "memory" {
logger.Info("Using memory storage backend",
zap.String(LocalStorageRootFlag, path.Join(localStorageRoot, localStorageWebRoot)),
Expand Down
6 changes: 3 additions & 3 deletions pkg/cli/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ func CheckVault(v *viper.Viper) error {
return nil
}

// GetAWSCredentials uses aws-vault to return AWS credentials
func GetAWSCredentials(keychainName string, awsProfile string) (*credentials.Credentials, error) {
// GetAWSCredentialsFromKeyring uses aws-vault to return AWS credential from a system keyring.
func GetAWSCredentialsFromKeyring(keychainName string, awsProfile string) (*credentials.Credentials, error) {

// Open the keyring which holds the credentials
ring, err := keyring.Open(keyring.Config{
Expand Down Expand Up @@ -151,7 +151,7 @@ func GetAWSConfig(v *viper.Viper, verbose bool) (*aws.Config, error) {
keychainName := v.GetString(VaultAWSKeychainNameFlag)
awsProfile := v.GetString(VaultAWSProfileFlag)
if len(keychainName) > 0 && len(awsProfile) > 0 {
creds, getAWSCredsErr := GetAWSCredentials(keychainName, awsProfile)
creds, getAWSCredsErr := GetAWSCredentialsFromKeyring(keychainName, awsProfile)
if getAWSCredsErr != nil {
return nil, errors.Wrap(getAWSCredsErr,
fmt.Sprintf("Unable to get AWS credentials from the keychain %s and profile %s", keychainName, awsProfile))
Expand Down

0 comments on commit d127fed

Please sign in to comment.