In [1]:
import pandas as pd

In [2]:
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
import pickle

In [4]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso

from sklearn.metrics import mean_squared_error

In [5]:
import mlflow

mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location='/workspaces/mlops-zoomcamp/02-experiment_tracking/mlruns/1', creation_time=1721214263552, experiment_id='1', last_update_time=1721214263552, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [6]:
def read_dataframe(filename):
    df = pd.read_parquet(filename)

    df["duration"] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)
    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    return df

In [7]:
df_train = read_dataframe("data/green_tripdata_2021-01.parquet")
df_val = read_dataframe("data/green_tripdata_2021-02.parquet")

In [8]:
df_train['PU_DO'] = df_train['PULocationID'] + "_" + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + "_" + df_val['DOLocationID']

In [9]:
categorical = ["PU_DO"] #['PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient="records")
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient="records")
X_val = dv.transform(val_dicts)

In [10]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [11]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_preds = lr.predict(X_val)

mean_squared_error(y_val, y_preds, squared=False)



7.758715209663881

In [13]:
with open("models/lin-reg.bin", "wb") as f_out:
    pickle.dump((dv, lr), f_out)

In [17]:
with mlflow.start_run():

    mlflow.set_tag("developer", "sam")
    
    mlflow.log_param("train-data-path", "data/green_tripdata_2021-01.parquet")
    mlflow.log_param("val-data-path", "data/green_tripdata_2021-02.parquet")

    alpha = 0.1
    mlflow.log_param("alpha", alpha)
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_preds = lr.predict(X_val)

    rmse = mean_squared_error(y_val, y_preds, squared=False)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="models/lin-reg.bin", artifact_path="models_pickle")



Experiment Tracking with MLFLow

In [13]:
import xgboost as xgb

from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [14]:
train = xgb.DMatrix(X_train, y_train)
valid = xgb.DMatrix(X_val, y_val)

In [22]:
def objective(params):
    
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, "validation")],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {"loss": rmse, "status": STATUS_OK}

In [23]:
search_space = {
    "max_depth": scope.int(hp.quniform("max_depth", 4, 100, 1)),
    "learning_rate": hp.loguniform("learning_rate", -3, 0),
    "reg_alpha": hp.loguniform("reg_alpha", -5, -1),
    "reg_lambda": hp.loguniform("reg_lambda", -6, -1),
    "min_child_weight": hp.loguniform("min_child_weight", -1, 3),
    "objective": "reg:linear",
    "seed": 42,
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:11.35101                          
[1]	validation-rmse:10.61095                          
[2]	validation-rmse:9.97736                           
[3]	validation-rmse:9.43859                           
[4]	validation-rmse:8.98106                           
[5]	validation-rmse:8.59677                           
[6]	validation-rmse:8.27478                           
[7]	validation-rmse:8.00217                           
[8]	validation-rmse:7.77389                           
[9]	validation-rmse:7.58463                           
[10]	validation-rmse:7.42911                          
[11]	validation-rmse:7.29590                          
[12]	validation-rmse:7.18432                          
[13]	validation-rmse:7.09500                          
[14]	validation-rmse:7.01815                          
[15]	validation-rmse:6.95581                          
[16]	validation-rmse:6.89998                          
[17]	validation-rmse:6.85578                          
[18]	valid





[0]	validation-rmse:10.08791                                                      
[1]	validation-rmse:8.72171                                                       
[2]	validation-rmse:7.86803                                                       
[3]	validation-rmse:7.34887                                                       
[4]	validation-rmse:7.03388                                                       
[5]	validation-rmse:6.84566                                                       
[6]	validation-rmse:6.72772                                                       
[7]	validation-rmse:6.65451                                                       
[8]	validation-rmse:6.60724                                                       
[9]	validation-rmse:6.57418                                                       
[10]	validation-rmse:6.55178                                                      
[11]	validation-rmse:6.53452                                                      
[12]





[0]	validation-rmse:11.06880                                                      
[1]	validation-rmse:10.13589                                                      
[2]	validation-rmse:9.38185                                                       
[3]	validation-rmse:8.77609                                                       
[4]	validation-rmse:8.29295                                                       
[5]	validation-rmse:7.91297                                                       
[6]	validation-rmse:7.61187                                                       
[7]	validation-rmse:7.37692                                                       
[8]	validation-rmse:7.19291                                                       
[9]	validation-rmse:7.04849                                                       
[10]	validation-rmse:6.93594                                                      
[11]	validation-rmse:6.84533                                                      
[12]





[0]	validation-rmse:11.58306                                                       
[1]	validation-rmse:11.01649                                                       
[2]	validation-rmse:10.50761                                                       
[3]	validation-rmse:10.05159                                                       
[4]	validation-rmse:9.64453                                                        
[5]	validation-rmse:9.28100                                                        
[6]	validation-rmse:8.95723                                                        
[7]	validation-rmse:8.67070                                                        
[8]	validation-rmse:8.41616                                                        
[9]	validation-rmse:8.19136                                                        
[10]	validation-rmse:7.99407                                                       
[11]	validation-rmse:7.81902                                                





[0]	validation-rmse:11.38438                                                      
[1]	validation-rmse:10.66396                                                      
[2]	validation-rmse:10.04254                                                      
[3]	validation-rmse:9.50589                                                       
[4]	validation-rmse:9.04313                                                       
[5]	validation-rmse:8.64902                                                       
[6]	validation-rmse:8.31400                                                       
[7]	validation-rmse:8.02919                                                       
[8]	validation-rmse:7.78782                                                       
[9]	validation-rmse:7.58243                                                       
[10]	validation-rmse:7.41070                                                      
[11]	validation-rmse:7.26567                                                      
[12]





[0]	validation-rmse:11.64935                                                      
[1]	validation-rmse:11.13637                                                      
[2]	validation-rmse:10.67028                                                      
[3]	validation-rmse:10.24767                                                      
[4]	validation-rmse:9.86524                                                       
[5]	validation-rmse:9.51984                                                       
[6]	validation-rmse:9.20814                                                       
[7]	validation-rmse:8.92737                                                       
[8]	validation-rmse:8.67519                                                       
[9]	validation-rmse:8.44917                                                       
[10]	validation-rmse:8.24754                                                      
[11]	validation-rmse:8.06652                                                      
[12]





[0]	validation-rmse:8.14568                                                       
[1]	validation-rmse:7.03012                                                       
[2]	validation-rmse:6.72977                                                       
[3]	validation-rmse:6.63506                                                       
[4]	validation-rmse:6.59398                                                       
[5]	validation-rmse:6.57054                                                       
[6]	validation-rmse:6.56053                                                       
[7]	validation-rmse:6.55209                                                       
[8]	validation-rmse:6.54409                                                       
[9]	validation-rmse:6.54077                                                       
[10]	validation-rmse:6.53418                                                      
[11]	validation-rmse:6.53429                                                      
[12]





[0]	validation-rmse:11.80945                                                      
[1]	validation-rmse:11.43192                                                      
[2]	validation-rmse:11.07920                                                      
[3]	validation-rmse:10.75006                                                      
[4]	validation-rmse:10.44313                                                      
[5]	validation-rmse:10.15708                                                      
[6]	validation-rmse:9.89087                                                       
[7]	validation-rmse:9.64338                                                       
[8]	validation-rmse:9.41374                                                       
[9]	validation-rmse:9.20064                                                       
[10]	validation-rmse:9.00296                                                      
[11]	validation-rmse:8.81972                                                      
[12]





[0]	validation-rmse:11.32023                                                      
[1]	validation-rmse:10.55722                                                      
[2]	validation-rmse:9.90940                                                       
[3]	validation-rmse:9.36084                                                       
[4]	validation-rmse:8.89932                                                       
[5]	validation-rmse:8.51263                                                       
[6]	validation-rmse:8.19119                                                       
[7]	validation-rmse:7.92436                                                       
[8]	validation-rmse:7.70231                                                       
[9]	validation-rmse:7.51909                                                       
[10]	validation-rmse:7.36759                                                      
[11]	validation-rmse:7.24098                                                      
[12]





[0]	validation-rmse:11.57351                                                      
[1]	validation-rmse:10.99936                                                      
[2]	validation-rmse:10.48565                                                      
[3]	validation-rmse:10.02698                                                      
[4]	validation-rmse:9.61789                                                       
[5]	validation-rmse:9.25473                                                       
[6]	validation-rmse:8.93345                                                       
[7]	validation-rmse:8.64839                                                       
[8]	validation-rmse:8.39687                                                       
[9]	validation-rmse:8.17540                                                       
[10]	validation-rmse:7.98040                                                      
[11]	validation-rmse:7.80916                                                      
[12]

KeyboardInterrupt: 

In [15]:
params = {
    "learning_rate": 0.07779422828008767,
    "max_depth": 45,
    "min_child_weight": 1.2986450371090257,
    "objective": "reg:linear",
    "reg_alpha": 0.0632450975172789,
    "reg_lambda": 0.009296790888165441,
}

mlflow.xgboost.autolog()

booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, "validation")],
            early_stopping_rounds=50
        )

2024/07/17 12:05:59 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID 'e2629be2a6fa499181df7e88f13a59b7', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current xgboost workflow


: 

In [15]:
with mlflow.start_run():
    params = {
        "learning_rate": 0.07779422828008767,
        "max_depth": 45,
        "min_child_weight": 1.2986450371090257,
        "objective": "reg:linear",
        "reg_alpha": 0.0632450975172789,
        "reg_lambda": 0.009296790888165441,
        "seed": 42
    }

    mlflow.log_params(params)

    booster = xgb.train(
                params=params,
                dtrain=train,
                num_boost_round=1000,
                evals=[(valid, "validation")],
                early_stopping_rounds=50
            )
    
    y_preds = booster.predict(valid)
    rmse = mean_squared_error(y_val, y_preds, squared=False)
    mlflow.log_metric("rmse", rmse)

    with open("models/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)

    mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")



[0]	validation-rmse:11.58306
[1]	validation-rmse:11.01649
[2]	validation-rmse:10.50761
[3]	validation-rmse:10.05159
[4]	validation-rmse:9.64453
[5]	validation-rmse:9.28100
[6]	validation-rmse:8.95723
[7]	validation-rmse:8.67070
[8]	validation-rmse:8.41616
[9]	validation-rmse:8.19136
[10]	validation-rmse:7.99407
[11]	validation-rmse:7.81902
[12]	validation-rmse:7.66522
[13]	validation-rmse:7.53066
[14]	validation-rmse:7.41162
[15]	validation-rmse:7.30582
[16]	validation-rmse:7.21455
[17]	validation-rmse:7.13239
[18]	validation-rmse:7.06094
[19]	validation-rmse:6.99793
[20]	validation-rmse:6.94187
[21]	validation-rmse:6.89282
[22]	validation-rmse:6.84954
[23]	validation-rmse:6.81077
[24]	validation-rmse:6.77764
[25]	validation-rmse:6.74742
[26]	validation-rmse:6.72144
[27]	validation-rmse:6.69833
[28]	validation-rmse:6.67653
[29]	validation-rmse:6.65721
[30]	validation-rmse:6.63989
[31]	validation-rmse:6.62368
[32]	validation-rmse:6.60989
[33]	validation-rmse:6.59743
[34]	validation-rmse



In [16]:
logged_model = "runs:/3545158565b946f399a4d65886d3a71f/models_mlflow"
loaded_model = mlflow.pyfunc.load_model(logged_model)

In [18]:
xgboost_model = mlflow.xgboost.load_model("runs:/3545158565b946f399a4d65886d3a71f/models_mlflow")

In [20]:
y_pred = xgboost_model.predict(valid)

In [21]:
y_pred[:10]

array([14.450237,  6.998118, 15.637997, 24.363546,  9.460321, 17.118723,
       10.913568,  8.458502,  8.984916, 19.83681 ], dtype=float32)