Skip to content

Commit

Permalink
Merge pull request #221 from thedrow/topic/filters
Browse files Browse the repository at this point in the history
More tagging filters and refactorings
  • Loading branch information
spulec committed Oct 1, 2014
2 parents 27ef345 + efa687f commit 0a99aae
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 53 deletions.
98 changes: 46 additions & 52 deletions moto/ec2/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import unicode_literals
import six
import copy
import itertools
from collections import defaultdict

import six
import boto
from boto.ec2.instance import Instance as BotoInstance, Reservation
from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType
Expand Down Expand Up @@ -70,7 +70,7 @@
random_volume_id,
random_vpc_id,
random_vpc_peering_connection_id,
)
generic_filter)


class InstanceState(object):
Expand All @@ -88,11 +88,19 @@ def get_filter_value(self, filter_name):
tags = self.get_tags()

if filter_name.startswith('tag:'):
tagname = filter_name.split('tag:')[1]
tagname = filter_name.replace('tag:', '', 1)
for tag in tags:
if tag['key'] == tagname:
return tag['value']

return ''

if filter_name == 'tag-key':
return [tag['key'] for tag in tags]

if filter_name == 'tag-value':
return [tag['value'] for tag in tags]


class NetworkInterface(object):
def __init__(self, subnet, private_ip_address, device_index=0, public_ip_auto_assign=True, group_ids=None):
Expand Down Expand Up @@ -615,13 +623,14 @@ def get_filter_value(self, filter_name):
return self.id
elif filter_name == 'state':
return self.state
elif filter_name.startswith('tag:'):
tag_name = filter_name.replace('tag:', '', 1)
tags = dict((tag['key'], tag['value']) for tag in self.get_tags())
return tags.get(tag_name)
else:

filter_value = super(Ami, self).get_filter_value(filter_name)

if filter_value is None:
ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeImages".format(filter_name))

return filter_value


class AmiBackend(object):
def __init__(self):
Expand All @@ -639,9 +648,8 @@ def create_image(self, instance_id, name, description):
def describe_images(self, ami_ids=(), filters=None):
if filters:
images = self.amis.values()
for (_filter, _filter_value) in filters.items():
images = [ ami for ami in images if ami.get_filter_value(_filter) in _filter_value ]
return images

return generic_filter(filters, images)
else:
images = []
for ami_id in ami_ids:
Expand Down Expand Up @@ -1159,11 +1167,8 @@ def get_filter_value(self, filter_name):

filter_value = super(VPC, self).get_filter_value(filter_name)

if not filter_value:
msg = "The filter '{0}' for DescribeVPCs has not been" \
" implemented in Moto yet. Feel free to open an issue at" \
" https://github.com/spulec/moto/issues".format(filter_name)
raise NotImplementedError(msg)
if filter_value is None:
ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeVPCs".format(filter_name))

return filter_value

Expand Down Expand Up @@ -1198,11 +1203,7 @@ def get_all_vpcs(self, vpc_ids=None, filters=None):
else:
vpcs = self.vpcs.values()

if filters:
for (_filter, _filter_value) in filters.items():
vpcs = [ vpc for vpc in vpcs if vpc.get_filter_value(_filter) in _filter_value ]

return vpcs
return generic_filter(filters, vpcs)

def delete_vpc(self, vpc_id):
# Delete route table if only main route table remains.
Expand Down Expand Up @@ -1346,11 +1347,13 @@ def get_filter_value(self, filter_name):
return self.vpc_id
elif filter_name == 'subnet-id':
return self.id
else:
msg = "The filter '{0}' for DescribeSubnets has not been" \
" implemented in Moto yet. Feel free to open an issue at" \
" https://github.com/spulec/moto/issues".format(filter_name)
raise NotImplementedError(msg)

filter_value = super(Subnet, self).get_filter_value(filter_name)

if filter_value is None:
ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeSubnets".format(filter_name))

return filter_value


class SubnetBackend(object):
Expand All @@ -1374,11 +1377,7 @@ def create_subnet(self, vpc_id, cidr_block):
def get_all_subnets(self, filters=None):
subnets = self.subnets.values()

if filters:
for (_filter, _filter_value) in filters.items():
subnets = [ subnet for subnet in subnets if subnet.get_filter_value(_filter) in _filter_value ]

return subnets
return generic_filter(filters, subnets)

def delete_subnet(self, subnet_id):
deleted = self.subnets.pop(subnet_id, None)
Expand Down Expand Up @@ -1417,7 +1416,7 @@ def create_subnet_association(self, route_table_id, subnet_id):
return subnet_association


class RouteTable(object):
class RouteTable(TaggedEC2Instance):
def __init__(self, route_table_id, vpc_id, main=False):
self.id = route_table_id
self.vpc_id = vpc_id
Expand Down Expand Up @@ -1450,11 +1449,13 @@ def get_filter_value(self, filter_name):
return 'false'
elif filter_name == "vpc-id":
return self.vpc_id
else:
msg = "The filter '{0}' for DescribeRouteTables has not been" \
" implemented in Moto yet. Feel free to open an issue at" \
" https://github.com/spulec/moto/issues".format(filter_name)
raise NotImplementedError(msg)

filter_value = super(RouteTable, self).get_filter_value(filter_name)

if filter_value is None:
ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeRouteTables".format(filter_name))

return filter_value


class RouteTableBackend(object):
Expand Down Expand Up @@ -1488,11 +1489,7 @@ def get_all_route_tables(self, route_table_ids=None, filters=None):
invalid_id = list(set(route_table_ids).difference(set([route_table.id for route_table in route_tables])))[0]
raise InvalidRouteTableIdError(invalid_id)

if filters:
for (_filter, _filter_value) in filters.items():
route_tables = [ route_table for route_table in route_tables if route_table.get_filter_value(_filter) in _filter_value ]

return route_tables
return generic_filter(filters, route_tables)

def delete_route_table(self, route_table_id):
deleted = self.route_tables.pop(route_table_id, None)
Expand Down Expand Up @@ -1718,13 +1715,14 @@ def __init__(self, spot_request_id, price, image_id, type, valid_from,
def get_filter_value(self, filter_name):
if filter_name == 'state':
return self.state
elif filter_name.startswith('tag:'):
tag_name = filter_name.replace('tag:', '', 1)
tags = dict((tag['key'], tag['value']) for tag in self.get_tags())
return tags.get(tag_name)
else:

filter_value = super(SpotInstanceRequest, self).get_filter_value(filter_name)

if filter_value is None:
ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeSpotInstanceRequests".format(filter_name))

return filter_value


@six.add_metaclass(Model)
class SpotRequestBackend(object):
Expand Down Expand Up @@ -1754,11 +1752,7 @@ def request_spot_instances(self, price, image_id, count, type, valid_from,
def describe_spot_instance_requests(self, filters=None):
requests = self.spot_instance_requests.values()

if filters:
for (_filter, _filter_value) in filters.items():
requests = [ request for request in requests if request.get_filter_value(_filter) in _filter_value ]

return requests
return generic_filter(filters, requests)

def cancel_spot_instance_requests(self, request_ids):
requests = []
Expand Down
21 changes: 21 additions & 0 deletions moto/ec2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,27 @@ def filter_reservations(reservations, filter_dict):
return result


def is_filter_matching(obj, filter, filter_value):
value = obj.get_filter_value(filter)

if isinstance(value, six.string_types):
return value in filter_value

try:
value = set(value)
return (value and value.issubset(filter_value)) or value.issuperset(filter_value)
except TypeError:
return value in filter_value


def generic_filter(filters, objects):
if filters:
for (_filter, _filter_value) in filters.items():
objects = [obj for obj in objects if is_filter_matching(obj, _filter, _filter_value)]

return objects


# not really random ( http://xkcd.com/221/ )
def random_key_pair():
return {
Expand Down
81 changes: 80 additions & 1 deletion tests/test_ec2/test_vpcs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import unicode_literals
# Ensure 'assert_raises' context manager support for Python 2.6
#import tests.backport_assert_raises
import tests.backport_assert_raises
from nose.tools import assert_raises

import boto
Expand Down Expand Up @@ -128,4 +128,83 @@ def test_vpc_get_by_tag():
vpcs.should.have.length_of(2)
vpc_ids = tuple(map(lambda v: v.id, vpcs))
vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids)


@mock_ec2
def test_vpc_get_by_tag_key_superset():
conn = boto.connect_vpc()
vpc1 = conn.create_vpc("10.0.0.0/16")
vpc2 = conn.create_vpc("10.0.0.0/16")
vpc3 = conn.create_vpc("10.0.0.0/24")

vpc1.add_tag('Name', 'TestVPC')
vpc1.add_tag('Key', 'TestVPC2')
vpc2.add_tag('Name', 'TestVPC')
vpc2.add_tag('Key', 'TestVPC2')
vpc3.add_tag('Key', 'TestVPC2')

vpcs = conn.get_all_vpcs(filters={'tag-key': 'Name'})
vpcs.should.have.length_of(2)
vpc_ids = tuple(map(lambda v: v.id, vpcs))
vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids)


@mock_ec2
def test_vpc_get_by_tag_key_subset():
conn = boto.connect_vpc()
vpc1 = conn.create_vpc("10.0.0.0/16")
vpc2 = conn.create_vpc("10.0.0.0/16")
vpc3 = conn.create_vpc("10.0.0.0/24")

vpc1.add_tag('Name', 'TestVPC')
vpc1.add_tag('Key', 'TestVPC2')
vpc2.add_tag('Name', 'TestVPC')
vpc2.add_tag('Key', 'TestVPC2')
vpc3.add_tag('Test', 'TestVPC2')

vpcs = conn.get_all_vpcs(filters={'tag-key': ['Name', 'Key']})
vpcs.should.have.length_of(2)
vpc_ids = tuple(map(lambda v: v.id, vpcs))
vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids)


@mock_ec2
def test_vpc_get_by_tag_value_superset():
conn = boto.connect_vpc()
vpc1 = conn.create_vpc("10.0.0.0/16")
vpc2 = conn.create_vpc("10.0.0.0/16")
vpc3 = conn.create_vpc("10.0.0.0/24")

vpc1.add_tag('Name', 'TestVPC')
vpc1.add_tag('Key', 'TestVPC2')
vpc2.add_tag('Name', 'TestVPC')
vpc2.add_tag('Key', 'TestVPC2')
vpc3.add_tag('Key', 'TestVPC2')

vpcs = conn.get_all_vpcs(filters={'tag-value': 'TestVPC'})
vpcs.should.have.length_of(2)
vpc_ids = tuple(map(lambda v: v.id, vpcs))
vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids)


@mock_ec2
def test_vpc_get_by_tag_value_subset():
conn = boto.connect_vpc()
vpc1 = conn.create_vpc("10.0.0.0/16")
vpc2 = conn.create_vpc("10.0.0.0/16")
vpc3 = conn.create_vpc("10.0.0.0/24")

vpc1.add_tag('Name', 'TestVPC')
vpc1.add_tag('Key', 'TestVPC2')
vpc2.add_tag('Name', 'TestVPC')
vpc2.add_tag('Key', 'TestVPC2')

vpcs = conn.get_all_vpcs(filters={'tag-value': ['TestVPC', 'TestVPC2']})
vpcs.should.have.length_of(2)
vpc_ids = tuple(map(lambda v: v.id, vpcs))
vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids)

0 comments on commit 0a99aae

Please sign in to comment.