-
Notifications
You must be signed in to change notification settings - Fork 30
Cu-86d2e1ka4: simplify inference module for image embedding generation only #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1 @@ | ||
| [aws] | ||
| model_storage_bucket=${MODEL_STORAGE_BUCKET} | ||
|
|
||
| [gcp] | ||
| model_storage_bucket=${MODEL_STORAGE_BUCKET} | ||
|
|
||
| [cloud_config] | ||
| cloud_provider=${CLOUD_PROVIDER} |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,169 +1,77 @@ | ||||||||||||||||||||||||||||||||||
| import base64 | ||||||||||||||||||||||||||||||||||
| from typing import Any, Dict | ||||||||||||||||||||||||||||||||||
| import binascii | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| from common_module.common_container import CommonContainer | ||||||||||||||||||||||||||||||||||
| from common_module.log.logger import logger | ||||||||||||||||||||||||||||||||||
| from common_module.response_formatter import ResponseFormatter | ||||||||||||||||||||||||||||||||||
| from dependency_injector.wiring import Provide, inject | ||||||||||||||||||||||||||||||||||
| from fastapi import APIRouter, Depends, status | ||||||||||||||||||||||||||||||||||
| from fastapi.responses import JSONResponse | ||||||||||||||||||||||||||||||||||
| from inference_app.inference_app_container import InferenceAppContainer | ||||||||||||||||||||||||||||||||||
| from inference_app.service.image_analyser import ImageClarityService | ||||||||||||||||||||||||||||||||||
| from inference_app.service.model_inference import ( | ||||||||||||||||||||||||||||||||||
| ModelInferenceService, | ||||||||||||||||||||||||||||||||||
| PreprocessingStep, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| from inference_app.service.model_repository import ModelRepository | ||||||||||||||||||||||||||||||||||
| from inference_app.service.image_embedding import ImageEmbedding | ||||||||||||||||||||||||||||||||||
| from pydantic import BaseModel, Field | ||||||||||||||||||||||||||||||||||
| from pydantic import BaseModel | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| class InferencePayload(BaseModel): | ||||||||||||||||||||||||||||||||||
| data: str | ||||||||||||||||||||||||||||||||||
| payload_type: str | ||||||||||||||||||||||||||||||||||
| model_info: dict | ||||||||||||||||||||||||||||||||||
| preprocessing_steps: list[PreprocessingStep] | ||||||||||||||||||||||||||||||||||
| max_expected_variance: int = Field(default=1000) | ||||||||||||||||||||||||||||||||||
| resize_width: int = Field(default=224) | ||||||||||||||||||||||||||||||||||
| resize_height: int = Field(default=224) | ||||||||||||||||||||||||||||||||||
| gaussian_blur_kernel: int = Field(default=3) | ||||||||||||||||||||||||||||||||||
| min_threshold: int = Field(default=50) | ||||||||||||||||||||||||||||||||||
| max_threshold: int = Field(default=150) | ||||||||||||||||||||||||||||||||||
| normalize_mean: str = Field(default='0.485,0.456,0.406') | ||||||||||||||||||||||||||||||||||
| normalize_std: str = Field(default='0.229,0.224,0.225') | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| class InferenceResult(BaseModel): | ||||||||||||||||||||||||||||||||||
| results: Dict[str, Any] = Field(..., description='Dictionary of inference results') | ||||||||||||||||||||||||||||||||||
| class ImagePayload(BaseModel): | ||||||||||||||||||||||||||||||||||
| image_data: str # base64 encoded image data | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| class ImagePayload(BaseModel): | ||||||||||||||||||||||||||||||||||
| image_data: str | ||||||||||||||||||||||||||||||||||
| class ImageBatchPayload(BaseModel): | ||||||||||||||||||||||||||||||||||
| image_batch: list[str] # list of base64 encoded image data | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| inference_app_router = APIRouter() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @inference_app_router.post('/v1/model-repository/model/{model_id}/infer') | ||||||||||||||||||||||||||||||||||
| @inference_app_router.post('/v1/query/embeddings') | ||||||||||||||||||||||||||||||||||
| @inject | ||||||||||||||||||||||||||||||||||
| async def generic_inference_handler( | ||||||||||||||||||||||||||||||||||
| payload: InferencePayload, | ||||||||||||||||||||||||||||||||||
| async def image_embedding( | ||||||||||||||||||||||||||||||||||
| payload: ImagePayload, | ||||||||||||||||||||||||||||||||||
| response_formatter: ResponseFormatter = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[CommonContainer.response_formatter] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| model_repository: ModelRepository = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[InferenceAppContainer.model_repository] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| image_analyser: ImageClarityService = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[InferenceAppContainer.image_analyser] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| config: dict = Depends(Provide[InferenceAppContainer.config]), | ||||||||||||||||||||||||||||||||||
| model_inference: ModelInferenceService = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[InferenceAppContainer.model_inference] | ||||||||||||||||||||||||||||||||||
| image_embedding_service: ImageEmbedding = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[InferenceAppContainer.image_embedding] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| provider = config['cloud_config']['cloud_provider'] | ||||||||||||||||||||||||||||||||||
| model_storage_bucket = ( | ||||||||||||||||||||||||||||||||||
| config['gcp']['model_storage_bucket'] | ||||||||||||||||||||||||||||||||||
| if provider.lower() == 'gcp' | ||||||||||||||||||||||||||||||||||
| else config['aws']['model_storage_bucket'] | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||
| f'Loading model from bucket: {model_storage_bucket}, model_info: {payload.model_info}' | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| model = await model_repository.load_model( | ||||||||||||||||||||||||||||||||||
| model_info=payload.model_info, bucket_name=model_storage_bucket | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| logger.debug('Model loaded successfully for model_id') | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if payload.payload_type.lower() == 'image': | ||||||||||||||||||||||||||||||||||
| base64_data_uri = payload.data | ||||||||||||||||||||||||||||||||||
| parts = base64_data_uri.split(',') | ||||||||||||||||||||||||||||||||||
| if len(parts) == 2: | ||||||||||||||||||||||||||||||||||
| base64_data = parts[1] | ||||||||||||||||||||||||||||||||||
| image_bytes = base64.b64decode(base64_data) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| clarity_score = image_analyser.laplacian_detection( | ||||||||||||||||||||||||||||||||||
| image_bytes, payload.max_expected_variance | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| infer_data = model_inference.model_infer_score( | ||||||||||||||||||||||||||||||||||
| model, | ||||||||||||||||||||||||||||||||||
| image_bytes, | ||||||||||||||||||||||||||||||||||
| payload.resize_width, | ||||||||||||||||||||||||||||||||||
| payload.resize_height, | ||||||||||||||||||||||||||||||||||
| payload.normalize_mean, | ||||||||||||||||||||||||||||||||||
| payload.normalize_std, | ||||||||||||||||||||||||||||||||||
| payload.gaussian_blur_kernel, | ||||||||||||||||||||||||||||||||||
| payload.min_threshold, | ||||||||||||||||||||||||||||||||||
| payload.max_threshold, | ||||||||||||||||||||||||||||||||||
| preprocessing_steps=payload.preprocessing_steps, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| logger.debug('Model inference completed successfully for model_id') | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| inference_results = InferenceResult( | ||||||||||||||||||||||||||||||||||
| results={ | ||||||||||||||||||||||||||||||||||
| 'clarity_score': clarity_score, | ||||||||||||||||||||||||||||||||||
| 'infer_data': infer_data, | ||||||||||||||||||||||||||||||||||
| 'data_type': payload.payload_type.lower(), | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| logger.info('Inference request completed successfully for model_id') | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_201_CREATED, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildSuccessResponse( | ||||||||||||||||||||||||||||||||||
| inference_results.dict() | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| error_msg = ( | ||||||||||||||||||||||||||||||||||
| "Input data is not in expected Data URI format (missing 'base64,')." | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||||||||
| f"Expected Data URI format with 'base64,' prefix. " | ||||||||||||||||||||||||||||||||||
| f'Data length: {len(base64_data_uri)}' | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_400_BAD_REQUEST, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildErrorResponse(error_msg), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| error_msg = f"Invalid payload_type: {payload.payload_type}. Accepted values are 'image'" | ||||||||||||||||||||||||||||||||||
| logger.error(f'{error_msg}') | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_400_BAD_REQUEST, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildErrorResponse( | ||||||||||||||||||||||||||||||||||
| 'Invalid payload_type. Accepted values are "image"' | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||
| logger.error(f'Error in generic_inference_handler {str(e)}') | ||||||||||||||||||||||||||||||||||
| # 1. Decode Base64 string | ||||||||||||||||||||||||||||||||||
| image_data = extract_decoded_image_data(payload.image_data) | ||||||||||||||||||||||||||||||||||
| embeddings = image_embedding_service.query_embed(image_data) | ||||||||||||||||||||||||||||||||||
| if not embeddings: | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildErrorResponse('Internal server error'), | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_400_BAD_REQUEST, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildErrorResponse( | ||||||||||||||||||||||||||||||||||
| 'No Embedding data is present' | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_200_OK, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildSuccessResponse(data={'response': embeddings}), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @inference_app_router.post('/v1/query/embeddings') | ||||||||||||||||||||||||||||||||||
| @inference_app_router.post('/v1/query/embeddings/batch') | ||||||||||||||||||||||||||||||||||
| @inject | ||||||||||||||||||||||||||||||||||
| async def image_embedding( | ||||||||||||||||||||||||||||||||||
| payload: ImagePayload, | ||||||||||||||||||||||||||||||||||
| async def image_embedding_batch( | ||||||||||||||||||||||||||||||||||
| payload: ImageBatchPayload, | ||||||||||||||||||||||||||||||||||
| response_formatter: ResponseFormatter = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[CommonContainer.response_formatter] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| image_embedding_service: ImageEmbedding = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[InferenceAppContainer.image_embedding] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| # 1. Decode Base64 string | ||||||||||||||||||||||||||||||||||
| base64_data_uri = payload.image_data | ||||||||||||||||||||||||||||||||||
| parts = base64_data_uri.split(',') | ||||||||||||||||||||||||||||||||||
| base64_data = parts[1] if len(parts) == 2 else parts[0] | ||||||||||||||||||||||||||||||||||
| image_data = base64.b64decode(base64_data) | ||||||||||||||||||||||||||||||||||
| embeddings = image_embedding_service.query_embed(image_data) | ||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| image_batch = [ | ||||||||||||||||||||||||||||||||||
| extract_decoded_image_data(image_data) for image_data in payload.image_batch | ||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||
| except binascii.Error: | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_400_BAD_REQUEST, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildErrorResponse( | ||||||||||||||||||||||||||||||||||
| 'Invalid base64 image data in batch' | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| embeddings = image_embedding_service.query_embed_batch(image_batch) | ||||||||||||||||||||||||||||||||||
vizsatiz marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||
| if not embeddings: | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_400_BAD_REQUEST, | ||||||||||||||||||||||||||||||||||
|
|
@@ -175,3 +83,9 @@ async def image_embedding( | |||||||||||||||||||||||||||||||||
| status_code=status.HTTP_200_OK, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildSuccessResponse(data={'response': embeddings}), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def extract_decoded_image_data(image_data: str) -> bytes: | ||||||||||||||||||||||||||||||||||
| parts = image_data.split(',') | ||||||||||||||||||||||||||||||||||
| base64_data = parts[1] if len(parts) == 2 else parts[0] | ||||||||||||||||||||||||||||||||||
| return base64.b64decode(base64_data) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+88
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
💡 Option: Add validation with a clearer error+import binascii
+
+class InvalidImageDataError(Exception):
+ pass
+
def extract_decoded_image_data(image_data: str) -> bytes:
parts = image_data.split(',')
base64_data = parts[1] if len(parts) == 2 else parts[0]
- return base64.b64decode(base64_data)
+ try:
+ return base64.b64decode(base64_data)
+ except binascii.Error as e:
+ raise InvalidImageDataError('Invalid base64 encoded image data') from e📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,31 +1,9 @@ | ||
| from dependency_injector import containers | ||
| from dependency_injector import providers | ||
| from inference_app.service.image_analyser import ImageClarityService | ||
| from flo_cloud.cloud_storage import CloudStorageManager | ||
| from inference_app.service.model_repository import ModelRepository | ||
| from inference_app.service.model_inference import ModelInferenceService | ||
| from inference_app.service.image_embedding import ImageEmbedding | ||
|
|
||
|
|
||
| class InferenceAppContainer(containers.DeclarativeContainer): | ||
| config = providers.Configuration(ini_files=['config.ini']) | ||
| cache_manager = providers.Dependency() | ||
|
|
||
| cloud_storage_manager = providers.Singleton( | ||
| CloudStorageManager, provider=config.cloud_config.cloud_provider | ||
| ) | ||
|
|
||
| model_repository = providers.Singleton( | ||
| ModelRepository, | ||
| cloud_storage_manager=cloud_storage_manager, | ||
| ) | ||
|
|
||
| model_inference = providers.Singleton(ModelInferenceService) | ||
|
|
||
| image_analyser = providers.Singleton( | ||
| ImageClarityService, | ||
| ) | ||
|
|
||
| image_embedding = providers.Singleton( | ||
| ImageEmbedding, | ||
| ) | ||
| image_embedding = providers.Singleton(ImageEmbedding) | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Single endpoint missing
binascii.Errorhandling.The batch endpoint (lines 63-73) wraps the decode in a try/except for
binascii.Error, but the single endpoint callsextract_decoded_image_datawithout similar protection. Malformed base64 input will propagate to the global exception handler and return HTTP 500 instead of a user-friendly 400.🛡️ Proposed fix for consistent error handling
`@inference_app_router.post`('/v1/query/embeddings') `@inject` async def image_embedding( payload: ImagePayload, response_formatter: ResponseFormatter = Depends( Provide[CommonContainer.response_formatter] ), image_embedding_service: ImageEmbedding = Depends( Provide[InferenceAppContainer.image_embedding] ), ): - # 1. Decode Base64 string - image_data = extract_decoded_image_data(payload.image_data) + try: + image_data = extract_decoded_image_data(payload.image_data) + except binascii.Error: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse( + 'Invalid base64 image data' + ), + ) embeddings = image_embedding_service.query_embed(image_data)🤖 Prompt for AI Agents