In [2]:
import mlflow
import pandas as pd
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from mlflow.models import infer_signature

In [None]:
X,y = datasets.load_iris(return_X_y=True)
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2)
# print(X.shape)        #(150, 4)
# print(type(X))        #<class 'numpy.ndarray'>
# print(type(X_train))  #<class 'numpy.ndarray'>
# print(X[:5])          #[[5.1 3.5 1.4 0.2]
                        #[4.9 3.  1.4 0.2]
                        #[4.7 3.2 1.3 0.2]
                        #[4.6 3.1 1.5 0.2]
                        #[5.  3.6 1.4 0.2]]

In [18]:
params = {
    "penalty":"l2",
    "solver":"lbfgs",
    "max_iter":1000,
    "multi_class":"auto",
    "random_state":8888,}
lr = LogisticRegression(**params)
lr.fit(X_train,y_train)
y_pred = lr.predict(X_test)
accuracy = accuracy_score(y_test,y_pred)
accuracy



In [43]:
"""MLFLOW tracking
uri
experiment
- params
- metrics
- input/output
- model
"""
mlflow.set_tracking_uri(uri="http://127.0.0.1:5000")
mlflow.set_experiment("MLFLOW QuickStart") # experiment name
with mlflow.start_run():
    mlflow.log_params(params)
    mlflow.log_metric("accuracy",1.0) 
    mlflow.set_tag("Training Info3","Basic LR model for iris data") # this is tag associated with our experiment it is always in key value pairs with training infor being the key
    signature = infer_signature(X_train,lr.predict(X_train)) # we are specifying the schema of input and output mlflow will understand from the X_train and our predictions
    model_info = mlflow.sklearn.log_model( # list of basic things we need to specify to save the model
        sk_model = lr,
        artifact_path="Iris_model", # path at ehcih the model will be saved by mlflow
        signature = signature,
        input_example = X_train)
        # registered_model_name="tracking-quickstart")

🏃 View run bemused-moth-117 at: http://127.0.0.1:5000/#/experiments/408467526766651123/runs/77b670f559c84295939a1b1420b7eb9e
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/408467526766651123


In [30]:
params = {
    "solver":"newton-cg",
    "max_iter":1000,
    "multi_class":"auto",
    "random_state":8888,}
lr = LogisticRegression(**params)
lr.fit(X_train,y_train)
y_pred = lr.predict(X_test)
accuracy = accuracy_score(y_test,y_pred)
accuracy



0.9666666666666667

# Inference code from MLFlow ui 

'runs:/22902720963e4f3bae2df56bebc0833f/Iris_model'

In [33]:
#the code below has been copied from the mlflow ui. Serving payload is the form in whcih this saved moels expects the input. theis is one of the two approaches. Next cell explains the other approach
from mlflow.models import validate_serving_input

model_uri = 'runs:/22902720963e4f3bae2df56bebc0833f/Iris_model'
# or we can write
# model_uri = model_info.model_info.model_uri

# The model is logged with an input example. MLflow converts
# it into the serving payload format for the deployed model endpoint,
# and saves it to 'serving_input_payload.json'

serving_payload = """{
  "inputs": [
    [
      6.3,
      3.4,
      5.6,
      2.4
    ],
    [
      5.6,
      2.5,
      3.9,
      1.1
    ],
    [
      6.7,
      3.0,
      5.2,
      2.3
    ],
    [
      6.8,
      3.0,
      5.5,
      2.1
    ],
    [
      4.6,
      3.1,
      1.5,
      0.2
    ],
    [
      5.4,
      3.9,
      1.7,
      0.4
    ],
    [
      5.5,
      2.4,
      3.8,
      1.1
    ],
    [
      7.4,
      2.8,
      6.1,
      1.9
    ],
    [
      5.8,
      2.7,
      5.1,
      1.9
    ],
    [
      6.1,
      2.6,
      5.6,
      1.4
    ],
    [
      6.1,
      2.8,
      4.7,
      1.2
    ],
    [
      5.0,
      3.5,
      1.6,
      0.6
    ],
    [
      6.6,
      2.9,
      4.6,
      1.3
    ],
    [
      4.4,
      3.2,
      1.3,
      0.2
    ],
    [
      5.6,
      3.0,
      4.1,
      1.3
    ],
    [
      4.9,
      3.0,
      1.4,
      0.2
    ],
    [
      4.6,
      3.2,
      1.4,
      0.2
    ],
    [
      6.4,
      2.7,
      5.3,
      1.9
    ],
    [
      5.1,
      3.3,
      1.7,
      0.5
    ],
    [
      5.1,
      3.4,
      1.5,
      0.2
    ],
    [
      6.7,
      3.1,
      5.6,
      2.4
    ],
    [
      6.2,
      2.9,
      4.3,
      1.3
    ],
    [
      6.7,
      3.3,
      5.7,
      2.5
    ],
    [
      6.4,
      3.2,
      4.5,
      1.5
    ],
    [
      4.8,
      3.4,
      1.9,
      0.2
    ],
    [
      7.7,
      3.0,
      6.1,
      2.3
    ],
    [
      6.2,
      2.8,
      4.8,
      1.8
    ],
    [
      5.7,
      2.6,
      3.5,
      1.0
    ],
    [
      7.9,
      3.8,
      6.4,
      2.0
    ],
    [
      5.1,
      3.8,
      1.5,
      0.3
    ],
    [
      5.0,
      3.4,
      1.5,
      0.2
    ],
    [
      5.5,
      2.3,
      4.0,
      1.3
    ],
    [
      6.7,
      2.5,
      5.8,
      1.8
    ],
    [
      5.6,
      2.8,
      4.9,
      2.0
    ],
    [
      6.5,
      3.0,
      5.8,
      2.2
    ],
    [
      6.6,
      3.0,
      4.4,
      1.4
    ],
    [
      5.0,
      3.6,
      1.4,
      0.2
    ],
    [
      5.6,
      2.9,
      3.6,
      1.3
    ],
    [
      5.1,
      3.5,
      1.4,
      0.3
    ],
    [
      5.6,
      3.0,
      4.5,
      1.5
    ],
    [
      4.8,
      3.0,
      1.4,
      0.3
    ],
    [
      5.0,
      3.3,
      1.4,
      0.2
    ],
    [
      4.7,
      3.2,
      1.6,
      0.2
    ],
    [
      4.5,
      2.3,
      1.3,
      0.3
    ],
    [
      5.8,
      2.7,
      3.9,
      1.2
    ],
    [
      5.3,
      3.7,
      1.5,
      0.2
    ],
    [
      5.4,
      3.7,
      1.5,
      0.2
    ],
    [
      5.1,
      3.5,
      1.4,
      0.2
    ],
    [
      6.3,
      2.8,
      5.1,
      1.5
    ],
    [
      5.8,
      2.7,
      4.1,
      1.0
    ],
    [
      7.3,
      2.9,
      6.3,
      1.8
    ],
    [
      6.1,
      3.0,
      4.9,
      1.8
    ],
    [
      7.2,
      3.0,
      5.8,
      1.6
    ],
    [
      5.5,
      2.4,
      3.7,
      1.0
    ],
    [
      7.2,
      3.2,
      6.0,
      1.8
    ],
    [
      6.3,
      3.3,
      4.7,
      1.6
    ],
    [
      5.2,
      3.5,
      1.5,
      0.2
    ],
    [
      5.9,
      3.0,
      4.2,
      1.5
    ],
    [
      5.0,
      3.4,
      1.6,
      0.4
    ],
    [
      5.6,
      2.7,
      4.2,
      1.3
    ],
    [
      4.6,
      3.6,
      1.0,
      0.2
    ],
    [
      5.5,
      4.2,
      1.4,
      0.2
    ],
    [
      7.7,
      3.8,
      6.7,
      2.2
    ],
    [
      5.7,
      2.8,
      4.5,
      1.3
    ],
    [
      4.6,
      3.4,
      1.4,
      0.3
    ],
    [
      5.7,
      2.8,
      4.1,
      1.3
    ],
    [
      6.9,
      3.1,
      5.4,
      2.1
    ],
    [
      4.4,
      2.9,
      1.4,
      0.2
    ],
    [
      4.8,
      3.0,
      1.4,
      0.1
    ],
    [
      6.1,
      2.9,
      4.7,
      1.4
    ],
    [
      5.4,
      3.9,
      1.3,
      0.4
    ],
    [
      5.5,
      2.5,
      4.0,
      1.3
    ],
    [
      6.0,
      2.2,
      4.0,
      1.0
    ],
    [
      5.4,
      3.4,
      1.7,
      0.2
    ],
    [
      6.9,
      3.1,
      5.1,
      2.3
    ],
    [
      6.4,
      2.8,
      5.6,
      2.2
    ],
    [
      5.0,
      2.3,
      3.3,
      1.0
    ],
    [
      4.9,
      2.4,
      3.3,
      1.0
    ],
    [
      4.4,
      3.0,
      1.3,
      0.2
    ],
    [
      6.7,
      3.1,
      4.4,
      1.4
    ],
    [
      5.1,
      3.7,
      1.5,
      0.4
    ],
    [
      7.7,
      2.8,
      6.7,
      2.0
    ],
    [
      6.3,
      3.3,
      6.0,
      2.5
    ],
    [
      5.2,
      4.1,
      1.5,
      0.1
    ],
    [
      5.0,
      3.0,
      1.6,
      0.2
    ],
    [
      6.0,
      2.7,
      5.1,
      1.6
    ],
    [
      6.2,
      3.4,
      5.4,
      2.3
    ],
    [
      7.1,
      3.0,
      5.9,
      2.1
    ],
    [
      6.0,
      2.2,
      5.0,
      1.5
    ],
    [
      4.9,
      3.1,
      1.5,
      0.1
    ],
    [
      6.3,
      2.3,
      4.4,
      1.3
    ],
    [
      6.4,
      2.9,
      4.3,
      1.3
    ],
    [
      6.0,
      3.4,
      4.5,
      1.6
    ],
    [
      6.1,
      2.8,
      4.0,
      1.3
    ],
    [
      6.1,
      3.0,
      4.6,
      1.4
    ],
    [
      5.9,
      3.2,
      4.8,
      1.8
    ],
    [
      6.4,
      3.1,
      5.5,
      1.8
    ],
    [
      6.3,
      2.5,
      5.0,
      1.9
    ],
    [
      6.8,
      2.8,
      4.8,
      1.4
    ],
    [
      4.3,
      3.0,
      1.1,
      0.1
    ],
    [
      6.5,
      2.8,
      4.6,
      1.5
    ],
    [
      4.9,
      2.5,
      4.5,
      1.7
    ],
    [
      5.1,
      2.5,
      3.0,
      1.1
    ],
    [
      5.8,
      4.0,
      1.2,
      0.2
    ],
    [
      5.2,
      2.7,
      3.9,
      1.4
    ],
    [
      6.3,
      2.5,
      4.9,
      1.5
    ],
    [
      6.4,
      3.2,
      5.3,
      2.3
    ],
    [
      6.7,
      3.3,
      5.7,
      2.1
    ],
    [
      7.2,
      3.6,
      6.1,
      2.5
    ],
    [
      5.7,
      4.4,
      1.5,
      0.4
    ],
    [
      6.0,
      3.0,
      4.8,
      1.8
    ],
    [
      4.9,
      3.1,
      1.5,
      0.2
    ],
    [
      6.5,
      3.2,
      5.1,
      2.0
    ],
    [
      5.0,
      3.5,
      1.3,
      0.3
    ],
    [
      5.5,
      3.5,
      1.3,
      0.2
    ],
    [
      5.2,
      3.4,
      1.4,
      0.2
    ],
    [
      5.4,
      3.4,
      1.5,
      0.4
    ],
    [
      6.2,
      2.2,
      4.5,
      1.5
    ],
    [
      6.9,
      3.1,
      4.9,
      1.5
    ],
    [
      5.0,
      3.2,
      1.2,
      0.2
    ]
  ]
}"""

# Validate the serving payload works on the model
validate_serving_input(model_uri, serving_payload)

array([2, 1, 2, 2, 0, 0, 1, 2, 2, 2, 1, 0, 1, 0, 1, 0, 0, 2, 0, 0, 2, 1,
       2, 1, 0, 2, 2, 1, 2, 0, 0, 1, 2, 2, 2, 1, 0, 1, 0, 1, 0, 0, 0, 0,
       1, 0, 0, 0, 2, 1, 2, 2, 2, 1, 2, 1, 0, 1, 0, 1, 0, 0, 2, 1, 0, 1,
       2, 0, 0, 1, 0, 1, 1, 0, 2, 2, 1, 1, 0, 1, 0, 2, 2, 0, 0, 2, 2, 2,
       2, 0, 1, 1, 1, 1, 1, 2, 2, 2, 1, 0, 1, 1, 1, 0, 1, 1, 2, 2, 2, 0,
       2, 0, 2, 0, 0, 0, 0, 1, 1, 0])

In [36]:
##load the model back for prediction as a generic python
# this approach dont require specification of input and output
# also this line willbe same nomatter the kind of model or its library be it sklearn or pytorch etc
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
predictions = loaded_model.predict(X_test)
#lets print the entire thing - not important
results = pd.DataFrame(X_test,columns= datasets.load_iris().feature_names)
results["Actual_class"] = y_test
results["Predicted class"] = predictions
results[:5]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),Actual_class,Predicted class
0,6.8,3.2,5.9,2.3,2,2
1,5.1,3.8,1.6,0.2,0,0
2,4.8,3.4,1.6,0.2,0,0
3,6.5,3.0,5.2,2.0,2,2
4,7.0,3.2,4.7,1.4,1,1


In [55]:
import mlflow.sklearn

model_name = "tracking-quickstart"
model_version = "4"
model_uri = f"models:/{model_name}/{model_version}"

loaded_model = mlflow.sklearn.load_model(model_uri)
loaded_model

In [57]:
loaded_model.predict(X_test)

array([2, 0, 0, 2, 1, 2, 1, 0, 0, 2, 1, 0, 1, 1, 1, 2, 2, 1, 2, 2, 0, 1,
       2, 2, 0, 2, 2, 2, 2, 1])

# MLflow Cheatsheet

## Experiment Tracking

### Run MLflow UI
Run `mlflow ui` in the terminal to launch the UI for visualizing experiments.

### Code Workflow

1. **Import MLflow and infer_signature**  
   import mlflow
   from mlflow.models import infer_signature

2. **Set Tracking URI**  
   `mlflow.set_tracking_uri("<your_tracking_server>")`

3. **Set Experiment**  
   `mlflow.set_experiment("<experiment_name>")`

4. **Start a Run**  
   `with mlflow.start_run():`  
   Add your experiment code inside this block.

5. **Log Parameters**  
   `mlflow.log_param("<param_name>", value)`

6. **Log Metrics**  
   `mlflow.log_metric("<metric_name>", value)`

7. **Set Tags**  
   `mlflow.set_tag("<tag_name>", value)`

8. **Infer Signature** (Optional)  
   Use `from mlflow.models.signature import infer_signature` to infer the input-output schema. Example:  
   `signature = infer_signature(input_data, model_output)`

9. **Log Model**  
   `mlflow.sklearn.log_model(sk_model=<model>, artifact_path="<path>", signature=<signature>, input_example=<example>)`

10. **Register Model**  
    Register the model via the MLflow UI or programmatically if needed.

---

## Load Model from Registry

1. **Load the Model**  
   `from mlflow.pyfunc import load_model`  
   `loaded_model = load_model(model_uri=f"models:/{model_name}/{model_version}")`

2. **Validate Model**  
   Use `predictions = loaded_model.predict(input_data)` to test the model.

---

## Additional Tips

- Use `mlflow.log_artifact("<file_path>")` to log additional files like plots or reports.  
- Compare runs and manage models using the MLflow UI.  
- Regularly test and validate your setup to ensure correct tracking.
