Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion wavefront/server/apps/floconsole/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"sqlalchemy>=2.0.40,<3.0.0",
"python-dotenv>=1.1.0,<2.0.0",
"dependency-injector>=4.46.0,<5.0.0",
"psycopg2>=2.9.10,<3.0.0",
"psycopg2-binary>=2.9.10,<3.0.0",
"python-jose[cryptography]>=3.3.0,<4.0.0",
"async-lru>=2.0.5",
]
Expand Down
7 changes: 0 additions & 7 deletions wavefront/server/apps/inference_app/inference_app/config.ini
Original file line number Diff line number Diff line change
@@ -1,8 +1 @@
[aws]
model_storage_bucket=${MODEL_STORAGE_BUCKET}

[gcp]
model_storage_bucket=${MODEL_STORAGE_BUCKET}

[cloud_config]
cloud_provider=${CLOUD_PROVIDER}
Original file line number Diff line number Diff line change
@@ -1,169 +1,77 @@
import base64
from typing import Any, Dict
import binascii

from common_module.common_container import CommonContainer
from common_module.log.logger import logger
from common_module.response_formatter import ResponseFormatter
from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from inference_app.inference_app_container import InferenceAppContainer
from inference_app.service.image_analyser import ImageClarityService
from inference_app.service.model_inference import (
ModelInferenceService,
PreprocessingStep,
)
from inference_app.service.model_repository import ModelRepository
from inference_app.service.image_embedding import ImageEmbedding
from pydantic import BaseModel, Field
from pydantic import BaseModel


class InferencePayload(BaseModel):
data: str
payload_type: str
model_info: dict
preprocessing_steps: list[PreprocessingStep]
max_expected_variance: int = Field(default=1000)
resize_width: int = Field(default=224)
resize_height: int = Field(default=224)
gaussian_blur_kernel: int = Field(default=3)
min_threshold: int = Field(default=50)
max_threshold: int = Field(default=150)
normalize_mean: str = Field(default='0.485,0.456,0.406')
normalize_std: str = Field(default='0.229,0.224,0.225')


class InferenceResult(BaseModel):
results: Dict[str, Any] = Field(..., description='Dictionary of inference results')
class ImagePayload(BaseModel):
image_data: str # base64 encoded image data


class ImagePayload(BaseModel):
image_data: str
class ImageBatchPayload(BaseModel):
image_batch: list[str] # list of base64 encoded image data


inference_app_router = APIRouter()


@inference_app_router.post('/v1/model-repository/model/{model_id}/infer')
@inference_app_router.post('/v1/query/embeddings')
@inject
async def generic_inference_handler(
payload: InferencePayload,
async def image_embedding(
payload: ImagePayload,
response_formatter: ResponseFormatter = Depends(
Provide[CommonContainer.response_formatter]
),
model_repository: ModelRepository = Depends(
Provide[InferenceAppContainer.model_repository]
),
image_analyser: ImageClarityService = Depends(
Provide[InferenceAppContainer.image_analyser]
),
config: dict = Depends(Provide[InferenceAppContainer.config]),
model_inference: ModelInferenceService = Depends(
Provide[InferenceAppContainer.model_inference]
image_embedding_service: ImageEmbedding = Depends(
Provide[InferenceAppContainer.image_embedding]
),
):
try:
provider = config['cloud_config']['cloud_provider']
model_storage_bucket = (
config['gcp']['model_storage_bucket']
if provider.lower() == 'gcp'
else config['aws']['model_storage_bucket']
)

logger.info(
f'Loading model from bucket: {model_storage_bucket}, model_info: {payload.model_info}'
)
model = await model_repository.load_model(
model_info=payload.model_info, bucket_name=model_storage_bucket
)
logger.debug('Model loaded successfully for model_id')

if payload.payload_type.lower() == 'image':
base64_data_uri = payload.data
parts = base64_data_uri.split(',')
if len(parts) == 2:
base64_data = parts[1]
image_bytes = base64.b64decode(base64_data)

clarity_score = image_analyser.laplacian_detection(
image_bytes, payload.max_expected_variance
)

infer_data = model_inference.model_infer_score(
model,
image_bytes,
payload.resize_width,
payload.resize_height,
payload.normalize_mean,
payload.normalize_std,
payload.gaussian_blur_kernel,
payload.min_threshold,
payload.max_threshold,
preprocessing_steps=payload.preprocessing_steps,
)
logger.debug('Model inference completed successfully for model_id')

inference_results = InferenceResult(
results={
'clarity_score': clarity_score,
'infer_data': infer_data,
'data_type': payload.payload_type.lower(),
}
)

logger.info('Inference request completed successfully for model_id')
return JSONResponse(
status_code=status.HTTP_201_CREATED,
content=response_formatter.buildSuccessResponse(
inference_results.dict()
),
)
else:
error_msg = (
"Input data is not in expected Data URI format (missing 'base64,')."
)
logger.error(
f"Expected Data URI format with 'base64,' prefix. "
f'Data length: {len(base64_data_uri)}'
)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(error_msg),
)
else:
error_msg = f"Invalid payload_type: {payload.payload_type}. Accepted values are 'image'"
logger.error(f'{error_msg}')
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(
'Invalid payload_type. Accepted values are "image"'
),
)
except Exception as e:
logger.error(f'Error in generic_inference_handler {str(e)}')
# 1. Decode Base64 string
image_data = extract_decoded_image_data(payload.image_data)
embeddings = image_embedding_service.query_embed(image_data)
Comment on lines +36 to +38
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Single endpoint missing binascii.Error handling.

The batch endpoint (lines 63-73) wraps the decode in a try/except for binascii.Error, but the single endpoint calls extract_decoded_image_data without similar protection. Malformed base64 input will propagate to the global exception handler and return HTTP 500 instead of a user-friendly 400.

🛡️ Proposed fix for consistent error handling
 `@inference_app_router.post`('/v1/query/embeddings')
 `@inject`
 async def image_embedding(
     payload: ImagePayload,
     response_formatter: ResponseFormatter = Depends(
         Provide[CommonContainer.response_formatter]
     ),
     image_embedding_service: ImageEmbedding = Depends(
         Provide[InferenceAppContainer.image_embedding]
     ),
 ):
-    # 1. Decode Base64 string
-    image_data = extract_decoded_image_data(payload.image_data)
+    try:
+        image_data = extract_decoded_image_data(payload.image_data)
+    except binascii.Error:
+        return JSONResponse(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            content=response_formatter.buildErrorResponse(
+                'Invalid base64 image data'
+            ),
+        )
     embeddings = image_embedding_service.query_embed(image_data)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@wavefront/server/apps/inference_app/inference_app/controllers/inference_controller.py`
around lines 36 - 38, The single-image endpoint calls
extract_decoded_image_data(payload.image_data) and then
image_embedding_service.query_embed(image_data) without handling malformed
base64; wrap the call to extract_decoded_image_data in a try/except that catches
binascii.Error (same as the batch endpoint) and convert it into a user-facing
400 response (e.g., raise HTTPException(status_code=400, detail="Malformed
base64 image data") or return the same error object used by the batch path) so
malformed inputs return 400 instead of propagating to the global 500 handler.

if not embeddings:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=response_formatter.buildErrorResponse('Internal server error'),
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(
'No Embedding data is present'
),
)
return JSONResponse(
status_code=status.HTTP_200_OK,
content=response_formatter.buildSuccessResponse(data={'response': embeddings}),
)


@inference_app_router.post('/v1/query/embeddings')
@inference_app_router.post('/v1/query/embeddings/batch')
@inject
async def image_embedding(
payload: ImagePayload,
async def image_embedding_batch(
payload: ImageBatchPayload,
response_formatter: ResponseFormatter = Depends(
Provide[CommonContainer.response_formatter]
),
image_embedding_service: ImageEmbedding = Depends(
Provide[InferenceAppContainer.image_embedding]
),
):
# 1. Decode Base64 string
base64_data_uri = payload.image_data
parts = base64_data_uri.split(',')
base64_data = parts[1] if len(parts) == 2 else parts[0]
image_data = base64.b64decode(base64_data)
embeddings = image_embedding_service.query_embed(image_data)
try:
image_batch = [
extract_decoded_image_data(image_data) for image_data in payload.image_batch
]
except binascii.Error:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(
'Invalid base64 image data in batch'
),
)
embeddings = image_embedding_service.query_embed_batch(image_batch)
if not embeddings:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -175,3 +83,9 @@ async def image_embedding(
status_code=status.HTTP_200_OK,
content=response_formatter.buildSuccessResponse(data={'response': embeddings}),
)


def extract_decoded_image_data(image_data: str) -> bytes:
parts = image_data.split(',')
base64_data = parts[1] if len(parts) == 2 else parts[0]
return base64.b64decode(base64_data)
Comment on lines +88 to +91
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

extract_decoded_image_data can raise unhandled binascii.Error on malformed base64.

base64.b64decode raises binascii.Error for invalid base64 input. This function is called directly without try/except in both endpoints. For the single-image endpoint (line 36), this would cause an HTTP 500. Consider adding error handling here or at call sites.

💡 Option: Add validation with a clearer error
+import binascii
+
+class InvalidImageDataError(Exception):
+    pass
+
 def extract_decoded_image_data(image_data: str) -> bytes:
     parts = image_data.split(',')
     base64_data = parts[1] if len(parts) == 2 else parts[0]
-    return base64.b64decode(base64_data)
+    try:
+        return base64.b64decode(base64_data)
+    except binascii.Error as e:
+        raise InvalidImageDataError('Invalid base64 encoded image data') from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def extract_decoded_image_data(image_data: str) -> bytes:
parts = image_data.split(',')
base64_data = parts[1] if len(parts) == 2 else parts[0]
return base64.b64decode(base64_data)
import binascii
class InvalidImageDataError(Exception):
pass
def extract_decoded_image_data(image_data: str) -> bytes:
parts = image_data.split(',')
base64_data = parts[1] if len(parts) == 2 else parts[0]
try:
return base64.b64decode(base64_data)
except binascii.Error as e:
raise InvalidImageDataError('Invalid base64 encoded image data') from e
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@wavefront/server/apps/inference_app/inference_app/controllers/inference_controller.py`
around lines 72 - 75, The helper extract_decoded_image_data currently calls
base64.b64decode which can raise binascii.Error on malformed input; wrap the
decode call in a try/except that catches binascii.Error (import binascii) and
re-raise a clear ValueError or custom exception (e.g., "Invalid base64 image
data") so caller endpoints can return a 4xx response instead of an unhandled
500; update any call sites (the single-image and multi-image endpoints) to catch
that ValueError and convert it to an appropriate HTTP error response.

Original file line number Diff line number Diff line change
@@ -1,31 +1,9 @@
from dependency_injector import containers
from dependency_injector import providers
from inference_app.service.image_analyser import ImageClarityService
from flo_cloud.cloud_storage import CloudStorageManager
from inference_app.service.model_repository import ModelRepository
from inference_app.service.model_inference import ModelInferenceService
from inference_app.service.image_embedding import ImageEmbedding


class InferenceAppContainer(containers.DeclarativeContainer):
config = providers.Configuration(ini_files=['config.ini'])
cache_manager = providers.Dependency()

cloud_storage_manager = providers.Singleton(
CloudStorageManager, provider=config.cloud_config.cloud_provider
)

model_repository = providers.Singleton(
ModelRepository,
cloud_storage_manager=cloud_storage_manager,
)

model_inference = providers.Singleton(ModelInferenceService)

image_analyser = providers.Singleton(
ImageClarityService,
)

image_embedding = providers.Singleton(
ImageEmbedding,
)
image_embedding = providers.Singleton(ImageEmbedding)
15 changes: 12 additions & 3 deletions wavefront/server/apps/inference_app/inference_app/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import os
from contextlib import asynccontextmanager

from dotenv import load_dotenv
from fastapi import FastAPI
Expand All @@ -22,14 +23,22 @@

# Initialize dependency containers
common_container = CommonContainer(cache_manager=None)
inference_app_container = InferenceAppContainer(
cache_manager=None,
)
inference_app_container = InferenceAppContainer()


@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info('Preloading ML models...')
inference_app_container.image_embedding()
logger.info('ML models loaded and ready.')
yield


app = FastAPI(
title='FloConsole API',
description='Console application for RootFlo platform',
version='1.0.0',
lifespan=lifespan,
)


Expand Down

This file was deleted.

Loading
Loading