Objective
=========

The objective of this notebook is to publish the model and demo code in order to interact with the model on a web page.

## Import Modules

In [1]:
# Standard Library Imports
import base64
import io
from pathlib import Path

# Third-Party Libraries
from PIL import Image
import numpy as np

# MLflow for Experiment Tracking and Model Management
import mlflow
from mlflow import MlflowClient
from mlflow.types.schema import Schema, ColSpec
from mlflow.types import ParamSchema, ParamSpec
from mlflow.models import ModelSignature
from mlflow.pyfunc.model import PythonModel

# Import Inference API
import onnxruntime as ort

# For user customizations
from userconfig import UserConfig, open_user_config

## Model Definition

In [2]:
CLASS_COLORS = {
    0: (0, 0, 0),       # Background
    1: (255, 0, 0),     # Class 1
    2: (0, 255, 0),     # Class 2
}


class ONNXSegmentationWrapper(PythonModel):

    def load_context(self, context):
        self.model = context.artifacts['onnx_model']

    def _decode_base64_image(self, b64_string):
        image_data = base64.b64decode(b64_string)
        image = Image.open(io.BytesIO(image_data)).convert("RGB")
        return np.array(image)

    def _preprocess(self, np_image: np.ndarray) -> np.ndarray:
        # Transpose to CHW format and normalize if needed
        image = np_image.transpose(2, 0, 1).astype(np.float32) / 255.0
        return np.expand_dims(image, axis=0)

    def _postprocess(self, output) -> np.ndarray:
        # Assume output shape is [1, num_classes, H, W]
        logits = output[0]  # take first (and only) batch
        prediction: np.ndarray = np.argmax(logits, axis=0)  # shape: (H, W)

        # Map class indices to RGB
        h, w = prediction.shape
        rgb_image = np.zeros((h, w, 3), dtype=np.uint8)
        for cls, color in CLASS_COLORS.items():
            rgb_image[prediction == cls] = color

        return rgb_image

    def _encode_base64_image(self, rgb_image: np.ndarray) -> str:
        image = Image.fromarray(rgb_image)
        buffered = io.BytesIO()
        image.save(buffered, format='PNG')
        return base64.b64encode(buffered.getvalue()).decode('utf-8')

    def predict(self, context, model_input: list[str], params=None) -> list[str]:
        """
        model_input: DataFrame with one column 'image', each row is a base64 string
        """
        session = ort.InferenceSession(context.artifacts["onnx_model"])
        input_name = session.get_inputs()[0].name

        results: list[str] = []
        for b64_string in model_input:
            np_image = self._decode_base64_image(b64_string)
            input_tensor = self._preprocess(np_image)
            outputs = session.run(None, {input_name: input_tensor})
            rgb_output = self._postprocess(outputs)
            result_b64 = self._encode_base64_image(rgb_output)
            results.append(result_b64)
        return results

## Log Model

In [3]:
config: UserConfig = open_user_config()

mlflow.set_tracking_uri(config.mlflow_tracking_uri)
mlflow.set_experiment(config.mlflow_experiment_name)

artifacts: dict[str, str] = {
    'onnx_model': config.best_model_path,
    'demo': config.demo_dir
}

with mlflow.start_run() as run:
    mlflow.pyfunc.log_model(
        artifact_path='neuron_unet',
        python_model=ONNXSegmentationWrapper(),
        artifacts=artifacts
    )
    mlflow.register_model(model_uri = f"runs:/{run.info.run_id}/{config.model_name}", name=config.model_name)

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/3 [00:00<?, ?it/s]

Successfully registered model 'neuron_unet'.
Created version '1' of model 'neuron_unet'.
