Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@
CeleryTaskQueueGateway,
DatadogMonitoringMetricsGateway,
FakeMonitoringMetricsGateway,
GCSFileStorageGateway,
GCSFilesystemGateway,
GCSLLMArtifactGateway,
LiveAsyncModelEndpointInferenceGateway,
LiveBatchJobOrchestrationGateway,
LiveBatchJobProgressGateway,
Expand Down Expand Up @@ -100,6 +103,9 @@
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
QueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.redis_queue_endpoint_resource_delegate import (
RedisQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import (
SQSQueueEndpointResourceDelegate,
)
Expand All @@ -115,6 +121,9 @@
DbTriggerRepository,
ECRDockerRepository,
FakeDockerRepository,
GARDockerRepository,
GCSFileLLMFineTuneEventsRepository,
GCSFileLLMFineTuneRepository,
LiveTokenizerRepository,
LLMFineTuneRepository,
OnPremDockerRepository,
Expand Down Expand Up @@ -224,13 +233,18 @@ def _get_external_interfaces(
read_only=read_only,
)

redis_client = aioredis.Redis(connection_pool=get_or_create_aioredis_pool())

queue_delegate: QueueEndpointResourceDelegate
if CIRCLECI:
queue_delegate = FakeQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "onprem":
queue_delegate = OnPremQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
# GCP uses Redis (Memorystore) for Celery, so use Redis-based queue delegate
queue_delegate = RedisQueueEndpointResourceDelegate(redis_client=redis_client)
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand All @@ -245,13 +259,13 @@ def _get_external_interfaces(
elif infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().celery_broker_type_redis:
elif infra_config().cloud_provider == "gcp" or infra_config().celery_broker_type_redis:
# GCP uses Redis (Memorystore) for Celery broker
inference_task_queue_gateway = redis_task_queue_gateway
infra_task_queue_gateway = redis_task_queue_gateway
else:
inference_task_queue_gateway = sqs_task_queue_gateway
infra_task_queue_gateway = sqs_task_queue_gateway
redis_client = aioredis.Redis(connection_pool=get_or_create_aioredis_pool())
inference_autoscaling_metrics_gateway = (
ASBInferenceAutoscalingMetricsGateway()
if infra_config().cloud_provider == "azure"
Expand Down Expand Up @@ -286,6 +300,9 @@ def _get_external_interfaces(
if infra_config().cloud_provider == "azure":
filesystem_gateway = ABSFilesystemGateway()
llm_artifact_gateway = ABSLLMArtifactGateway()
elif infra_config().cloud_provider == "gcp":
filesystem_gateway = GCSFilesystemGateway()
llm_artifact_gateway = GCSLLMArtifactGateway()
else:
# AWS uses S3, on-prem uses MinIO (S3-compatible)
filesystem_gateway = S3FilesystemGateway()
Expand Down Expand Up @@ -337,6 +354,11 @@ def _get_external_interfaces(
if infra_config().cloud_provider == "azure":
llm_fine_tune_repository = ABSFileLLMFineTuneRepository(file_path=file_path)
llm_fine_tune_events_repository = ABSFileLLMFineTuneEventsRepository()
elif infra_config().cloud_provider == "gcp":
llm_fine_tune_repository = GCSFileLLMFineTuneRepository(
file_path=file_path,
)
llm_fine_tune_events_repository = GCSFileLLMFineTuneEventsRepository()
else:
# AWS uses S3, on-prem uses MinIO (S3-compatible)
llm_fine_tune_repository = S3FileLLMFineTuneRepository(file_path=file_path)
Expand All @@ -354,6 +376,8 @@ def _get_external_interfaces(
file_storage_gateway: FileStorageGateway
if infra_config().cloud_provider == "azure":
file_storage_gateway = ABSFileStorageGateway()
elif infra_config().cloud_provider == "gcp":
file_storage_gateway = GCSFileStorageGateway()
else:
# AWS uses S3, on-prem uses MinIO (S3-compatible)
file_storage_gateway = S3FileStorageGateway()
Expand All @@ -365,6 +389,8 @@ def _get_external_interfaces(
docker_repository = OnPremDockerRepository()
elif infra_config().cloud_provider == "azure":
docker_repository = ACRDockerRepository()
elif infra_config().cloud_provider == "gcp":
docker_repository = GARDockerRepository()
else:
docker_repository = ECRDockerRepository()

Expand Down Expand Up @@ -417,11 +443,13 @@ async def get_external_interfaces():
try:
from plugins.dependencies import get_external_interfaces as get_custom_external_interfaces

yield get_custom_external_interfaces()
ei = get_custom_external_interfaces()
except ModuleNotFoundError:
yield get_default_external_interfaces()
ei = get_default_external_interfaces()
try:
yield ei
finally:
pass
await ei.file_storage_gateway.close()


async def get_external_interfaces_read_only():
Expand All @@ -430,11 +458,13 @@ async def get_external_interfaces_read_only():
get_external_interfaces_read_only as get_custom_external_interfaces_read_only,
)

yield get_custom_external_interfaces_read_only()
ei = get_custom_external_interfaces_read_only()
except ModuleNotFoundError:
yield get_default_external_interfaces_read_only()
ei = get_default_external_interfaces_read_only()
try:
yield ei
finally:
pass
await ei.file_storage_gateway.close()


def get_default_auth_repository() -> AuthenticationRepository:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def excluded_namespaces():


ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master"
GCP_MEMORYSTORE_REDIS_BROKER = "redis-gcp-memorystore-message-broker-master"
SQS_BROKER = "sqs-message-broker-master"
SERVICEBUS_BROKER = "servicebus-message-broker-master"

Expand Down Expand Up @@ -589,6 +590,8 @@ async def main():

BROKER_NAME_TO_CLASS = {
ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True),
# GCP Memorystore also doesn't support CONFIG GET
GCP_MEMORYSTORE_REDIS_BROKER: RedisBroker(use_elasticache=True),
SQS_BROKER: SQSBroker(),
SERVICEBUS_BROKER: ASBBroker(),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,7 @@ async def get_file_content(self, owner: str, file_id: str) -> Optional[str]:
The content of the file, or None if it does not exist.
"""
pass

async def close(self) -> None:
"""Release any resources held by this gateway. No-op by default."""
pass
6 changes: 6 additions & 0 deletions model-engine/model_engine_server/infra/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from .datadog_monitoring_metrics_gateway import DatadogMonitoringMetricsGateway
from .fake_model_primitive_gateway import FakeModelPrimitiveGateway
from .fake_monitoring_metrics_gateway import FakeMonitoringMetricsGateway
from .gcs_file_storage_gateway import GCSFileStorageGateway
from .gcs_filesystem_gateway import GCSFilesystemGateway
from .gcs_llm_artifact_gateway import GCSLLMArtifactGateway
from .live_async_model_endpoint_inference_gateway import LiveAsyncModelEndpointInferenceGateway
from .live_batch_job_orchestration_gateway import LiveBatchJobOrchestrationGateway
from .live_batch_job_progress_gateway import LiveBatchJobProgressGateway
Expand Down Expand Up @@ -37,6 +40,9 @@
"DatadogMonitoringMetricsGateway",
"FakeModelPrimitiveGateway",
"FakeMonitoringMetricsGateway",
"GCSFileStorageGateway",
"GCSFilesystemGateway",
"GCSLLMArtifactGateway",
"LiveAsyncModelEndpointInferenceGateway",
"LiveBatchJobOrchestrationGateway",
"LiveBatchJobProgressGateway",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import asyncio
import os
from datetime import timedelta
from typing import List, Optional

from gcloud.aio.storage import Storage
from model_engine_server.core.config import infra_config
from model_engine_server.domain.gateways.file_storage_gateway import (
FileMetadata,
FileStorageGateway,
)
from model_engine_server.infra.gateways.gcs_storage_client import get_gcs_sync_client, parse_gcs_uri


def _get_gcs_key(owner: str, file_id: str) -> str:
return os.path.join(owner, file_id)


def _get_gcs_url(owner: str, file_id: str) -> str:
return f"gs://{infra_config().s3_bucket}/{_get_gcs_key(owner, file_id)}"


def _generate_signed_url_sync(uri: str, expiration: int = 3600) -> str:
"""Generate a V4 signed URL synchronously (gcloud-aio-storage does not support this)."""
bucket_name, blob_name = parse_gcs_uri(uri)
client = get_gcs_sync_client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name)
return blob.generate_signed_url(
version="v4",
expiration=timedelta(seconds=expiration),
method="GET",
)


class GCSFileStorageGateway(FileStorageGateway):
"""
Concrete implementation of a file storage gateway backed by GCS,
using gcloud-aio-storage for async-native operations.
"""

def __init__(self) -> None:
self._storage = Storage()

async def close(self) -> None:
await self._storage.close()

async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]:
uri = _get_gcs_url(owner, file_id)
return await asyncio.to_thread(_generate_signed_url_sync, uri)

async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]:
bucket_name = infra_config().s3_bucket
blob_name = _get_gcs_key(owner, file_id)
try:
metadata = await self._storage.download_metadata(bucket_name, blob_name)
return FileMetadata(
id=file_id,
filename=file_id,
size=int(metadata.get("size", 0)),
owner=owner,
updated_at=metadata.get("updated"),
)
except Exception:
return None

async def get_file_content(self, owner: str, file_id: str) -> Optional[str]:
bucket_name = infra_config().s3_bucket
blob_name = _get_gcs_key(owner, file_id)
try:
content = await self._storage.download(bucket_name, blob_name)
return content.decode("utf-8")
except Exception:
return None

async def upload_file(self, owner: str, filename: str, content: bytes) -> str:
bucket_name = infra_config().s3_bucket
blob_name = _get_gcs_key(owner, filename)
await self._storage.upload(bucket_name, blob_name, content)
return filename

async def delete_file(self, owner: str, file_id: str) -> bool:
bucket_name = infra_config().s3_bucket
blob_name = _get_gcs_key(owner, file_id)
try:
await self._storage.delete(bucket_name, blob_name)
return True
except Exception:
return False

async def list_files(self, owner: str) -> List[FileMetadata]:
bucket_name = infra_config().s3_bucket
files: List[FileMetadata] = []
params = {"prefix": owner}
while True:
response = await self._storage.list_objects(bucket_name, params=params)
for item in response.get("items", []):
blob_name = item.get("name", "")
file_id = blob_name.replace(f"{owner}/", "", 1)
files.append(
FileMetadata(
id=file_id,
filename=file_id,
size=int(item.get("size", 0)),
owner=owner,
updated_at=item.get("updated"),
)
)
next_token = response.get("nextPageToken")
if not next_token:
break
params = {"prefix": owner, "pageToken": next_token}
return files
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio
from datetime import timedelta
from typing import IO

import smart_open
from gcloud.aio.storage import Storage
from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway
from model_engine_server.infra.gateways.gcs_storage_client import get_gcs_sync_client, parse_gcs_uri


class GCSFilesystemGateway(FilesystemGateway):
"""
Concrete implementation for interacting with a filesystem backed by Google Cloud Storage.

Provides both sync methods (required by FilesystemGateway ABC) and async-native
counterparts using gcloud-aio-storage for use in async contexts.
"""

def open(self, uri: str, mode: str = "rt", **kwargs) -> IO:
client = get_gcs_sync_client()
transport_params = {"client": client}
return smart_open.open(uri, mode, transport_params=transport_params)

def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str:
bucket_name, blob_name = parse_gcs_uri(uri)
client = get_gcs_sync_client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name)
return blob.generate_signed_url(
version="v4",
expiration=timedelta(seconds=expiration),
method="GET",
**kwargs,
)

async def async_read(self, uri: str) -> bytes:
"""Async-native download of blob content."""
bucket_name, blob_name = parse_gcs_uri(uri)
async with Storage() as storage:
return await storage.download(bucket_name, blob_name)

async def async_write(self, uri: str, content: bytes) -> None:
"""Async-native upload of blob content."""
bucket_name, blob_name = parse_gcs_uri(uri)
async with Storage() as storage:
await storage.upload(bucket_name, blob_name, content)

async def async_generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str:
"""Async wrapper for signed URL generation (offloaded to a thread)."""
return await asyncio.to_thread(self.generate_signed_url, uri, expiration, **kwargs)
Loading