In [None]:
from typing import NamedTuple
from kfp import dsl, compiler

# Define the prediction component
@dsl.component(
    base_image='tensorflow/tensorflow:latest',
    packages_to_install=['pandas', 's3fs', 'boto3', 'numpy', 'tensorflow', 'Pillow']
)
def predict_image_from_minio(
    image_name: str, 
    s3_bucket: str, 
    s3_endpoint: str, 
    access_key: str, 
    secret_key: str
) -> int:
    """
    This function loads an image from MinIO, preprocesses it, and uses the MNIST model to predict the digit.
    """

    import numpy as np
    import boto3
    from tensorflow.keras.models import load_model
    from tensorflow.keras.preprocessing.image import img_to_array, load_img
    import os

    # Step 1: Set up the MinIO client using boto3
    s3 = boto3.client(
        's3',
        endpoint_url=s3_endpoint,
        aws_access_key_id=access_key,
        aws_secret_access_key=secret_key
    )

    # Step 2: Load the MNIST model from MinIO
    model_key = 'sushant_model/detect-digits.h5'
    response_model = s3.get_object(Bucket=s3_bucket, Key=model_key)

    # Save the model content to a temporary file
    model_temp_path = "/tmp/mnist_model.h5"
    with open(model_temp_path, 'wb') as model_file:
        model_file.write(response_model['Body'].read())
    mnist_model = load_model(model_temp_path)

    # Step 3: Preprocess the image for MNIST model
    def preprocess_image(image_path: str) -> np.ndarray:
        """
        Preprocesses the image to match the input format required by the MNIST model.
        """
        img = load_img(image_path, color_mode="grayscale", target_size=(28, 28))
        img_array = img_to_array(img)
        img_array = img_array / 255.0  # Normalize the image
        img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
        return img_array

    # Step 4: Download the image from MinIO
    image_object = s3.get_object(Bucket=s3_bucket, Key=image_name)
    image_temp_path = '/tmp/' + os.path.basename(image_name)  # Save to a temporary file
    with open(image_temp_path, 'wb') as image_file:
        image_file.write(image_object['Body'].read())

    # Step 5: Preprocess the downloaded image
    preprocessed_image = preprocess_image(image_temp_path)

    # Step 6: Make a prediction using the MNIST model
    predictions = mnist_model.predict(preprocessed_image)
    predicted_digit = np.argmax(predictions)  # Get the predicted digit as the class with the highest probability

    # Ensure that the predicted digit is returned as a native Python int
    return int(predicted_digit)

# Define the pipeline
@dsl.pipeline(name="mnist-prediction-pipeline", description="A pipeline to predict digits using MNIST model from MinIO")
def mnist_pipeline(
    image_name: str,
    s3_bucket: str,
    s3_endpoint: str,
    access_key: str,
    secret_key: str
):
    # Add the predict component to the pipeline
    prediction = predict_image_from_minio(
        image_name=image_name,
        s3_bucket=s3_bucket,
        s3_endpoint=s3_endpoint,
        access_key=access_key,
        secret_key=secret_key
    )

    # Print the prediction
    prediction.set_display_name("Digit Prediction")

# Compile the pipeline
pipeline_file_path = "latest_prediction_pipeline.yaml"
compiler.Compiler().compile(
    pipeline_func=mnist_pipeline,
    package_path=pipeline_file_path
)

print(f"Pipeline definition saved to {pipeline_file_path}")
