From 97cfe272e885d55c83f45a3edcee311e8f361864 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 13 May 2024 14:20:02 +0200 Subject: [PATCH] Use service connector for boto session if possible (#2682) * Use service connector for boto session if possible * Move logic to aws service connector * Fix log message * Improve error message * Docstring --- .../aws_container_registry.py | 52 ++++++++++++++++--- .../aws_service_connector.py | 31 +++++++++++ 2 files changed, 75 insertions(+), 8 deletions(-) diff --git a/src/zenml/integrations/aws/container_registries/aws_container_registry.py b/src/zenml/integrations/aws/container_registries/aws_container_registry.py index 0867570794..075312f6b8 100644 --- a/src/zenml/integrations/aws/container_registries/aws_container_registry.py +++ b/src/zenml/integrations/aws/container_registries/aws_container_registry.py @@ -17,15 +17,23 @@ from typing import List, Optional, cast import boto3 +from botocore.client import BaseClient from botocore.exceptions import BotoCoreError, ClientError +from zenml.client import Client from zenml.container_registries.base_container_registry import ( BaseContainerRegistry, ) from zenml.integrations.aws.flavors.aws_container_registry_flavor import ( AWSContainerRegistryConfig, ) +from zenml.integrations.aws.service_connectors import ( + AWSServiceConnector, +) from zenml.logger import get_logger +from zenml.service_connectors.service_connector_registry import ( + service_connector_registry, +) logger = get_logger(__name__) @@ -61,6 +69,35 @@ def _get_region(self) -> str: return match.group(1) + def _get_ecr_client(self) -> BaseClient: + """Get an ECR client. + + If this container registry is configured with an AWS service connector, + we use that connector to create an authenticated client. Otherwise + local AWS credentials will be used. + + Returns: + An ECR client. + """ + if self.connector: + try: + model = Client().get_service_connector(self.connector) + connector = service_connector_registry.instantiate_connector( + model=model + ) + assert isinstance(connector, AWSServiceConnector) + return connector.get_ecr_client() + except Exception as e: + logger.error( + "Unable to get ECR client from service connector: %s", + str(e), + ) + + return boto3.Session().client( + "ecr", + region_name=self._get_region(), + ) + def prepare_image_push(self, image_name: str) -> None: """Logs warning message if trying to push an image for which no repository exists. @@ -76,10 +113,9 @@ def prepare_image_push(self, image_name: str) -> None: raise ValueError(f"Invalid docker image name '{image_name}'.") repo_name = match.group(1) + client = self._get_ecr_client() try: - response = boto3.client( - "ecr", region_name=self._get_region() - ).describe_repositories() + response = client.describe_repositories() except (BotoCoreError, ClientError): logger.warning( "Amazon ECR requires you to create a repository before you can " @@ -123,9 +159,9 @@ def post_registration_message(self) -> Optional[str]: """ return ( "Amazon ECR requires you to create a repository before you can " - "push an image to it. If you want to for example run a pipeline " - "using our Kubeflow orchestrator, ZenML will automatically build a " - f"docker image called `{self.config.uri}/zenml-kubeflow:` " - f"and try to push it. This will fail unless you create the " - f"repository `zenml-kubeflow` inside your amazon registry." + "push an image to it. If you want to for run a pipeline " + "using a remote orchestrator, ZenML will automatically build a " + f"docker image called `{self.config.uri}/zenml:` " + f"and try to push it. This will fail unless you create a " + f"repository called `zenml` inside your Amazon ECR." ) diff --git a/src/zenml/integrations/aws/service_connectors/aws_service_connector.py b/src/zenml/integrations/aws/service_connectors/aws_service_connector.py index efef014652..851bf3f07a 100644 --- a/src/zenml/integrations/aws/service_connectors/aws_service_connector.py +++ b/src/zenml/integrations/aws/service_connectors/aws_service_connector.py @@ -731,6 +731,37 @@ def get_boto3_session( self._session_cache[key] = (session, expires_at) return session, expires_at + def get_ecr_client(self) -> BaseClient: + """Get an ECR client. + + Raises: + ValueError: If the service connector is not able to instantiate an + ECR client. + + Returns: + An ECR client. + """ + if self.resource_type and self.resource_type not in { + AWS_RESOURCE_TYPE, + DOCKER_REGISTRY_RESOURCE_TYPE, + }: + raise ValueError( + f"Unable to instantiate ECR client for a connector that is " + f"configured to provide access to a '{self.resource_type}' " + "resource type." + ) + + session, _ = self.get_boto3_session( + auth_method=self.auth_method, + resource_type=DOCKER_REGISTRY_RESOURCE_TYPE, + resource_id=self.config.region, + ) + return session.client( + "ecr", + region_name=self.config.region, + endpoint_url=self.config.endpoint_url, + ) + def _get_iam_policy( self, region_id: str,