In [2]:
!python -V

Python 3.9.19


In [3]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge
import xgboost as xgb

from sklearn.metrics import mean_squared_error, root_mean_squared_error

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import pickle

In [4]:
import mlflow

mlflow.set_tracking_uri("sqlite:///mlruns.db")
mlflow.set_experiment("nyc-taxi-experiments")

<Experiment: artifact_location='/Users/rajitsanghvi/Library/CloudStorage/OneDrive-OPEDGMBH/General/01_Github/MLOps/02-experiment-tracking/mlruns/2', creation_time=1716478091983, experiment_id='2', last_update_time=1716478091983, lifecycle_stage='active', name='nyc-taxi-experiments', tags={}>

In [6]:
def read_dataframe(filename):
    if filename.endswith('.csv'):
        df = pd.read_csv(filename)

        df.lpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
        df.lpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)
    elif filename.endswith('.parquet'):
        df = pd.read_parquet(filename)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [7]:
df_train = read_dataframe('./data/green_tripdata_2021-01.parquet')
df_val = read_dataframe('./data/green_tripdata_2021-02.parquet')

In [8]:
len(df_train), len(df_val)

(73908, 61921)

In [9]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [10]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [11]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [17]:
with mlflow.start_run():

    mlflow.set_tag("develop", "R.Sanghvi")
    mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.parquet")
    mlflow.log_param("val-data-path", "./data/green_tripdata_2021-02.parquet")

    alpha = 0.001
    lr = Lasso(alpha)
    mlflow.log_param("alpha", alpha)

    lr.fit(X_train, y_train)
    y_pred = lr.predict(X_val)

    rmse = root_mean_squared_error(y_val, y_pred)
    mlflow.log_metric("rmse", rmse)


In [18]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [19]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [20]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = root_mean_squared_error(y_val, y_pred)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [21]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

[0]	validation-rmse:11.64459                          
[1]	validation-rmse:11.12765                          
  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[2]	validation-rmse:10.65883                          
[3]	validation-rmse:10.23484                          
[4]	validation-rmse:9.85072                           
[5]	validation-rmse:9.50490                           
[6]	validation-rmse:9.19317                           
[7]	validation-rmse:8.91281                           
[8]	validation-rmse:8.66155                           
[9]	validation-rmse:8.43655                           
[10]	validation-rmse:8.23471                          
[11]	validation-rmse:8.05433                          
[12]	validation-rmse:7.89365                          
[13]	validation-rmse:7.75090                          
[14]	validation-rmse:7.62319                          
[15]	validation-rmse:7.50985                          
[16]	validation-rmse:7.40843                          
[17]	validation-rmse:7.31772                          
[18]	validation-rmse:7.23774                          
[19]	validation-rmse:7.16607                          
[20]	valid




[0]	validation-rmse:7.06806                                                    
[1]	validation-rmse:6.60073                                                    
[2]	validation-rmse:6.53279                                                    
[3]	validation-rmse:6.51155                                                    
[4]	validation-rmse:6.50486                                                    
[5]	validation-rmse:6.49580                                                    
[6]	validation-rmse:6.49251                                                    
[7]	validation-rmse:6.48355                                                    
[8]	validation-rmse:6.47871                                                    
[9]	validation-rmse:6.46805                                                    
[10]	validation-rmse:6.46327                                                   
[11]	validation-rmse:6.45728                                                   
[12]	validation-rmse:6.45269            




[3]	validation-rmse:9.71966                                                    
[4]	validation-rmse:9.28701                                                    
[5]	validation-rmse:8.91419                                                    
[6]	validation-rmse:8.59418                                                    
[7]	validation-rmse:8.31983                                                    
[8]	validation-rmse:8.08575                                                    
[9]	validation-rmse:7.88622                                                    
[10]	validation-rmse:7.71643                                                   
[11]	validation-rmse:7.57202                                                   
[12]	validation-rmse:7.44990                                                   
[13]	validation-rmse:7.34563                                                   
[14]	validation-rmse:7.25692                                                   
[15]	validation-rmse:7.18173            




[2]	validation-rmse:7.07420                                                    
[3]	validation-rmse:6.82184                                                    
[4]	validation-rmse:6.70746                                                    
[5]	validation-rmse:6.64607                                                    
[6]	validation-rmse:6.61635                                                    
[7]	validation-rmse:6.59928                                                    
[8]	validation-rmse:6.59129                                                    
[9]	validation-rmse:6.58961                                                    
[10]	validation-rmse:6.58166                                                   
[11]	validation-rmse:6.57698                                                   
[12]	validation-rmse:6.57226                                                   
[13]	validation-rmse:6.56810                                                   
[14]	validation-rmse:6.56538            




[1]	validation-rmse:11.01748                                                   
[2]	validation-rmse:10.50948                                                   
[3]	validation-rmse:10.05462                                                   
[4]	validation-rmse:9.64929                                                    
[5]	validation-rmse:9.28865                                                    
[6]	validation-rmse:8.96693                                                    
[7]	validation-rmse:8.68144                                                    
[8]	validation-rmse:8.43073                                                    
[9]	validation-rmse:8.20663                                                    
[10]	validation-rmse:8.01033                                                   
[11]	validation-rmse:7.83788                                                   
[12]	validation-rmse:7.68473                                                   
[13]	validation-rmse:7.55076            




[2]	validation-rmse:6.78865                                                    
[3]	validation-rmse:6.68392                                                    
[4]	validation-rmse:6.64048                                                    
[5]	validation-rmse:6.62033                                                    
[6]	validation-rmse:6.60382                                                    
[7]	validation-rmse:6.59496                                                    
[8]	validation-rmse:6.59114                                                    
[9]	validation-rmse:6.58755                                                    
[10]	validation-rmse:6.58551                                                   
[11]	validation-rmse:6.58320                                                   
[12]	validation-rmse:6.58066                                                   
[13]	validation-rmse:6.57706                                                   
[14]	validation-rmse:6.57475            




[0]	validation-rmse:11.58974                                                   
[1]	validation-rmse:11.02816                                                   
[2]	validation-rmse:10.52284                                                   
[3]	validation-rmse:10.06843                                                   
[4]	validation-rmse:9.66144                                                    
[5]	validation-rmse:9.29811                                                    
[6]	validation-rmse:8.97235                                                    
[7]	validation-rmse:8.68374                                                    
[8]	validation-rmse:8.42655                                                    
[9]	validation-rmse:8.19909                                                    
[10]	validation-rmse:7.99788                                                   
[11]	validation-rmse:7.82065                                                   
[12]	validation-rmse:7.66497            




[1]	validation-rmse:6.76836                                                    
[2]	validation-rmse:6.63362                                                    
[3]	validation-rmse:6.59691                                                    
[4]	validation-rmse:6.57914                                                    
[5]	validation-rmse:6.57566                                                    
[6]	validation-rmse:6.56872                                                    
[7]	validation-rmse:6.56188                                                    
[8]	validation-rmse:6.55194                                                    
[9]	validation-rmse:6.54510                                                    
[10]	validation-rmse:6.53991                                                   
[11]	validation-rmse:6.53261                                                   
[12]	validation-rmse:6.52471                                                   
[13]	validation-rmse:6.51851            




[3]	validation-rmse:9.41605                                                    
[4]	validation-rmse:8.95800                                                    
[5]	validation-rmse:8.57340                                                    
[6]	validation-rmse:8.25110                                                    
[7]	validation-rmse:7.98114                                                    
[8]	validation-rmse:7.75635                                                    
[9]	validation-rmse:7.56924                                                    
[10]	validation-rmse:7.41403                                                   
[11]	validation-rmse:7.28528                                                   
[12]	validation-rmse:7.17813                                                   
[13]	validation-rmse:7.09018                                                   
[14]	validation-rmse:7.01619                                                   
[15]	validation-rmse:6.95401            




[10]	validation-rmse:6.67707                                                   
[11]	validation-rmse:6.67573                                                   
[12]	validation-rmse:6.67211                                                   
[13]	validation-rmse:6.66294                                                   
[14]	validation-rmse:6.65932                                                   
[15]	validation-rmse:6.65442                                                   
[16]	validation-rmse:6.65210                                                   
[17]	validation-rmse:6.64526                                                   
[18]	validation-rmse:6.64070                                                   
[19]	validation-rmse:6.63513                                                   
[20]	validation-rmse:6.63408                                                   
[21]	validation-rmse:6.63309                                                   
[22]	validation-rmse:6.62449            




[1]	validation-rmse:10.72167                                                    
[2]	validation-rmse:10.11912                                                    
[3]	validation-rmse:9.59861                                                     
[4]	validation-rmse:9.15106                                                     
[5]	validation-rmse:8.76591                                                     
[6]	validation-rmse:8.43718                                                     
[7]	validation-rmse:8.15749                                                     
[8]	validation-rmse:7.91966                                                     
[9]	validation-rmse:7.71850                                                     
[10]	validation-rmse:7.54778                                                    
[11]	validation-rmse:7.40466                                                    
[12]	validation-rmse:7.28276                                                    
[13]	validation-rmse:7.17943




[3]	validation-rmse:7.08493                                                     
[4]	validation-rmse:6.91042                                                     
[5]	validation-rmse:6.81663                                                     
[6]	validation-rmse:6.76051                                                     
[7]	validation-rmse:6.73003                                                     
[8]	validation-rmse:6.71105                                                     
[9]	validation-rmse:6.69756                                                     
[10]	validation-rmse:6.68346                                                    
[11]	validation-rmse:6.67516                                                    
[12]	validation-rmse:6.67043                                                    
[13]	validation-rmse:6.66764                                                    
[14]	validation-rmse:6.66435                                                    
[15]	validation-rmse:6.66333




[1]	validation-rmse:7.82750                                                     
[2]	validation-rmse:7.15089                                                     
[3]	validation-rmse:6.84223                                                     
[4]	validation-rmse:6.69672                                                     
[5]	validation-rmse:6.62128                                                     
[6]	validation-rmse:6.57728                                                     
[7]	validation-rmse:6.55324                                                     
[8]	validation-rmse:6.53668                                                     
[9]	validation-rmse:6.52593                                                     
[10]	validation-rmse:6.51700                                                    
[11]	validation-rmse:6.51308                                                    
[12]	validation-rmse:6.50881                                                    
[13]	validation-rmse:6.50479




[1]	validation-rmse:8.09297                                                     
[2]	validation-rmse:7.34420                                                     
[3]	validation-rmse:6.96968                                                     
[4]	validation-rmse:6.77534                                                     
[5]	validation-rmse:6.67468                                                     
[6]	validation-rmse:6.61530                                                     
[7]	validation-rmse:6.57900                                                     
[8]	validation-rmse:6.55266                                                     
[9]	validation-rmse:6.53403                                                     
[10]	validation-rmse:6.52318                                                    
[11]	validation-rmse:6.51200                                                    
[12]	validation-rmse:6.50095                                                    
[13]	validation-rmse:6.49604




[1]	validation-rmse:10.56512                                                    
[2]	validation-rmse:9.91606                                                     
[3]	validation-rmse:9.36434                                                     
[4]	validation-rmse:8.89930                                                     
[5]	validation-rmse:8.50849                                                     
[6]	validation-rmse:8.18315                                                     
[7]	validation-rmse:7.91196                                                     
[8]	validation-rmse:7.68522                                                     
[9]	validation-rmse:7.49580                                                     
[10]	validation-rmse:7.33888                                                    
[11]	validation-rmse:7.20580                                                    
[12]	validation-rmse:7.09734                                                    
[13]	validation-rmse:7.00805




[0]	validation-rmse:7.96592                                                     
[1]	validation-rmse:6.95614                                                     
[2]	validation-rmse:6.72422                                                     
[3]	validation-rmse:6.64709                                                     
[4]	validation-rmse:6.61410                                                     
[5]	validation-rmse:6.59997                                                     
[6]	validation-rmse:6.58670                                                     
[7]	validation-rmse:6.58350                                                     
[8]	validation-rmse:6.57752                                                     
[9]	validation-rmse:6.57379                                                     
[10]	validation-rmse:6.56814                                                    
[11]	validation-rmse:6.56651                                                    
[12]	validation-rmse:6.56035




[4]	validation-rmse:8.87316                                                     
[5]	validation-rmse:8.49448                                                     
[6]	validation-rmse:8.18233                                                     
[7]	validation-rmse:7.92536                                                     
[8]	validation-rmse:7.71388                                                     
[9]	validation-rmse:7.54090                                                     
[10]	validation-rmse:7.39945                                                    
[11]	validation-rmse:7.28345                                                    
[12]	validation-rmse:7.18831                                                    
[13]	validation-rmse:7.11081                                                    
[14]	validation-rmse:7.04715                                                    
[15]	validation-rmse:6.99521                                                    
[16]	validation-rmse:6.95207




[2]	validation-rmse:10.78367                                                    
[3]	validation-rmse:10.38679                                                    
[4]	validation-rmse:10.02511                                                    
[5]	validation-rmse:9.69546                                                     
[6]	validation-rmse:9.39600                                                     
[7]	validation-rmse:9.12448                                                     
[8]	validation-rmse:8.87851                                                     
[9]	validation-rmse:8.65566                                                     
[10]	validation-rmse:8.45481                                                    
[11]	validation-rmse:8.27356                                                    
[12]	validation-rmse:8.11016                                                    
[13]	validation-rmse:7.96324                                                    
[14]	validation-rmse:7.83131




[11]	validation-rmse:8.22345
[12]	validation-rmse:8.06711                                                    
[13]	validation-rmse:7.92772                                                    
[14]	validation-rmse:7.80381                                                    
[15]	validation-rmse:7.69314                                                    
[16]	validation-rmse:7.59414                                                    
[17]	validation-rmse:7.50658                                                    
[18]	validation-rmse:7.42846                                                    
[19]	validation-rmse:7.35847                                                    
[20]	validation-rmse:7.29607                                                    
[21]	validation-rmse:7.24050                                                    
[22]	validation-rmse:7.19078                                                    
[23]	validation-rmse:7.14620                                                    




[2]	validation-rmse:6.63996                                                     
[3]	validation-rmse:6.61010                                                     
[4]	validation-rmse:6.59648                                                     
[5]	validation-rmse:6.58768                                                     
[6]	validation-rmse:6.57702                                                     
[7]	validation-rmse:6.56835                                                     
[8]	validation-rmse:6.56384                                                     
[9]	validation-rmse:6.55989                                                     
[10]	validation-rmse:6.55833                                                    
[11]	validation-rmse:6.55391                                                    
[12]	validation-rmse:6.55325                                                    
[13]	validation-rmse:6.55209                                                    
[14]	validation-rmse:6.54996




[2]	validation-rmse:8.95652                                                     
[3]	validation-rmse:8.35622                                                     
[4]	validation-rmse:7.91188                                                     
[5]	validation-rmse:7.58702                                                     
[6]	validation-rmse:7.34548                                                     
[7]	validation-rmse:7.17189                                                     
[8]	validation-rmse:7.04601                                                     
[9]	validation-rmse:6.94678                                                     
[10]	validation-rmse:6.87833                                                    
[11]	validation-rmse:6.82748                                                    
[12]	validation-rmse:6.78605                                                    
[13]	validation-rmse:6.75674                                                    
[14]	validation-rmse:6.73402




[0]	validation-rmse:10.75300                                                    
[1]	validation-rmse:9.65124                                                     
[2]	validation-rmse:8.83116                                                     
[3]	validation-rmse:8.22897                                                     
[4]	validation-rmse:7.78221                                                     
[5]	validation-rmse:7.46164                                                     
[6]	validation-rmse:7.23236                                                     
[7]	validation-rmse:7.05988                                                     
[8]	validation-rmse:6.93984                                                     
[9]	validation-rmse:6.85356                                                     
[10]	validation-rmse:6.78435                                                    
[11]	validation-rmse:6.73523                                                    
[12]	validation-rmse:6.69512




[3]	validation-rmse:8.77998                                                     
[4]	validation-rmse:8.31302                                                     
[5]	validation-rmse:7.95093                                                     
[6]	validation-rmse:7.66457                                                     
[7]	validation-rmse:7.44303                                                     
[8]	validation-rmse:7.26933                                                     
[9]	validation-rmse:7.13913                                                     
[10]	validation-rmse:7.03573                                                    
[11]	validation-rmse:6.95241                                                    
[12]	validation-rmse:6.88934                                                    
[13]	validation-rmse:6.83781                                                    
[14]	validation-rmse:6.79757                                                    
[15]	validation-rmse:6.76597




[0]	validation-rmse:10.20712                                                    
[1]	validation-rmse:8.88377                                                     
[2]	validation-rmse:8.03954                                                     
[3]	validation-rmse:7.51251                                                     
[4]	validation-rmse:7.17625                                                     
[5]	validation-rmse:6.97103                                                     
[6]	validation-rmse:6.83571                                                     
[7]	validation-rmse:6.74584                                                     
[8]	validation-rmse:6.69164                                                     
[9]	validation-rmse:6.65413                                                     
[10]	validation-rmse:6.62860                                                    
[11]	validation-rmse:6.60897                                                    
[12]	validation-rmse:6.59532




[0]	validation-rmse:10.48129                                                    
[1]	validation-rmse:9.24506                                                     
[2]	validation-rmse:8.38169                                                     
[3]	validation-rmse:7.79016                                                     
[4]	validation-rmse:7.38732                                                     
[5]	validation-rmse:7.11717                                                     
[6]	validation-rmse:6.93663                                                     
[7]	validation-rmse:6.81350                                                     
[8]	validation-rmse:6.72794                                                     
[9]	validation-rmse:6.66810                                                     
[10]	validation-rmse:6.62637                                                    
[11]	validation-rmse:6.59337                                                    
[12]	validation-rmse:6.56834




[2]	validation-rmse:11.04361                                                    
[3]	validation-rmse:10.70516                                                    
[4]	validation-rmse:10.39129                                                    
[5]	validation-rmse:10.09918                                                    
[6]	validation-rmse:9.82818                                                     
[7]	validation-rmse:9.57621                                                     
[8]	validation-rmse:9.34428                                                     
[9]	validation-rmse:9.12821                                                     
[10]	validation-rmse:8.92885                                                    
[11]	validation-rmse:8.74533                                                    
[12]	validation-rmse:8.57557                                                    
[13]	validation-rmse:8.41867                                                    
[14]	validation-rmse:8.27449




[6]	validation-rmse:6.66140                                                     
[7]	validation-rmse:6.65850                                                     
[8]	validation-rmse:6.64864                                                     
[9]	validation-rmse:6.64472                                                     
[10]	validation-rmse:6.63839                                                    
[11]	validation-rmse:6.63329                                                    
[12]	validation-rmse:6.62748                                                    
[13]	validation-rmse:6.62182                                                    
[14]	validation-rmse:6.61710                                                    
[15]	validation-rmse:6.61372                                                    
[16]	validation-rmse:6.61115                                                    
[17]	validation-rmse:6.60859                                                    
[18]	validation-rmse:6.60700




[1]	validation-rmse:8.90305                                                     
[2]	validation-rmse:8.04828                                                     
[3]	validation-rmse:7.51209                                                     
[4]	validation-rmse:7.17505                                                     
[5]	validation-rmse:6.96571                                                     
[6]	validation-rmse:6.83139                                                     
[7]	validation-rmse:6.74325                                                     
[8]	validation-rmse:6.68449                                                     
[9]	validation-rmse:6.64467                                                     
[10]	validation-rmse:6.61698                                                    
[11]	validation-rmse:6.59610                                                    
[12]	validation-rmse:6.58142                                                    
[13]	validation-rmse:6.56773




[1]	validation-rmse:8.51745                                                     
[2]	validation-rmse:7.68261                                                     
[3]	validation-rmse:7.20532                                                     
[4]	validation-rmse:6.93095                                                     
[5]	validation-rmse:6.77114                                                     
[6]	validation-rmse:6.67535                                                     
[7]	validation-rmse:6.61728                                                     
[8]	validation-rmse:6.57992                                                     
[9]	validation-rmse:6.55294                                                     
[10]	validation-rmse:6.53604                                                    
[11]	validation-rmse:6.52219                                                    
[12]	validation-rmse:6.51296                                                    
[13]	validation-rmse:6.50625




[1]	validation-rmse:10.20357                                                    
[2]	validation-rmse:9.46881                                                     
[3]	validation-rmse:8.88427                                                     
[4]	validation-rmse:8.40982                                                     
[5]	validation-rmse:8.03817                                                     
[6]	validation-rmse:7.74295                                                     
[7]	validation-rmse:7.51180                                                     
[8]	validation-rmse:7.32686                                                     
[9]	validation-rmse:7.18520                                                     
[10]	validation-rmse:7.06986                                                    
[11]	validation-rmse:6.98215                                                    
[12]	validation-rmse:6.91241                                                    
[13]	validation-rmse:6.85152




[0]	validation-rmse:10.42324                                                    
[1]	validation-rmse:9.17238                                                     
[2]	validation-rmse:8.32053                                                     
[3]	validation-rmse:7.73792                                                     
[4]	validation-rmse:7.35234                                                     
[5]	validation-rmse:7.09693                                                     
[6]	validation-rmse:6.92692                                                     
[7]	validation-rmse:6.81836                                                     
[8]	validation-rmse:6.73912                                                     
[9]	validation-rmse:6.68028                                                     
[10]	validation-rmse:6.64296                                                    
[11]	validation-rmse:6.61470                                                    
[12]	validation-rmse:6.59167




[12]	validation-rmse:7.04871                                                    
[13]	validation-rmse:6.99633                                                    
[14]	validation-rmse:6.95435                                                    
[15]	validation-rmse:6.92293                                                    
[16]	validation-rmse:6.89699                                                    
[17]	validation-rmse:6.87551                                                    
[18]	validation-rmse:6.85714                                                    
[19]	validation-rmse:6.84264                                                    
[20]	validation-rmse:6.83003                                                    
[21]	validation-rmse:6.82135                                                    
[22]	validation-rmse:6.81202                                                    
[23]	validation-rmse:6.80493                                                    
[24]	validation-rmse:6.80132




[1]	validation-rmse:9.64383                                                     
[2]	validation-rmse:8.81145                                                     
[3]	validation-rmse:8.20208                                                     
[4]	validation-rmse:7.75495                                                     
[5]	validation-rmse:7.43189                                                     
[6]	validation-rmse:7.19946                                                     
[7]	validation-rmse:7.03132                                                     
[8]	validation-rmse:6.90697                                                     
[9]	validation-rmse:6.81717                                                     
[10]	validation-rmse:6.75209                                                    
[11]	validation-rmse:6.70276                                                    
[12]	validation-rmse:6.66465                                                    
[13]	validation-rmse:6.63658




[0]	validation-rmse:11.79684                                                    
[1]	validation-rmse:11.40813                                                    
[2]	validation-rmse:11.04584                                                    
[3]	validation-rmse:10.70736                                                    
[4]	validation-rmse:10.39229                                                    
[5]	validation-rmse:10.09873                                                    
[6]	validation-rmse:9.82624                                                     
[7]	validation-rmse:9.57275                                                     
[8]	validation-rmse:9.33810                                                     
[9]	validation-rmse:9.12021                                                     
[10]	validation-rmse:8.91910                                                    
[11]	validation-rmse:8.73251                                                    
[12]	validation-rmse:8.56011




[0]	validation-rmse:10.01370                                                    
[1]	validation-rmse:8.62358                                                     
[2]	validation-rmse:7.77300                                                     
[3]	validation-rmse:7.26671                                                     
[4]	validation-rmse:6.96878                                                     
[5]	validation-rmse:6.78908                                                     
[6]	validation-rmse:6.67895                                                     
[7]	validation-rmse:6.60965                                                     
[8]	validation-rmse:6.56263                                                     
[9]	validation-rmse:6.53156                                                     
[10]	validation-rmse:6.51114                                                    
[11]	validation-rmse:6.49561                                                    
[12]	validation-rmse:6.48238




[1]	validation-rmse:10.84531                                                    
[2]	validation-rmse:10.27958                                                    
[3]	validation-rmse:9.78301                                                     
[4]	validation-rmse:9.34831                                                     
[5]	validation-rmse:8.96814                                                     
[6]	validation-rmse:8.63706                                                     
[7]	validation-rmse:8.34974                                                     
[8]	validation-rmse:8.10116                                                     
[9]	validation-rmse:7.88721                                                     
[10]	validation-rmse:7.70174                                                    
[11]	validation-rmse:7.54215                                                    
[12]	validation-rmse:7.40522                                                    
[13]	validation-rmse:7.28708




[2]	validation-rmse:6.95912                                                     
[3]	validation-rmse:6.75300                                                     
[4]	validation-rmse:6.66147                                                     
[5]	validation-rmse:6.61737                                                     
[6]	validation-rmse:6.58857                                                     
[7]	validation-rmse:6.57672                                                     
[8]	validation-rmse:6.57021                                                     
[9]	validation-rmse:6.56697                                                     
[10]	validation-rmse:6.55921                                                    
[11]	validation-rmse:6.55508                                                    
[12]	validation-rmse:6.55122                                                    
[13]	validation-rmse:6.54542                                                    
[14]	validation-rmse:6.53985




[6]	validation-rmse:9.61146                                                     
[7]	validation-rmse:9.35017                                                     
[8]	validation-rmse:9.11152                                                     
[9]	validation-rmse:8.89357                                                     
[10]	validation-rmse:8.69469                                                    
[11]	validation-rmse:8.51337                                                    
[12]	validation-rmse:8.34763                                                    
[13]	validation-rmse:8.19729                                                    
[14]	validation-rmse:8.05946                                                    
[15]	validation-rmse:7.93503                                                    
[16]	validation-rmse:7.82235                                                    
[17]	validation-rmse:7.71905                                                    
[18]	validation-rmse:7.62693




[0]	validation-rmse:11.53452                                                    
[1]	validation-rmse:10.92893                                                    
[2]	validation-rmse:10.39083                                                    
[3]	validation-rmse:9.91285                                                     
[4]	validation-rmse:9.49024                                                     
[5]	validation-rmse:9.11566                                                     
[6]	validation-rmse:8.78666                                                     
[7]	validation-rmse:8.49761                                                     
[8]	validation-rmse:8.24424                                                     
[9]	validation-rmse:8.02234                                                     
[10]	validation-rmse:7.82956                                                    
[11]	validation-rmse:7.66171                                                    
[12]	validation-rmse:7.51445




[0]	validation-rmse:11.21003                                                    
[1]	validation-rmse:10.37319                                                    
[2]	validation-rmse:9.68309                                                     
[3]	validation-rmse:9.10358                                                     
[4]	validation-rmse:8.63522                                                     
[5]	validation-rmse:8.24597                                                     
[6]	validation-rmse:7.93266                                                     
[7]	validation-rmse:7.67613                                                     
[8]	validation-rmse:7.47134                                                     
[9]	validation-rmse:7.30809                                                     
[10]	validation-rmse:7.17161                                                    
[11]	validation-rmse:7.06762                                                    
[12]	validation-rmse:6.97668




[1]	validation-rmse:9.89035                                                     
[2]	validation-rmse:9.09103                                                     
[3]	validation-rmse:8.47763                                                     
[4]	validation-rmse:8.00934                                                     
[5]	validation-rmse:7.65467                                                     
[6]	validation-rmse:7.38775                                                     
[7]	validation-rmse:7.18542                                                     
[8]	validation-rmse:7.03281                                                     
[9]	validation-rmse:6.91706                                                     
[10]	validation-rmse:6.82890                                                    
[11]	validation-rmse:6.76061                                                    
[12]	validation-rmse:6.70780                                                    
[13]	validation-rmse:6.66705




[1]	validation-rmse:9.85624                                                     
[2]	validation-rmse:9.05152                                                     
[3]	validation-rmse:8.43742                                                     
[4]	validation-rmse:7.97210                                                     
[5]	validation-rmse:7.61920                                                     
[6]	validation-rmse:7.35733                                                     
[7]	validation-rmse:7.15763                                                     
[8]	validation-rmse:7.00764                                                     
[9]	validation-rmse:6.89532                                                     
[10]	validation-rmse:6.81118                                                    
[11]	validation-rmse:6.74477                                                    
[12]	validation-rmse:6.69363                                                    
[13]	validation-rmse:6.65413




[0]	validation-rmse:10.46796                                                    
[1]	validation-rmse:9.22352                                                     
[2]	validation-rmse:8.35081                                                     
[3]	validation-rmse:7.75868                                                     
[4]	validation-rmse:7.35624                                                     
[5]	validation-rmse:7.08423                                                     
[6]	validation-rmse:6.90317                                                     
[7]	validation-rmse:6.78029                                                     
[8]	validation-rmse:6.69414                                                     
[9]	validation-rmse:6.63361                                                     
[10]	validation-rmse:6.58938                                                    
[11]	validation-rmse:6.55716                                                    
[12]	validation-rmse:6.53379




[1]	validation-rmse:10.69979                                                    
[2]	validation-rmse:10.08994                                                    
[3]	validation-rmse:9.56370                                                     
[4]	validation-rmse:9.11120                                                     
[5]	validation-rmse:8.72344                                                     
[6]	validation-rmse:8.39314                                                     
[7]	validation-rmse:8.11365                                                     
[8]	validation-rmse:7.87549                                                     
[9]	validation-rmse:7.67339                                                     
[10]	validation-rmse:7.50330                                                    
[11]	validation-rmse:7.35948                                                    
[12]	validation-rmse:7.23832                                                    
[13]	validation-rmse:7.13611




[0]	validation-rmse:11.56827                                                    
[1]	validation-rmse:10.98927                                                    
[2]	validation-rmse:10.47046                                                    
[3]	validation-rmse:10.00577                                                    
[4]	validation-rmse:9.59184                                                     
[5]	validation-rmse:9.22299                                                     
[6]	validation-rmse:8.89497                                                     
[7]	validation-rmse:8.60604                                                     
[8]	validation-rmse:8.34921                                                     
[9]	validation-rmse:8.12270                                                     
[10]	validation-rmse:7.92311                                                    
[11]	validation-rmse:7.74869                                                    
[12]	validation-rmse:7.59527




[1]	validation-rmse:10.44015                                                    
[2]	validation-rmse:9.75892                                                     
[3]	validation-rmse:9.19243                                                     
[4]	validation-rmse:8.71939                                                     
[5]	validation-rmse:8.33174                                                     
[6]	validation-rmse:8.00958                                                     
[7]	validation-rmse:7.74824                                                     
[8]	validation-rmse:7.53300                                                     
[9]	validation-rmse:7.35857                                                     
[10]	validation-rmse:7.21438                                                    
[11]	validation-rmse:7.09826                                                    
[12]	validation-rmse:7.00230                                                    
[13]	validation-rmse:6.92430




[4]	validation-rmse:8.86882                                                     
[5]	validation-rmse:8.48746                                                     
[6]	validation-rmse:8.17038                                                     
[7]	validation-rmse:7.91036                                                     
[8]	validation-rmse:7.69632                                                     
[9]	validation-rmse:7.51915                                                     
[10]	validation-rmse:7.37241                                                    
[11]	validation-rmse:7.25440                                                    
[12]	validation-rmse:7.15602                                                    
[13]	validation-rmse:7.07440                                                    
[14]	validation-rmse:7.00796                                                    
[15]	validation-rmse:6.95068                                                    
[16]	validation-rmse:6.90416




[1]	validation-rmse:11.21592                                                    
[2]	validation-rmse:10.77939                                                    
[3]	validation-rmse:10.38039                                                    
[4]	validation-rmse:10.01605                                                    
[5]	validation-rmse:9.68556                                                     
[6]	validation-rmse:9.38328                                                     
[7]	validation-rmse:9.10930                                                     
[8]	validation-rmse:8.86038                                                     
[9]	validation-rmse:8.63494                                                     
[10]	validation-rmse:8.43179                                                    
[11]	validation-rmse:8.24844                                                    
[12]	validation-rmse:8.08150                                                    
[13]	validation-rmse:7.93152




[3]	validation-rmse:6.70119                                                     
[4]	validation-rmse:6.69011                                                     
[5]	validation-rmse:6.68725                                                     
[6]	validation-rmse:6.68692                                                     
[7]	validation-rmse:6.68064                                                     
[8]	validation-rmse:6.67954                                                     
[9]	validation-rmse:6.67887                                                     
[10]	validation-rmse:6.67378                                                    
[11]	validation-rmse:6.66942                                                    
[12]	validation-rmse:6.66444                                                    
[13]	validation-rmse:6.65923                                                    
[14]	validation-rmse:6.66098                                                    
[15]	validation-rmse:6.65776




[0]	validation-rmse:11.47037                                                    
[1]	validation-rmse:10.81894                                                    
[2]	validation-rmse:10.24637                                                    
[3]	validation-rmse:9.74742                                                     
[4]	validation-rmse:9.31375                                                     
[5]	validation-rmse:8.93798                                                     
[6]	validation-rmse:8.61216                                                     
[7]	validation-rmse:8.32464                                                     
[8]	validation-rmse:8.07929                                                     
[9]	validation-rmse:7.87057                                                     
[10]	validation-rmse:7.68507                                                    
[11]	validation-rmse:7.53407                                                    
[12]	validation-rmse:7.40129

In [22]:
# Training model with best hyperparameter
params = {
    "reg_lambda": 0.35931960489862685,
    "seed": 42,
    "max_depth": 37,
    "min_child_weight": 1.2139033434560909,
    "learning_rate": 0.11035059238513262,
    "reg_alpha": 0.05628231573565957,
    "objective": "reg:linear"
}

mlflow.xgboost.autolog()
booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
y_pred = booster.predict(valid)
rmse = root_mean_squared_error(y_val, y_pred)

2024/05/23 17:54:20 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID 'bc62efa19aaa4815a2a2f31733dda747', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current xgboost workflow


[0]	validation-rmse:11.32560




[1]	validation-rmse:10.56512
[2]	validation-rmse:9.91606
[3]	validation-rmse:9.36434
[4]	validation-rmse:8.89930
[5]	validation-rmse:8.50849
[6]	validation-rmse:8.18315
[7]	validation-rmse:7.91196
[8]	validation-rmse:7.68522
[9]	validation-rmse:7.49580
[10]	validation-rmse:7.33888
[11]	validation-rmse:7.20580
[12]	validation-rmse:7.09734
[13]	validation-rmse:7.00805
[14]	validation-rmse:6.93310
[15]	validation-rmse:6.86976
[16]	validation-rmse:6.81682
[17]	validation-rmse:6.77136
[18]	validation-rmse:6.73385
[19]	validation-rmse:6.70130
[20]	validation-rmse:6.67571
[21]	validation-rmse:6.65214
[22]	validation-rmse:6.63081
[23]	validation-rmse:6.61403
[24]	validation-rmse:6.59795
[25]	validation-rmse:6.58504
[26]	validation-rmse:6.57302
[27]	validation-rmse:6.56317
[28]	validation-rmse:6.55341
[29]	validation-rmse:6.54597
[30]	validation-rmse:6.53908
[31]	validation-rmse:6.53153
[32]	validation-rmse:6.52613
[33]	validation-rmse:6.52097
[34]	validation-rmse:6.51612
[35]	validation-rmse:6



In [24]:
mlflow.xgboost.autolog(disable=True)
with mlflow.start_run():

    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_val, label=y_val)

    # Training model with best hyperparameter
    best_params = {
        "reg_lambda": 0.35931960489862685,
        "seed": 42,
        "max_depth": 37,
        "min_child_weight": 1.2139033434560909,
        "learning_rate": 0.11035059238513262,
        "reg_alpha": 0.05628231573565957,
        "objective": "reg:linear"
    }

    mlflow.log_params(best_params)

    booster = xgb.train(
                params=best_params,
                dtrain=train,
                num_boost_round=1000,
                evals=[(valid, 'validation')],
                early_stopping_rounds=50
            )
    y_pred = booster.predict(valid)
    rmse = root_mean_squared_error(y_val, y_pred)

    mlflow.log_metric("rmse", rmse)

    with open("models/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)

    mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")
    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")

[0]	validation-rmse:11.32560




[1]	validation-rmse:10.56512
[2]	validation-rmse:9.91606
[3]	validation-rmse:9.36434
[4]	validation-rmse:8.89930
[5]	validation-rmse:8.50849
[6]	validation-rmse:8.18315
[7]	validation-rmse:7.91196
[8]	validation-rmse:7.68522
[9]	validation-rmse:7.49580
[10]	validation-rmse:7.33888
[11]	validation-rmse:7.20580
[12]	validation-rmse:7.09734
[13]	validation-rmse:7.00805
[14]	validation-rmse:6.93310
[15]	validation-rmse:6.86976
[16]	validation-rmse:6.81682
[17]	validation-rmse:6.77136
[18]	validation-rmse:6.73385
[19]	validation-rmse:6.70130
[20]	validation-rmse:6.67571
[21]	validation-rmse:6.65214
[22]	validation-rmse:6.63081
[23]	validation-rmse:6.61403
[24]	validation-rmse:6.59795
[25]	validation-rmse:6.58504
[26]	validation-rmse:6.57302
[27]	validation-rmse:6.56317
[28]	validation-rmse:6.55341
[29]	validation-rmse:6.54597
[30]	validation-rmse:6.53908
[31]	validation-rmse:6.53153
[32]	validation-rmse:6.52613
[33]	validation-rmse:6.52097
[34]	validation-rmse:6.51612
[35]	validation-rmse:6



In [25]:
logged_model = 'runs:/044fa9b574e04760b0462ae1e375fb02/models_mlflow'

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)



In [26]:
loaded_model

mlflow.pyfunc.loaded_model:
  artifact_path: models_mlflow
  flavor: mlflow.xgboost
  run_id: 044fa9b574e04760b0462ae1e375fb02

In [27]:
xgboost_model = mlflow.xgboost.load_model(logged_model)
xgboost_model



<xgboost.core.Booster at 0x379b1a070>

In [28]:
y_pred = xgboost_model.predict(valid)
rmse = root_mean_squared_error(y_val, y_pred)
rmse

6.300452789009226

In [12]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR

mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):

    with mlflow.start_run():

        mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.parquet")
        mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.parquet")
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

