# MLflow XGBoost Iris Model Registration and Prediction Logging

This notebook demonstrates how to train an XGBoost classifier on the Iris dataset, wrap it as an MLflow Python model, and log the model into the MLflow Model Registry. Additionally, the notebook captures and logs predictions with metadata such as event IDs and timestamps for each prediction event.

In this notebook, you will:
* Train an **XGBoost** classifier using the Iris dataset.
* Wrap the trained model in a custom **MLflow Pyfunc** to enable flexible prediction and logging.
* Register the wrapped model in the **MLflow Model Registry** with a dynamic model name.
* Capture prediction events with metadata like **event IDs** and **timestamps** for traceability.

## Steps Covered:
1. Load the **Iris dataset** and split it into training and test sets.
2. Train an **XGBoost Classifier** with predefined hyperparameters.
3. Wrap the model using the **MLflow Pyfunc** class to facilitate custom prediction logic.
4. Log the model to MLflow with a dynamic name based on the current user and project.
5. Record each prediction event along with associated metadata (event ID, timestamp, input values).

## Dataset Overview:
The **Iris dataset** consists of 150 samples from three species of iris flowers (**Setosa**, **Versicolor**, and **Virginica**). Each sample has four features:
1. **Sepal length** (in centimeters)
2. **Sepal width** (in centimeters)
3. **Petal length** (in centimeters)
4. **Petal width** (in centimeters)

The model will classify an iris flower into one of the three species, where the output classes correspond to:
- **0**: Setosa
- **1**: Versicolor
- **2**: Virginica

### Example Test Data:
Once the model is deployed as an API, you can test it with the following JSON payload:

```json
{
  "data": [
    [5.1, 3.5, 1.4, 0.2], 
    [4.9, 3.0, 1.4, 0.2]
  ]
}
```

## Reference Project:
* https://github.com/dominodatalab/reference-project-domino-mlflow-supported-models/blob/main/domino-mlflow-model-xgboost-imm.ipynb

In [None]:
# Import libraries
import os
import mlflow
from mlflow.exceptions import MlflowException

In [None]:
# Load Data
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
    data["data"], data["target"], test_size=0.2
)


In [None]:
#Create XGBoost Model and wrap as a mlflow Python model
from xgboost import XGBClassifier
from domino_data_capture.data_capture_client import DataCaptureClient
import uuid
import datetime

xgb_classifier = XGBClassifier(
    n_estimators=10,
    max_depth=3,
    learning_rate=1,
    objective="binary:logistic",
    random_state=123,
)

# train model
xgb_classifier.fit(X_train, y_train)


class IrisModel(mlflow.pyfunc.PythonModel):
    def __init__(self,model):
        self.model = model
    
    def predict(self, context, model_input, params=None):
        event_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
        prediction = self.model.predict(model_input)
        
        for i in range(len(prediction)):
            # Record eventID and current time
            event_id = uuid.uuid4()
            # Convert np types to python builtin type to allow JSON serialization by prediction capture library
            model_input_value = [float(x) for x in model_input[i]]
            prediction_value = [data.target_names[prediction[i]]]
            

        return prediction

model = IrisModel(xgb_classifier)

In [None]:
# Get the username and project name for the dynamic model name
username = os.environ['DOMINO_STARTING_USERNAME']
project_name = os.environ['DOMINO_PROJECT_NAME']
model_name = f"pyfunc-xgboost-model-{username}-{project_name}"

# Initialize MLflow client
client = mlflow.tracking.MlflowClient()

# Check if the model is already registered
try:
    # Attempt to get the registered model
    model_details = client.get_registered_model(model_name)
    print(f"Model '{model_name}' already exists. Creating a new version.")
    
    # Model exists, so log a new version
    with mlflow.start_run() as run:
        model_info = mlflow.pyfunc.log_model(
            registered_model_name=model_name,  # Register as a new version of the existing model
            python_model=model,
            artifact_path="test-model"
        )

except MlflowException as e:
    if "RESOURCE_DOES_NOT_EXIST" in str(e):
        print(f"Model '{model_name}' does not exist. Creating a new model.")
        
        # Model does not exist, so register it for the first time
        with mlflow.start_run() as run:
            model_info = mlflow.pyfunc.log_model(
                registered_model_name=model_name,  # Register a new model
                python_model=model,
                artifact_path="test-model"
            )
    else:
        print(f"Error checking model: {e}")

# Print out model information
print(model_info)
