Skip to content
This repository has been archived by the owner on Aug 14, 2018. It is now read-only.

Commit

Permalink
Merge pull request #16 from smartystreets/variadic-credentials
Browse files Browse the repository at this point in the history
Credentials as a parameter to all Sign functions
  • Loading branch information
mholt committed Jun 2, 2014
2 parents db9dd1c + 3c21e95 commit 41cdffe
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 97 deletions.
118 changes: 76 additions & 42 deletions awsauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ import (
"time"
)

// Keys stores the authentication credentials to be used when signing requests.
// You can set them manually or leave it to awsauth to use environment variables.
var Keys *Credentials

// Credentials stores the information necessary to authorize with AWS and it
// is from this information that requests are signed.
type Credentials struct {
Expand All @@ -24,33 +20,51 @@ type Credentials struct {

// Sign signs a request bound for AWS. It automatically chooses the best
// authentication scheme based on the service the request is going to.
func Sign(req *http.Request) *http.Request {
func Sign(req *http.Request, cred ...Credentials) *http.Request {
service, _ := serviceAndRegion(req.URL.Host)
sigVersion := awsSignVersion[service]

switch sigVersion {
case 2:
return Sign2(req)
case 3:
return Sign3(req)
case 4:
return Sign4(req)
case -1:
return SignS3(req)
if len(cred) == 0 {
switch sigVersion {
case 2:
return Sign2(req)
case 3:
return Sign3(req)
case 4:
return Sign4(req)
case -1:
return SignS3(req)
}
} else {
switch sigVersion {
case 2:
return Sign2(req, cred[0])
case 3:
return Sign3(req, cred[0])
case 4:
return Sign4(req, cred[0])
case -1:
return SignS3(req, cred[0])
}
}

return nil
}

// Sign4 signs a request with Signed Signature Version 4.
func Sign4(req *http.Request) *http.Request {
func Sign4(req *http.Request, cred ...Credentials) *http.Request {
signMutex.Lock()
defer signMutex.Unlock()
checkKeys()
var keys Credentials
if len(cred) == 0 {
keys = newKeys()
} else {
keys = cred[0]
}

// Add the X-Amz-Security-Token header when using STS
if Keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", Keys.SecurityToken)
if keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", keys.SecurityToken)
}

prepareRequestV4(req)
Expand All @@ -63,24 +77,29 @@ func Sign4(req *http.Request) *http.Request {
stringToSign := stringToSignV4(req, hashedCanonReq, meta)

// Task 3
signingKey := signingKeyV4(Keys.SecretAccessKey, meta.date, meta.region, meta.service)
signingKey := signingKeyV4(keys.SecretAccessKey, meta.date, meta.region, meta.service)
signature := signatureV4(signingKey, stringToSign)

req.Header.Set("Authorization", buildAuthHeaderV4(signature, meta))
req.Header.Set("Authorization", buildAuthHeaderV4(signature, meta, keys))

return req
}

// Sign3 signs a request with Signed Signature Version 3.
// If the service you're accessing supports Version 4, use that instead.
func Sign3(req *http.Request) *http.Request {
func Sign3(req *http.Request, cred ...Credentials) *http.Request {
signMutex.Lock()
defer signMutex.Unlock()
checkKeys()
var keys Credentials
if len(cred) == 0 {
keys = newKeys()
} else {
keys = cred[0]
}

// Add the X-Amz-Security-Token header when using STS
if Keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", Keys.SecurityToken)
if keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", keys.SecurityToken)
}

prepareRequestV3(req)
Expand All @@ -89,34 +108,39 @@ func Sign3(req *http.Request) *http.Request {
stringToSign := stringToSignV3(req)

// Task 2
signature := signatureV3(stringToSign)
signature := signatureV3(stringToSign, keys)

// Task 3
req.Header.Set("X-Amzn-Authorization", buildAuthHeaderV3(signature))
req.Header.Set("X-Amzn-Authorization", buildAuthHeaderV3(signature, keys))

return req
}

// Sign2 signs a request with Signed Signature Version 2.
// If the service you're accessing supports Version 4, use that instead.
func Sign2(req *http.Request) *http.Request {
func Sign2(req *http.Request, cred ...Credentials) *http.Request {
signMutex.Lock()
defer signMutex.Unlock()
checkKeys()
var keys Credentials
if len(cred) == 0 {
keys = newKeys()
} else {
keys = cred[0]
}

// Add the SecurityToken parameter when using STS
// This must be added before the signature is calculated
if Keys.SecurityToken != "" {
if keys.SecurityToken != "" {
v := url.Values{}
v.Set("SecurityToken", Keys.SecurityToken)
v.Set("SecurityToken", keys.SecurityToken)
augmentRequestQuery(req, v)

}

prepareRequestV2(req)
prepareRequestV2(req, keys)

stringToSign := stringToSignV2(req)
signature := signatureV2(stringToSign)
signature := signatureV2(stringToSign, keys)

values := url.Values{}
values.Set("Signature", signature)
Expand All @@ -128,22 +152,27 @@ func Sign2(req *http.Request) *http.Request {

// SignS3 signs a request bound for Amazon S3 using their custom
// HTTP authentication scheme.
func SignS3(req *http.Request) *http.Request {
func SignS3(req *http.Request, cred ...Credentials) *http.Request {
signMutex.Lock()
defer signMutex.Unlock()
checkKeys()
var keys Credentials
if len(cred) == 0 {
keys = newKeys()
} else {
keys = cred[0]
}

// Add the X-Amz-Security-Token header when using STS
if Keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", Keys.SecurityToken)
if keys.SecurityToken != "" {
req.Header.Set("X-Amz-Security-Token", keys.SecurityToken)
}

prepareRequestS3(req)

stringToSign := stringToSignS3(req)
signature := signatureS3(stringToSign)
signature := signatureS3(stringToSign, keys)

authHeader := "AWS " + Keys.AccessKeyID + ":" + signature
authHeader := "AWS " + keys.AccessKeyID + ":" + signature
req.Header.Set("Authorization", authHeader)

return req
Expand All @@ -153,16 +182,21 @@ func SignS3(req *http.Request) *http.Request {
// query string parameters containing credentials and signature. You must
// specify an expiration date for these signed requests. After that date,
// a request signed with this method will be rejected by S3.
func SignS3Url(req *http.Request, expire time.Time) *http.Request {
func SignS3Url(req *http.Request, expire time.Time, cred ...Credentials) *http.Request {
signMutex.Lock()
defer signMutex.Unlock()
checkKeys()
var keys Credentials
if len(cred) == 0 {
keys = newKeys()
} else {
keys = cred[0]
}

stringToSign := stringToSignS3Url("GET", expire, req.URL.Path)
signature := signatureS3(stringToSign)
signature := signatureS3(stringToSign, keys)

qs := req.URL.Query()
qs.Set("AWSAccessKeyId", Keys.AccessKeyID)
qs.Set("AWSAccessKeyId", keys.AccessKeyID)
qs.Set("Signature", signature)
qs.Set("Expires", timeToUnixEpochString(expire))
req.URL.RawQuery = qs.Encode()
Expand Down
41 changes: 39 additions & 2 deletions awsauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,42 @@ func TestSign(t *testing.T) {
So(signedReq.Header.Get("Authorization"), ShouldContainSubstring, ", Signature=")
}
})

var keys Credentials
keys = newKeys()
Convey("Requests to services using existing credentials Version 2 should be signed accordingly", t, func() {
reqs := []*http.Request{
newRequest("GET", "https://ec2.amazonaws.com", url.Values{}),
newRequest("GET", "https://elasticache.amazonaws.com/", url.Values{}),
}
for _, req := range reqs {
signedReq := Sign(req, keys)
So(signedReq.URL.Query().Get("SignatureVersion"), ShouldEqual, "2")
}
})

Convey("Requests to services using existing credentials Version 3 should be signed accordingly", t, func() {
reqs := []*http.Request{
newRequest("GET", "https://route53.amazonaws.com", url.Values{}),
newRequest("GET", "https://email.us-east-1.amazonaws.com/", url.Values{}),
}
for _, req := range reqs {
signedReq := Sign(req, keys)
So(signedReq.Header.Get("X-Amzn-Authorization"), ShouldNotBeBlank)
}
})

Convey("Requests to services using existing credentials Version 4 should be signed accordingly", t, func() {
reqs := []*http.Request{
newRequest("POST", "https://sqs.amazonaws.com/", url.Values{}),
newRequest("GET", "https://iam.amazonaws.com", url.Values{}),
newRequest("GET", "https://s3.amazonaws.com", url.Values{}),
}
for _, req := range reqs {
signedReq := Sign(req, keys)
So(signedReq.Header.Get("Authorization"), ShouldContainSubstring, ", Signature=")
}
})
}

func TestExpiration(t *testing.T) {
Expand All @@ -201,8 +237,9 @@ func TestExpiration(t *testing.T) {
}

func credentialsSet() bool {
checkKeys()
if Keys.AccessKeyID == "" {
var keys Credentials
keys = newKeys()
if keys.AccessKeyID == "" {
return false
} else {
return true
Expand Down
22 changes: 10 additions & 12 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,22 @@ func serviceAndRegion(host string) (service string, region string) {
return
}

func checkKeys() {
if Keys == nil {
Keys = &Credentials{
AccessKeyID: os.Getenv(envAccessKeyID),
SecretAccessKey: os.Getenv(envSecretAccessKey),
SecurityToken: os.Getenv(envSecurityToken),
}
}
func newKeys() (newCredentials Credentials) {

newCredentials.AccessKeyID = os.Getenv(envAccessKeyID)
newCredentials.SecretAccessKey = os.Getenv(envSecretAccessKey)
newCredentials.SecurityToken = os.Getenv(envSecurityToken)

// If there is no Access Key and you are on EC2, get the key from the role
if Keys.AccessKeyID == "" && onEC2() {
Keys = getIAMRoleCredentials()
if newCredentials.AccessKeyID == "" && onEC2() {
newCredentials = *getIAMRoleCredentials()
}

// If the key is expiring, get a new key
if Keys.expired() && onEC2() {
Keys = getIAMRoleCredentials()
if newCredentials.expired() && onEC2() {
newCredentials = *getIAMRoleCredentials()
}
return newCredentials
}

// onEC2 checks to see if the program is running on an EC2 instance.
Expand Down
4 changes: 2 additions & 2 deletions s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"time"
)

func signatureS3(stringToSign string) string {
hashed := hmacSHA1([]byte(Keys.SecretAccessKey), stringToSign)
func signatureS3(stringToSign string, keys Credentials) string {
hashed := hmacSHA1([]byte(keys.SecretAccessKey), stringToSign)
return base64.StdEncoding.EncodeToString(hashed)
}

Expand Down
14 changes: 7 additions & 7 deletions s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestSignatureS3(t *testing.T) {
// (but signed URL requests still utilize a lot of the same functionality)

Convey("Given a GET request to Amazon S3", t, func() {
Keys = testCredS3
keys := *testCredS3
req := test_plainRequestS3()

// Mock time
Expand Down Expand Up @@ -46,13 +46,13 @@ func TestSignatureS3(t *testing.T) {
})

Convey("The final signature string should be exactly correct", func() {
actual := signatureS3(stringToSignS3(req))
actual := signatureS3(stringToSignS3(req), keys)
So(actual, ShouldEqual, "bWq2s1WEIj+Ydj0vQ697zp+IXMU=")
})
})

Convey("Given a GET request for a resource on S3 for query string authentication", t, func() {
Keys = testCredS3
keys := *testCredS3
req, _ := http.NewRequest("GET", "https://johnsmith.s3.amazonaws.com/johnsmith/photos/puppy.jpg", nil)

now = func() time.Time {
Expand All @@ -66,13 +66,13 @@ func TestSignatureS3(t *testing.T) {
})

Convey("The signature of string to sign should be correct", func() {
actual := signatureS3(expectedStringToSignS3Url)
actual := signatureS3(expectedStringToSignS3Url, keys)
So(actual, ShouldEqual, "R2K/+9bbnBIbVDCs7dqlz3XFtBQ=")
})

Convey("The finished signed URL should be correct", func() {
expiry := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
So(SignS3Url(req, expiry).URL.String(), ShouldEqual, expectedSignedS3Url)
So(SignS3Url(req, expiry, keys).URL.String(), ShouldEqual, expectedSignedS3Url)
})
})
}
Expand All @@ -82,10 +82,10 @@ func TestS3STSRequestPreparer(t *testing.T) {
req := test_plainRequestS3()

Convey("And a set of credentials with an STS token", func() {
Keys = testCredS3WithSTS
keys := *testCredS3WithSTS

Convey("It should include an X-Amz-Security-Token when the request is signed", func() {
actualSigned := SignS3(req)
actualSigned := SignS3(req, keys)
actual := actualSigned.Header.Get("X-Amz-Security-Token")

So(actual, ShouldNotBeBlank)
Expand Down
Loading

0 comments on commit 41cdffe

Please sign in to comment.