# Experiment tracking with MLflow

In this notebook we download a data file from the [NYC TLC trip record data website](https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page) and we train a model to predict the duration of a taxi ride.

We add tracking to our experiment by using MLflow.

## Import libraries

In [3]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge
import pickle

In [5]:
import mlflow

mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc_taxi_experiment")

<Experiment: artifact_location='/home/sgrodriguez/Formación/mlops_dtc/02_experiment_tracking/code/mlruns/1', creation_time=1685268445418, experiment_id='1', last_update_time=1685268445418, lifecycle_stage='active', name='nyc_taxi_experiment', tags={}>

## Download data from website

As an example, we will use Green Taxi data from January 2021 for training, and from February 2021 for validation.

In [6]:
!wget -NP ./data https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-01.parquet
!wget -NP ./data https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-02.parquet

--2023-05-28 12:09:12--  https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-01.parquet
Resolviendo d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)... 18.67.246.167, 18.67.246.47, 18.67.246.176, ...
Conectando con d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)[18.67.246.167]:443... conectado.
Petición HTTP enviada, esperando respuesta... 304 Not Modified
Fichero “../data/green_tripdata_2021-01.parquet” no modificado en el servidor. Omitiendo la descarga.

--2023-05-28 12:09:13--  https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-02.parquet
Resolviendo d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)... 18.67.246.167, 18.67.246.47, 18.67.246.176, ...
Conectando con d37ci6vzurychx.cloudfront.net (d37ci6vzurychx.cloudfront.net)[18.67.246.167]:443... conectado.
Petición HTTP enviada, esperando respuesta... 304 Not Modified
Fichero “../data/green_tripdata_2021-02.parquet” no modificado en el servidor. Omitiendo la d

## Read and preprocess data

In [7]:
def read_dataframe(filename):
    df = pd.read_parquet(filename)
    
    df["duration"] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)
    
    df = df[(df.duration >= 1) & (df.duration <= 60)]
    
    categorical = ["PULocationID", "DOLocationID"]
    df.loc[:, categorical] = df.loc[:, categorical].astype(str)
    
    return df

In [8]:
df_train = read_dataframe("./data/green_tripdata_2021-01.parquet")
df_val = read_dataframe("./data/green_tripdata_2021-02.parquet")

In [9]:
categorical = ["PULocationID", "DOLocationID"]
numerical = ["trip_distance"]

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient="records")
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient="records")
X_val = dv.transform(val_dicts)

## Set target variable

Next we create an array with the values of our target variable.

In [10]:
target = "duration"

y_train = df_train[target].values
y_val = df_val[target].values

## Experiment tracking

In [11]:
with mlflow.start_run():
    
    mlflow.set_tag("developer", "sergiogrz")
    
    mlflow.log_param("train_data_file", "green_tripdata_2021-01.parquet")
    mlflow.log_param("valid_data_file", "green_tripdata_2021-02.parquet")
    
    alpha = 0.6
    mlflow.log_param("alpha", alpha)
    
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)
    
    y_pred = lr.predict(X_val)
    
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

## Hyperparameter tuning

In [12]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope
import xgboost as xgb

### Defining the objective function

In [13]:
#assuming we already have the dataframes in memory, we create the matrices for xgboost
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [14]:
#params contains the hyperparameters for xgboost for a specific run
def objective(params):

    with mlflow.start_run():

        #set a tag for easier classification and log the hyperparameters
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)

        #model definition and training
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )

        #predicting with the validation set
        y_pred = booster.predict(valid)

        #rmse metric and logging
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    #we return a dict with the metric and the OK signal
    return {'loss': rmse, 'status': STATUS_OK}

### Defining the search space

In [15]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),  # [exp(-3), exp(0)] = [0.05, 1]
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

### Minimizing the objective function

In [13]:
best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

[0]	validation-rmse:7.41660                                                                                                         
[1]	validation-rmse:6.49510                                                                                                         
[2]	validation-rmse:6.40147                                                                                                         
[3]	validation-rmse:6.38287                                                                                                         
[4]	validation-rmse:6.37344                                                                                                         
[5]	validation-rmse:6.37555                                                                                                         
[6]	validation-rmse:6.37868                                                                                                         
[7]	validation-rmse:6.38542                                          

[66]	validation-rmse:6.20442                                                                                                        
[67]	validation-rmse:6.20529                                                                                                        
[68]	validation-rmse:6.20586                                                                                                        
[69]	validation-rmse:6.20612                                                                                                        
[70]	validation-rmse:6.20683                                                                                                        
[71]	validation-rmse:6.20791                                                                                                        
[72]	validation-rmse:6.20833                                                                                                        
[73]	validation-rmse:6.20855                                         

[0]	validation-rmse:19.05667                                                                                                        
[1]	validation-rmse:17.19999                                                                                                        
[2]	validation-rmse:15.59132                                                                                                        
[3]	validation-rmse:14.20029                                                                                                        
[4]	validation-rmse:13.00144                                                                                                        
[5]	validation-rmse:11.97136                                                                                                        
[6]	validation-rmse:11.09072                                                                                                        
[7]	validation-rmse:10.33976                                         

[121]	validation-rmse:6.16022                                                                                                       
[122]	validation-rmse:6.16033                                                                                                       
[123]	validation-rmse:6.15973                                                                                                       
[124]	validation-rmse:6.15947                                                                                                       
[125]	validation-rmse:6.15963                                                                                                       
[126]	validation-rmse:6.15914                                                                                                       
[127]	validation-rmse:6.15900                                                                                                       
[128]	validation-rmse:6.15947                                        

[243]	validation-rmse:6.14967                                                                                                       
[244]	validation-rmse:6.14983                                                                                                       
[245]	validation-rmse:6.14957                                                                                                       
[246]	validation-rmse:6.14935                                                                                                       
[247]	validation-rmse:6.14967                                                                                                       
[248]	validation-rmse:6.14943                                                                                                       
[249]	validation-rmse:6.14952                                                                                                       
[250]	validation-rmse:6.14958                                        

[106]	validation-rmse:6.20259                                                                                                       
[107]	validation-rmse:6.20202                                                                                                       
[108]	validation-rmse:6.20243                                                                                                       
[109]	validation-rmse:6.20242                                                                                                       
[110]	validation-rmse:6.20191                                                                                                       
[111]	validation-rmse:6.20398                                                                                                       
[112]	validation-rmse:6.20364                                                                                                       
[113]	validation-rmse:6.20448                                        

[82]	validation-rmse:6.60473                                                                                                        
[83]	validation-rmse:6.60487                                                                                                        
[84]	validation-rmse:6.60531                                                                                                        
[85]	validation-rmse:6.60558                                                                                                        
[0]	validation-rmse:16.67335                                                                                                        
[1]	validation-rmse:13.44400                                                                                                        
[2]	validation-rmse:11.15841                                                                                                        
[3]	validation-rmse:9.58141                                          

[2]	validation-rmse:6.84064                                                                                                         
[3]	validation-rmse:6.75620                                                                                                         
[4]	validation-rmse:6.72625                                                                                                         
[5]	validation-rmse:6.72383                                                                                                         
[6]	validation-rmse:6.73809                                                                                                         
[7]	validation-rmse:6.73938                                                                                                         
[8]	validation-rmse:6.74408                                                                                                         
[9]	validation-rmse:6.75926                                          

[6]	validation-rmse:9.48086                                                                                                         
[7]	validation-rmse:8.81818                                                                                                         
[8]	validation-rmse:8.28571                                                                                                         
[9]	validation-rmse:7.87072                                                                                                         
[10]	validation-rmse:7.54103                                                                                                        
[11]	validation-rmse:7.28445                                                                                                        
[12]	validation-rmse:7.08415                                                                                                        
[13]	validation-rmse:6.92973                                         

[43]	validation-rmse:6.42312                                                                                                        
[44]	validation-rmse:6.41705                                                                                                        
[45]	validation-rmse:6.41237                                                                                                        
[46]	validation-rmse:6.40751                                                                                                        
[47]	validation-rmse:6.40367                                                                                                        
[48]	validation-rmse:6.40038                                                                                                        
[49]	validation-rmse:6.39790                                                                                                        
[50]	validation-rmse:6.39542                                         

KeyboardInterrupt: 

### Retraining with the optimal hyperparameters and automatic logging

In [16]:
params = {
    "learning_rate": 0.1126860623846719,
    "max_depth": 11,
    "min_child_weight": 7.128461099684721,
    "objective": "reg:linear",
    "reg_alpha": 0.04429046957254972,
    "reg_lambda": 0.09902356874800584,
    "seed": 42,
    
}

mlflow.xgboost.autolog()

booster = xgb.train(
    params=params,
    dtrain=train,
    num_boost_round=1000,
    evals=[(valid, 'validation')],
    early_stopping_rounds=50
)

2023/05/27 12:02:27 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID 'a7b6560075ff464889e7a5ca4a546c92', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current xgboost workflow


[0]	validation-rmse:19.18916
[1]	validation-rmse:17.43343
[2]	validation-rmse:15.89883
[3]	validation-rmse:14.55999
[4]	validation-rmse:13.39704
[5]	validation-rmse:12.38651
[6]	validation-rmse:11.51434
[7]	validation-rmse:10.75978
[8]	validation-rmse:10.11576
[9]	validation-rmse:9.56565
[10]	validation-rmse:9.09455
[11]	validation-rmse:8.69152
[12]	validation-rmse:8.35168
[13]	validation-rmse:8.06299
[14]	validation-rmse:7.82050
[15]	validation-rmse:7.61407
[16]	validation-rmse:7.43876
[17]	validation-rmse:7.28958
[18]	validation-rmse:7.16372
[19]	validation-rmse:7.05627
[20]	validation-rmse:6.96427
[21]	validation-rmse:6.88679
[22]	validation-rmse:6.82021
[23]	validation-rmse:6.76296
[24]	validation-rmse:6.71292
[25]	validation-rmse:6.67016
[26]	validation-rmse:6.63300
[27]	validation-rmse:6.59970
[28]	validation-rmse:6.57240
[29]	validation-rmse:6.54777
[30]	validation-rmse:6.52569
[31]	validation-rmse:6.50544
[32]	validation-rmse:6.48907
[33]	validation-rmse:6.47239
[34]	validation

[273]	validation-rmse:6.15399
[274]	validation-rmse:6.15380
[275]	validation-rmse:6.15321
[276]	validation-rmse:6.15285
[277]	validation-rmse:6.15306
[278]	validation-rmse:6.15304
[279]	validation-rmse:6.15330
[280]	validation-rmse:6.15337
[281]	validation-rmse:6.15305
[282]	validation-rmse:6.15252
[283]	validation-rmse:6.15272
[284]	validation-rmse:6.15225
[285]	validation-rmse:6.15223
[286]	validation-rmse:6.15202
[287]	validation-rmse:6.15186
[288]	validation-rmse:6.15211
[289]	validation-rmse:6.15132
[290]	validation-rmse:6.15085
[291]	validation-rmse:6.15054
[292]	validation-rmse:6.15043
[293]	validation-rmse:6.15029
[294]	validation-rmse:6.15019
[295]	validation-rmse:6.15003
[296]	validation-rmse:6.14991
[297]	validation-rmse:6.14975
[298]	validation-rmse:6.14958
[299]	validation-rmse:6.14924
[300]	validation-rmse:6.14928
[301]	validation-rmse:6.14900
[302]	validation-rmse:6.14853
[303]	validation-rmse:6.14792
[304]	validation-rmse:6.14712
[305]	validation-rmse:6.14710
[306]	vali



## Model management

### Log model as an artifact

In [17]:
with open("./models/lin_reg.bin", "wb") as f_out:
    pickle.dump((dv, lr), f_out)

In [18]:
with mlflow.start_run():
    
    mlflow.set_tag("developer", "sergiogrz")
    
    mlflow.log_param("train_data_file", "green_tripdata_2021-01.parquet")
    mlflow.log_param("valid_data_file", "green_tripdata_2021-02.parquet")
    
    alpha = 0.6
    mlflow.log_param("alpha", alpha)
    
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)
    
    y_pred = lr.predict(X_val)
    
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)
    
    # track model
    mlflow.log_artifact(local_path="./models/lin_reg.bin", artifact_path="models_pickle")

### Log model using `log_model` method

In [19]:
# for this example, we turn off autologging
mlflow.xgboost.autolog(disable=True)

In [20]:
with mlflow.start_run():
    
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_val, label=y_val)

    best_params = {
        "learning_rate": 0.1126860623846719,
        "max_depth": 11,
        "min_child_weight": 7.128461099684721,
        "objective": "reg:linear",
        "reg_alpha": 0.04429046957254972,
        "reg_lambda": 0.09902356874800584,
        "seed": 42,
    }
    
    mlflow.log_params(best_params)

    booster = xgb.train(
        params=best_params,
        dtrain=train,
        num_boost_round=1000,
        evals=[(valid, 'validation')],
        early_stopping_rounds=50
    )

    y_pred = booster.predict(valid)

    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)
    
    # log preprocessor
    with open("./models/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)
    mlflow.log_artifact(local_path="./models/preprocessor.b", artifact_path="preprocessor")
    
    # log model
    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")

[0]	validation-rmse:19.18916
[1]	validation-rmse:17.43343
[2]	validation-rmse:15.89883
[3]	validation-rmse:14.55999
[4]	validation-rmse:13.39704
[5]	validation-rmse:12.38651
[6]	validation-rmse:11.51434
[7]	validation-rmse:10.75978
[8]	validation-rmse:10.11576
[9]	validation-rmse:9.56565
[10]	validation-rmse:9.09455
[11]	validation-rmse:8.69152
[12]	validation-rmse:8.35168
[13]	validation-rmse:8.06299
[14]	validation-rmse:7.82050
[15]	validation-rmse:7.61407
[16]	validation-rmse:7.43876
[17]	validation-rmse:7.28958
[18]	validation-rmse:7.16372
[19]	validation-rmse:7.05627
[20]	validation-rmse:6.96427
[21]	validation-rmse:6.88679
[22]	validation-rmse:6.82021
[23]	validation-rmse:6.76296
[24]	validation-rmse:6.71292
[25]	validation-rmse:6.67016
[26]	validation-rmse:6.63300
[27]	validation-rmse:6.59970
[28]	validation-rmse:6.57240
[29]	validation-rmse:6.54777
[30]	validation-rmse:6.52569
[31]	validation-rmse:6.50544
[32]	validation-rmse:6.48907
[33]	validation-rmse:6.47239
[34]	validation

[273]	validation-rmse:6.15399
[274]	validation-rmse:6.15380
[275]	validation-rmse:6.15321
[276]	validation-rmse:6.15285
[277]	validation-rmse:6.15306
[278]	validation-rmse:6.15304
[279]	validation-rmse:6.15330
[280]	validation-rmse:6.15337
[281]	validation-rmse:6.15305
[282]	validation-rmse:6.15252
[283]	validation-rmse:6.15272
[284]	validation-rmse:6.15225
[285]	validation-rmse:6.15223
[286]	validation-rmse:6.15202
[287]	validation-rmse:6.15186
[288]	validation-rmse:6.15211
[289]	validation-rmse:6.15132
[290]	validation-rmse:6.15085
[291]	validation-rmse:6.15054
[292]	validation-rmse:6.15043
[293]	validation-rmse:6.15029
[294]	validation-rmse:6.15019
[295]	validation-rmse:6.15003
[296]	validation-rmse:6.14991
[297]	validation-rmse:6.14975
[298]	validation-rmse:6.14958
[299]	validation-rmse:6.14924
[300]	validation-rmse:6.14928
[301]	validation-rmse:6.14900
[302]	validation-rmse:6.14853
[303]	validation-rmse:6.14792
[304]	validation-rmse:6.14712
[305]	validation-rmse:6.14710
[306]	vali

### Making predictions

In [21]:
logged_model = 'runs:/f9130c479da14355b9918b7a2b8b5569/models_mlflow'

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(model_uri=logged_model)

 - mlflow (current: 2.3.2, required: mlflow==2.3)
 - pandas (current: 1.5.3, required: pandas==2.0.1)
 - typing-extensions (current: 4.6.2, required: typing-extensions==4.5.0)
To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.




In [22]:
loaded_model

mlflow.pyfunc.loaded_model:
  artifact_path: models_mlflow
  flavor: mlflow.xgboost
  run_id: f9130c479da14355b9918b7a2b8b5569

In [23]:
xgboost_model = mlflow.xgboost.load_model(model_uri=logged_model)



In [24]:
xgboost_model

<xgboost.core.Booster at 0x7faca9a88a00>

In [25]:
y_pred = xgboost_model.predict(valid)

In [26]:
y_pred[:10]

array([15.17474 ,  6.258108, 18.17024 , 22.549145,  9.738305, 14.481582,
       13.229766,  9.051595,  8.43366 , 17.937893], dtype=float32)

## Model registry

In [27]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR

mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):

    with mlflow.start_run():

        mlflow.log_param("train_data_file", "green_tripdata_2021-01.parquet")
        mlflow.log_param("valid_data_file", "green_tripdata_2021-02.parquet")
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

