Skip to content

Commit

Permalink
Merge pull request #28994 from timcharper/2015.8.1-dev
Browse files Browse the repository at this point in the history
add support to s3 for aws role assumption
  • Loading branch information
Mike Place committed Nov 30, 2015
2 parents 3d16434 + e060986 commit 87e4aa4
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 29 deletions.
48 changes: 30 additions & 18 deletions salt/modules/s3.py
Expand Up @@ -19,6 +19,10 @@
s3.service_url: s3.amazonaws.com
A role_arn may also be specified in the configuration::
s3.role_arn: arn:aws:iam::111111111111:role/my-role-to-assume
If a service_url is not specified, the default is s3.amazonaws.com. This
may appear in various documentation as an "endpoint". A comprehensive list
for Amazon S3 may be found at::
Expand Down Expand Up @@ -67,7 +71,7 @@ def __virtual__():


def delete(bucket, path=None, action=None, key=None, keyid=None,
service_url=None, verify_ssl=None, location=None):
service_url=None, verify_ssl=None, location=None, role_arn=None):
'''
Delete a bucket, or delete an object from a bucket.
Expand All @@ -79,8 +83,8 @@ def delete(bucket, path=None, action=None, key=None, keyid=None,
salt myminion s3.delete mybucket remoteobject
'''
key, keyid, service_url, verify_ssl, location = _get_key(
key, keyid, service_url, verify_ssl, location)
key, keyid, service_url, verify_ssl, location, role_arn = _get_key(
key, keyid, service_url, verify_ssl, location, role_arn)

return salt.utils.s3.query(method='DELETE',
bucket=bucket,
Expand All @@ -90,12 +94,13 @@ def delete(bucket, path=None, action=None, key=None, keyid=None,
keyid=keyid,
service_url=service_url,
verify_ssl=verify_ssl,
location=location)
location=location,
role_arn=role_arn)


def get(bucket=None, path=None, return_bin=False, action=None,
local_file=None, key=None, keyid=None, service_url=None,
verify_ssl=None, location=None):
verify_ssl=None, location=None, role_arn=None):
'''
List the contents of a bucket, or return an object from a bucket. Set
return_bin to True in order to retrieve an object wholesale. Otherwise,
Expand Down Expand Up @@ -147,8 +152,8 @@ def get(bucket=None, path=None, return_bin=False, action=None,
salt myminion s3.get mybucket myfile.png action=acl
'''
key, keyid, service_url, verify_ssl, location = _get_key(
key, keyid, service_url, verify_ssl, location)
key, keyid, service_url, verify_ssl, location, role_arn = _get_key(
key, keyid, service_url, verify_ssl, location, role_arn)

return salt.utils.s3.query(method='GET',
bucket=bucket,
Expand All @@ -160,11 +165,12 @@ def get(bucket=None, path=None, return_bin=False, action=None,
keyid=keyid,
service_url=service_url,
verify_ssl=verify_ssl,
location=location)
location=location,
role_arn=role_arn)


def head(bucket, path=None, key=None, keyid=None, service_url=None,
verify_ssl=None, location=None):
verify_ssl=None, location=None, role_arn=None):
'''
Return the metadata for a bucket, or an object in a bucket.
Expand All @@ -175,8 +181,8 @@ def head(bucket, path=None, key=None, keyid=None, service_url=None,
salt myminion s3.head mybucket
salt myminion s3.head mybucket myfile.png
'''
key, keyid, service_url, verify_ssl, location = _get_key(
key, keyid, service_url, verify_ssl, location)
key, keyid, service_url, verify_ssl, location, role_arn = _get_key(
key, keyid, service_url, verify_ssl, location, role_arn)

return salt.utils.s3.query(method='HEAD',
bucket=bucket,
Expand All @@ -186,11 +192,13 @@ def head(bucket, path=None, key=None, keyid=None, service_url=None,
service_url=service_url,
verify_ssl=verify_ssl,
location=location,
full_headers=True)
full_headers=True,
role_arn=role_arn)


def put(bucket, path=None, return_bin=False, action=None, local_file=None,
key=None, keyid=None, service_url=None, verify_ssl=None, location=None):
key=None, keyid=None, service_url=None, verify_ssl=None, location=None,
role_arn=None):
'''
Create a new bucket, or upload an object to a bucket.
Expand All @@ -206,8 +214,8 @@ def put(bucket, path=None, return_bin=False, action=None, local_file=None,
salt myminion s3.put mybucket remotepath local_file=/path/to/file
'''
key, keyid, service_url, verify_ssl, location = _get_key(
key, keyid, service_url, verify_ssl, location)
key, keyid, service_url, verify_ssl, location, role_arn = _get_key(
key, keyid, service_url, verify_ssl, location, role_arn)

return salt.utils.s3.query(method='PUT',
bucket=bucket,
Expand All @@ -219,10 +227,11 @@ def put(bucket, path=None, return_bin=False, action=None, local_file=None,
keyid=keyid,
service_url=service_url,
verify_ssl=verify_ssl,
location=location)
location=location,
role_arn=role_arn)


def _get_key(key, keyid, service_url, verify_ssl, location):
def _get_key(key, keyid, service_url, verify_ssl, location, role_arn):
'''
Examine the keys, and populate as necessary
'''
Expand All @@ -247,4 +256,7 @@ def _get_key(key, keyid, service_url, verify_ssl, location):
if location is None and __salt__['config.option']('s3.location') is not None:
location = __salt__['config.option']('s3.location')

return key, keyid, service_url, verify_ssl, location
if role_arn is None and __salt__['config.option']('s3.role_arn') is not None:
role_arn = __salt__['config.option']('s3.role_arn')

return key, keyid, service_url, verify_ssl, location, role_arn
67 changes: 61 additions & 6 deletions salt/utils/aws.py
Expand Up @@ -14,10 +14,12 @@
import sys
import time
import binascii
import datetime
from datetime import datetime
import hashlib
import hmac
import logging
import salt.config
import re

# Import Salt libs
import salt.utils.xmlutil as xml
Expand Down Expand Up @@ -53,6 +55,7 @@
__Token__ = ''
__Expiration__ = ''
__Location__ = ''
__AssumeCache__ = {}


def creds(provider):
Expand All @@ -70,7 +73,7 @@ def creds(provider):
if provider['id'] == IROLE_CODE or provider['key'] == IROLE_CODE:
# Check to see if we have cache credentials that are still good
if __Expiration__ != '':
timenow = datetime.datetime.utcnow()
timenow = datetime.utcnow()
timestamp = timenow.strftime('%Y-%m-%dT%H:%M:%SZ')
if timestamp < __Expiration__:
# Current timestamp less than expiration fo cached credentials
Expand Down Expand Up @@ -114,7 +117,7 @@ def sig2(method, endpoint, params, provider, aws_api_version):
http://docs.aws.amazon.com/general/latest/gr/signature-version-2.html
'''
timenow = datetime.datetime.utcnow()
timenow = datetime.utcnow()
timestamp = timenow.strftime('%Y-%m-%dT%H:%M:%SZ')

# Retrieve access credentials from meta-data, or use provided
Expand Down Expand Up @@ -147,9 +150,58 @@ def sig2(method, endpoint, params, provider, aws_api_version):
return params_with_headers


def assumed_creds(prov_dict, role_arn, location=None):
valid_session_name_re = re.compile("[^a-z0-9A-Z+=,.@-]")

now = (datetime.utcnow() - datetime(1970, 1, 1)).total_seconds()
for key, creds in __AssumeCache__.items():
if (creds["Expiration"] - now) <= 120:
__AssumeCache__.delete(key)

if role_arn in __AssumeCache__:
c = __AssumeCache__[role_arn]
return c["AccessKeyId"], c["SecretAccessKey"], c["SessionToken"]

version = "2011-06-15"
session_name = valid_session_name_re.sub('', salt.config.get_id({"root_dir": None})[0])[0:63]

headers, requesturl = sig4(
'GET',
'sts.amazonaws.com',
params={
"Version": version,
"Action": "AssumeRole",
"RoleSessionName": session_name,
"RoleArn": role_arn,
"Policy": '{"Version":"2012-10-17","Statement":[{"Sid":"Stmt1", "Effect":"Allow","Action":"*","Resource":"*"}]}',
"DurationSeconds": "3600"
},
aws_api_version=version,
data='',
uri='/',
prov_dict=prov_dict,
product='sts',
location=location,
requesturl="https://sts.amazonaws.com/"
)
headers["Accept"] = "application/json"
result = requests.request('GET', requesturl, headers=headers,
data='',
verify=True)

if result.status_code >= 400:
LOG.info('AssumeRole response: {0}'.format(result.content))
result.raise_for_status()
resp = result.json()

data = resp["AssumeRoleResponse"]["AssumeRoleResult"]["Credentials"]
__AssumeCache__[role_arn] = data
return data["AccessKeyId"], data["SecretAccessKey"], data["SessionToken"]


def sig4(method, endpoint, params, prov_dict,
aws_api_version=DEFAULT_AWS_API_VERSION, location=None,
product='ec2', uri='/', requesturl=None, data=''):
product='ec2', uri='/', requesturl=None, data='', role_arn=None):
'''
Sign a query against AWS services using Signature Version 4 Signing
Process. This is documented at:
Expand All @@ -158,10 +210,13 @@ def sig4(method, endpoint, params, prov_dict,
http://docs.aws.amazon.com/general/latest/gr/sigv4-signed-request-examples.html
http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
'''
timenow = datetime.datetime.utcnow()
timenow = datetime.utcnow()

# Retrieve access credentials from meta-data, or use provided
access_key_id, secret_access_key, token = creds(prov_dict)
if role_arn is None:
access_key_id, secret_access_key, token = creds(prov_dict)
else:
access_key_id, secret_access_key, token = assumed_creds(prov_dict, role_arn, location=location)

if location is None:
location = get_region_from_metadata()
Expand Down
3 changes: 2 additions & 1 deletion salt/utils/s3.py
Expand Up @@ -28,7 +28,7 @@
def query(key, keyid, method='GET', params=None, headers=None,
requesturl=None, return_url=False, bucket=None, service_url=None,
path='', return_bin=False, action=None, local_file=None,
verify_ssl=True, location=None, full_headers=False):
verify_ssl=True, location=None, full_headers=False, role_arn=None):
'''
Perform a query against an S3-like API. This function requires that a
secret key and the id for that key are passed in. For instance:
Expand Down Expand Up @@ -106,6 +106,7 @@ def query(key, keyid, method='GET', params=None, headers=None,
data=data,
uri='/{0}'.format(path),
prov_dict={'id': keyid, 'key': key},
role_arn=role_arn,
location=location,
product='s3',
requesturl=requesturl,
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/modules/s3_test.py
Expand Up @@ -33,7 +33,7 @@ def test_delete(self):
'''
with patch.object(s3, '_get_key',
return_value=('key', 'keyid', 'service_url',
'verify_ssl', 'location')):
'verify_ssl', 'location', 'role_arn')):
with patch.object(salt.utils.s3, 'query', return_value='A'):
self.assertEqual(s3.delete('bucket'), 'A')

Expand All @@ -44,7 +44,7 @@ def test_get(self):
'''
with patch.object(s3, '_get_key',
return_value=('key', 'keyid', 'service_url',
'verify_ssl', 'location')):
'verify_ssl', 'location', 'role_arn')):
with patch.object(salt.utils.s3, 'query', return_value='A'):
self.assertEqual(s3.get(), 'A')

Expand All @@ -54,7 +54,7 @@ def test_head(self):
'''
with patch.object(s3, '_get_key',
return_value=('key', 'keyid', 'service_url',
'verify_ssl', 'location')):
'verify_ssl', 'location', 'role_arn')):
with patch.object(salt.utils.s3, 'query', return_value='A'):
self.assertEqual(s3.head('bucket'), 'A')

Expand All @@ -64,7 +64,7 @@ def test_put(self):
'''
with patch.object(s3, '_get_key',
return_value=('key', 'keyid', 'service_url',
'verify_ssl', 'location')):
'verify_ssl', 'location', 'role_arn')):
with patch.object(salt.utils.s3, 'query', return_value='A'):
self.assertEqual(s3.put('bucket'), 'A')

Expand Down

0 comments on commit 87e4aa4

Please sign in to comment.