# LOAD LIBRARIES

In [1]:
import mlflow
import mlflow.xgboost
import xgboost as xgb

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

import matplotlib.pyplot as plt
import pandas as pd
import os

# SETTINGS

In [2]:
mlflow_arn = "arn:aws:sagemaker:eu-west-1:575618486322:mlflow-tracking-server/dev-mlflow"
mlflow_experiment_name = "02-sample-experiment"

# SET MLFLOW

In [3]:
mlflow.set_tracking_uri(mlflow_arn)
mlflow.set_experiment(mlflow_experiment_name)

<Experiment: artifact_location='s3://ipf-sds-datalake-dev-data-science-bucket/mlflow/2', creation_time=1730285259326, experiment_id='2', last_update_time=1730285259326, lifecycle_stage='active', name='02-sample-experiment', tags={}>

# LOAD DATA

In [4]:
# Load Diabetes dataset
data = load_diabetes()
X = data.data
y = data.target

# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# BUILD A MODEL

In [5]:
# Create and train model
model = xgb.XGBRegressor(n_estimators=100, max_depth=3, random_state=42)
model.fit(X_train, y_train)

# Make predictions
y_pred = model.predict(X_test)

# Calculate metrics
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

In [8]:
# Start a new MLflow run
with mlflow.start_run(run_name = "input-example") as run:
    run_id = run.info.run_id
    print(f"Run ID: {run_id}")

    # Log parameters directly from the model
    params = model.get_params()
    for param, value in params.items():
        mlflow.log_param(param, value)

    # Log metrics
    mlflow.log_metric("mse", mse)
    mlflow.log_metric("r2", r2)

    mlflow.set_tag("model_type", "xgboost")

    # Log a sample input
    input_example = pd.DataFrame([X_test[0]], columns=data.feature_names)
    
    # Log the model
    mlflow.xgboost.log_model(xgb_model = model, artifact_path = "mymodel", input_example = input_example)

print("Logging completed.")

Run ID: 57119b7eb46d4fbd9b5bafb37673bd98




Downloading artifacts:   0%|          | 0/7 [00:00<?, ?it/s]

2024/10/30 12:32:07 INFO mlflow.tracking._tracking_service.client: 🏃 View run input-example at: https://eu-west-1.experiments.sagemaker.aws/#/experiments/2/runs/57119b7eb46d4fbd9b5bafb37673bd98.
2024/10/30 12:32:07 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://eu-west-1.experiments.sagemaker.aws/#/experiments/2.


Logging completed.
