Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoscaler][AWS] Make sure subnets belong to same VPC as user-specified security groups #13558

Merged
merged 10 commits into from
Jan 28, 2021
56 changes: 53 additions & 3 deletions python/ray/autoscaler/_private/aws/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import os
import time
from typing import Any, Dict, List
import logging

import boto3
Expand Down Expand Up @@ -357,9 +358,23 @@ def _configure_subnet(config):
ec2 = _resource("ec2", config)
use_internal_ips = config["provider"].get("use_internal_ips", False)

# If head or worker security group is specified, filter down to subnets
# belonging to the same VPC as the security group.
sg_ids = (config["head_node"].get("SecurityGroupIds", []) +
config["worker_nodes"].get("SecurityGroupIds", []))
if sg_ids:
vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
else:
vpc_id_of_sg = None

try:
candidate_subnets = ec2.subnets.all()
if vpc_id_of_sg:
candidate_subnets = [
s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg
]
subnets = sorted(
(s for s in ec2.subnets.all() if s.state == "available" and (
(s for s in candidate_subnets if s.state == "available" and (
use_internal_ips or s.map_public_ip_on_launch)),
reverse=True, # sort from Z-A
key=lambda subnet: subnet.availability_zone)
Expand Down Expand Up @@ -414,6 +429,34 @@ def _configure_subnet(config):
return config


def _get_vpc_id_of_sg(sg_ids: List[str], config: Dict[str, Any]) -> str:
"""Returns the VPC id of the security groups with the provided security
group ids.

Errors if the provided security groups belong to multiple VPCs.
Errors if no security group with any of the provided ids is identified.
"""
sg_ids = list(set(sg_ids))

ec2 = _resource("ec2", config)
filters = [{"Name": "group-id", "Values": sg_ids}]
security_groups = ec2.security_groups.filter(Filters=filters)
vpc_ids = [sg.vpc_id for sg in security_groups]
vpc_ids = list(set(vpc_ids))

multiple_vpc_msg = "All security groups specified in the cluster config "\
"should belong to the same VPC."
cli_logger.doassert(len(vpc_ids) <= 1, multiple_vpc_msg)
assert len(vpc_ids) <= 1, multiple_vpc_msg

Copy link
Contributor Author

@DmitriGekhtman DmitriGekhtman Jan 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user specifies security groups for both head and worker, and the security groups belong to different VPCs, this will throw an error.

We could have it support user-specified security groups in different VPCs for head and workers.
This would match the behavior for subnets: #8374 .
I think I won't bother with that, unless reviewers think I should.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we wanted to go further down that path, we would want at some point to support different VPCs for each node type in a multi node-type config.

no_sg_msg = "Failed to detect a security group with id equal to any of "\
"the configured SecurityGroupIds."
cli_logger.doassert(len(vpc_ids) > 0, no_sg_msg)
assert len(vpc_ids) > 0, no_sg_msg

return vpc_ids[0]


def _configure_security_group(config):
_set_config_info(
head_security_group_src="config", workers_security_group_src="config")
Expand Down Expand Up @@ -566,6 +609,13 @@ def _create_security_group(config, vpc_id, group_name):

def _upsert_security_group_rules(conf, security_groups):
sgids = {sg.id for sg in security_groups.values()}

# Update sgids to include user-specified security groups.
# This is necessary if the user specifies the head node type's security
# groups but not the worker's, or vice-versa.
for node_type in NODE_KIND_CONFIG_KEYS.values():
sgids.update(conf[node_type].get("SecurityGroupIds", []))

# sort security group items for deterministic inbound rule config order
# (mainly supports more precise stub-based boto3 unit testing)
for node_type, sg in sorted(security_groups.items()):
Expand All @@ -583,7 +633,7 @@ def _update_inbound_rules(target_security_group, sgids, config):


def _create_default_inbound_rules(sgids, extended_rules=[]):
intracluster_rules = _create_default_instracluster_inbound_rules(sgids)
intracluster_rules = _create_default_intracluster_inbound_rules(sgids)
ssh_rules = _create_default_ssh_inbound_rules()
merged_rules = itertools.chain(
intracluster_rules,
Expand All @@ -593,7 +643,7 @@ def _create_default_inbound_rules(sgids, extended_rules=[]):
return list(merged_rules)


def _create_default_instracluster_inbound_rules(intracluster_sgids):
def _create_default_intracluster_inbound_rules(intracluster_sgids):
return [{
"FromPort": -1,
"ToPort": -1,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
cluster_name: sg

max_workers: 1

provider:
type: aws
region: us-west-2
availability_zone: us-west-2a

auth:
ssh_user: ubuntu

# If required, head and worker nodes can exist on subnets in different VPCs and
# communicate via VPC peering.

# VPC peering overview: https://docs.aws.amazon.com/vpc/latest/userguide/vpc-peering.html.
# Setup VPC peering: https://docs.aws.amazon.com/vpc/latest/peering/create-vpc-peering-connection.html.
# Configure VPC peering route tables: https://docs.aws.amazon.com/vpc/latest/peering/vpc-peering-routing.html.

# To enable external SSH connectivity, you should also ensure that your VPC
# is configured to assign public IPv4 addresses to every EC2 instance
# assigned to it.
head_node:
SecurityGroupIds:
- sg-1234abcd # Replace with an actual security group id.

worker_nodes:
SecurityGroupIds:
- sg-1234abcd # Replace with an actual security group id.


20 changes: 20 additions & 0 deletions python/ray/tests/aws/test_autoscaler_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ def test_create_sg_with_custom_inbound_rules_and_name(iam_client_stub,
ec2_client_stub.assert_no_pending_responses()


def test_subnet_given_head_and_worker_sg(iam_client_stub, ec2_client_stub):
stubs.configure_iam_role_default(iam_client_stub)
stubs.configure_key_pair_default(ec2_client_stub)

# list a security group and a thousand subnets in different vpcs
stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub)

config = helpers.bootstrap_aws_example_config_file(
"example-head-and-worker-security-group.yaml")

# check that just the single subnet in the right vpc is filled
assert config["head_node"]["SubnetIds"] == [DEFAULT_SUBNET["SubnetId"]]
assert config["worker_nodes"]["SubnetIds"] == [DEFAULT_SUBNET["SubnetId"]]

# expect no pending responses left in IAM or EC2 client stub queues
iam_client_stub.assert_no_pending_responses()
ec2_client_stub.assert_no_pending_responses()


if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))
13 changes: 13 additions & 0 deletions python/ray/tests/aws/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@
"VpcId": "vpc-0000000",
}


def subnet_in_vpc(vpc_num):
"""Returns a copy of DEFAULT_SUBNET whose VpcId ends with the digits
of vpc_num."""
subnet = copy.copy(DEFAULT_SUBNET)
subnet["VpcId"] = f"vpc-{vpc_num:07d}"
return subnet


A_THOUSAND_SUBNETS_IN_DIFFERENT_VPCS = [
subnet_in_vpc(vpc_num) for vpc_num in range(1, 1000)
] + [DEFAULT_SUBNET]

# Secondary EC2 subnet to expose to tests as required.
AUX_SUBNET = {
"AvailabilityZone": "us-west-2a",
Expand Down
21 changes: 20 additions & 1 deletion python/ray/tests/aws/utils/stubs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ray
from ray.tests.aws.utils.mocks import mock_path_exists_key_pair
from ray.tests.aws.utils.constants import DEFAULT_INSTANCE_PROFILE, \
DEFAULT_KEY_PAIR, DEFAULT_SUBNET
DEFAULT_KEY_PAIR, DEFAULT_SUBNET, A_THOUSAND_SUBNETS_IN_DIFFERENT_VPCS

from unittest import mock

Expand Down Expand Up @@ -41,6 +41,13 @@ def configure_subnet_default(ec2_client_stub):
service_response={"Subnets": [DEFAULT_SUBNET]})


def describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub):
ec2_client_stub.add_response(
"describe_subnets",
expected_params={},
service_response={"Subnets": A_THOUSAND_SUBNETS_IN_DIFFERENT_VPCS})


def skip_to_configure_sg(ec2_client_stub, iam_client_stub):
configure_iam_role_default(iam_client_stub)
configure_key_pair_default(ec2_client_stub)
Expand All @@ -66,6 +73,18 @@ def describe_no_security_groups(ec2_client_stub):
service_response={})


def describe_a_security_group(ec2_client_stub, security_group):
ec2_client_stub.add_response(
"describe_security_groups",
expected_params={
"Filters": [{
"Name": "group-id",
"Values": [security_group["GroupId"]]
}]
},
service_response={"SecurityGroups": [security_group]})


def create_sg_echo(ec2_client_stub, security_group):
ec2_client_stub.add_response(
"create_security_group",
Expand Down