In [32]:
!python -V

Python 3.9.19


In [33]:
import pandas as pd

In [34]:
import pickle

In [35]:
import seaborn as sns
import matplotlib.pyplot as plt

In [36]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error

In [37]:
import mlflow


mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location='/Users/son/Documents/GitHub/mlops-zoomcamp/02-experiment-tracking/mlruns/1', creation_time=1716546824066, experiment_id='1', last_update_time=1716546824066, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [38]:
def read_dataframe(filename):
    df = pd.read_csv(filename)

    df.lpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
    df.lpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [39]:
df_train = read_dataframe('./data/green_tripdata_2021-01.csv')
df_val = read_dataframe('./data/green_tripdata_2021-02.csv')

  df = pd.read_csv(filename)


In [40]:
len(df_train), len(df_val)

(73908, 61921)

In [41]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [42]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [43]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [44]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

mean_squared_error(y_val, y_pred, squared=False)

2024/05/26 10:58:15 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '19ce86726b7c454b9fafc9b964dd5632', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current sklearn workflow


7.7587152073244585

In [46]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

In [47]:
with mlflow.start_run():

    mlflow.set_tag("developer", "sonle")

    mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
    mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")

    alpha = 0.1
    mlflow.log_param("alpha", alpha)
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle")



In [18]:
import xgboost as xgb

In [19]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [20]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [21]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [22]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

[0]	validation-rmse:10.86050                          
[1]	validation-rmse:9.81412                           
  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[2]	validation-rmse:9.01685                           
[3]	validation-rmse:8.41510                           
[4]	validation-rmse:7.96585                           
[5]	validation-rmse:7.63118                           
[6]	validation-rmse:7.38517                           
[7]	validation-rmse:7.20287                           
[8]	validation-rmse:7.06826                           
[9]	validation-rmse:6.96944                           
[10]	validation-rmse:6.89447                          
[11]	validation-rmse:6.83738                          
[12]	validation-rmse:6.79351                          
[13]	validation-rmse:6.75985                          
[14]	validation-rmse:6.73290                          
[15]	validation-rmse:6.71466                          
[16]	validation-rmse:6.69851                          
[17]	validation-rmse:6.68498                          
[18]	validation-rmse:6.67468                          
[19]	validation-rmse:6.66669                          
[20]	valid




[4]	validation-rmse:8.89713                                                    
[5]	validation-rmse:8.51667                                                    
[6]	validation-rmse:8.20136                                                    
[7]	validation-rmse:7.94089                                                    
[8]	validation-rmse:7.72609                                                    
[9]	validation-rmse:7.54973                                                    
[10]	validation-rmse:7.40500                                                   
[11]	validation-rmse:7.28636                                                   
[12]	validation-rmse:7.18880                                                   
[13]	validation-rmse:7.10875                                                   
[14]	validation-rmse:7.04340                                                   
[15]	validation-rmse:6.98912                                                   
[16]	validation-rmse:6.94338            




[1]	validation-rmse:10.17965                                                   
[2]	validation-rmse:9.43948                                                    
[3]	validation-rmse:8.84295                                                    
[4]	validation-rmse:8.36784                                                    
[5]	validation-rmse:7.98983                                                    
[6]	validation-rmse:7.69204                                                    
[7]	validation-rmse:7.45895                                                    
[8]	validation-rmse:7.27621                                                    
[9]	validation-rmse:7.13312                                                    
[10]	validation-rmse:7.02019                                                   
[11]	validation-rmse:6.93124                                                   
[12]	validation-rmse:6.86162                                                   
[13]	validation-rmse:6.80650            




[5]	validation-rmse:6.86615                                                    
[6]	validation-rmse:6.78981                                                    
[7]	validation-rmse:6.74424                                                    
[8]	validation-rmse:6.71601                                                    
[9]	validation-rmse:6.69754                                                    
[10]	validation-rmse:6.68302                                                   
[11]	validation-rmse:6.67133                                                   
[12]	validation-rmse:6.66558                                                   
[13]	validation-rmse:6.66304                                                   
[14]	validation-rmse:6.65566                                                   
[15]	validation-rmse:6.65212                                                   
[16]	validation-rmse:6.64930                                                   
[17]	validation-rmse:6.64686            




[0]	validation-rmse:10.48210                                                   
[1]	validation-rmse:9.25964                                                    
[2]	validation-rmse:8.41273                                                    
[3]	validation-rmse:7.82540                                                    
[4]	validation-rmse:7.42698                                                    
[5]	validation-rmse:7.15796                                                    
[6]	validation-rmse:6.98162                                                    
[7]	validation-rmse:6.86126                                                    
[8]	validation-rmse:6.77840                                                    
[9]	validation-rmse:6.71397                                                    
[10]	validation-rmse:6.67291                                                   
[11]	validation-rmse:6.63596                                                   
[12]	validation-rmse:6.61419            




[0]	validation-rmse:10.37292                                                   
[1]	validation-rmse:9.09106                                                    
[2]	validation-rmse:8.21935                                                    
[3]	validation-rmse:7.63600                                                    
[4]	validation-rmse:7.25348                                                    
[5]	validation-rmse:7.00400                                                    
[6]	validation-rmse:6.84220                                                    
[7]	validation-rmse:6.73303                                                    
[8]	validation-rmse:6.65604                                                    
[9]	validation-rmse:6.60248                                                    
[10]	validation-rmse:6.56556                                                   
[11]	validation-rmse:6.53693                                                   
[12]	validation-rmse:6.51580            




[0]	validation-rmse:11.08024                                                   
[1]	validation-rmse:10.15974                                                   
[2]	validation-rmse:9.41414                                                    
[3]	validation-rmse:8.82370                                                    
[4]	validation-rmse:8.35391                                                    
[5]	validation-rmse:7.98250                                                    
[6]	validation-rmse:7.68994                                                    
[7]	validation-rmse:7.45966                                                    
[8]	validation-rmse:7.28508                                                    
[9]	validation-rmse:7.14502                                                    
[10]	validation-rmse:7.03324                                                   
[11]	validation-rmse:6.94620                                                   
[12]	validation-rmse:6.87648            




[0]	validation-rmse:10.83413                                                   
[1]	validation-rmse:9.76132                                                    
[2]	validation-rmse:8.93648                                                    
[3]	validation-rmse:8.31162                                                    
[4]	validation-rmse:7.84115                                                    
[5]	validation-rmse:7.49209                                                    
[6]	validation-rmse:7.23312                                                    
[7]	validation-rmse:7.04395                                                    
[8]	validation-rmse:6.90253                                                    
[9]	validation-rmse:6.79754                                                    
[10]	validation-rmse:6.71685                                                   
[11]	validation-rmse:6.65534                                                   
[12]	validation-rmse:6.60992            




[3]	validation-rmse:8.08878                                                    
[4]	validation-rmse:7.67653                                                    
[5]	validation-rmse:7.38992                                                    
[6]	validation-rmse:7.19073                                                    
[7]	validation-rmse:7.05254                                                    
[8]	validation-rmse:6.95592                                                    
[9]	validation-rmse:6.88766                                                    
[10]	validation-rmse:6.83747                                                   
[11]	validation-rmse:6.79984                                                   
[12]	validation-rmse:6.77303                                                   
[13]	validation-rmse:6.75330                                                   
[14]	validation-rmse:6.73873                                                   
[15]	validation-rmse:6.72649            




[0]	validation-rmse:6.90879                                                    
[1]	validation-rmse:6.65912                                                    
[2]	validation-rmse:6.61562                                                    
[3]	validation-rmse:6.60618                                                    
[4]	validation-rmse:6.59706                                                    
[5]	validation-rmse:6.58930                                                    
[6]	validation-rmse:6.58523                                                    
[7]	validation-rmse:6.57973                                                    
[8]	validation-rmse:6.57420                                                    
[9]	validation-rmse:6.56990                                                    
[10]	validation-rmse:6.56838                                                   
[11]	validation-rmse:6.56585                                                   
[12]	validation-rmse:6.56364            




[0]	validation-rmse:11.33983                                                    
[1]	validation-rmse:10.58602                                                    
[2]	validation-rmse:9.94030                                                     
[3]	validation-rmse:9.38765                                                     
[4]	validation-rmse:8.91767                                                     
[5]	validation-rmse:8.51826                                                     
[6]	validation-rmse:8.18413                                                     
[7]	validation-rmse:7.90239                                                     
[8]	validation-rmse:7.66606                                                     
[9]	validation-rmse:7.46771                                                     
[10]	validation-rmse:7.30153                                                    
[11]	validation-rmse:7.16316                                                    
[12]	validation-rmse:7.04855




[9]	validation-rmse:6.75192                                                     
[10]	validation-rmse:6.74836                                                    
[11]	validation-rmse:6.74634                                                    
[12]	validation-rmse:6.74066                                                    
[13]	validation-rmse:6.73807                                                    
[14]	validation-rmse:6.73552                                                    
[15]	validation-rmse:6.73401                                                    
[16]	validation-rmse:6.73206                                                    
[17]	validation-rmse:6.73046                                                    
[18]	validation-rmse:6.72795                                                    
[19]	validation-rmse:6.72669                                                    
[20]	validation-rmse:6.72617                                                    
[21]	validation-rmse:6.72576




[0]	validation-rmse:11.77666                                                    
[1]	validation-rmse:11.37176                                                    
[2]	validation-rmse:10.99573                                                    
[3]	validation-rmse:10.64593                                                    
[4]	validation-rmse:10.32185                                                    
[5]	validation-rmse:10.02234                                                    
[6]	validation-rmse:9.74500                                                     
[7]	validation-rmse:9.49040                                                     
[8]	validation-rmse:9.25241                                                     
[9]	validation-rmse:9.03534                                                     
[10]	validation-rmse:8.83663                                                    
[11]	validation-rmse:8.65032                                                    
[12]	validation-rmse:8.48146




[0]	validation-rmse:7.90413                                                     
[1]	validation-rmse:6.83961                                                     
[2]	validation-rmse:6.58326                                                     
[3]	validation-rmse:6.50988                                                     
[4]	validation-rmse:6.48011                                                     
[5]	validation-rmse:6.46865                                                     
[6]	validation-rmse:6.45936                                                     
[7]	validation-rmse:6.45228                                                     
[8]	validation-rmse:6.44272                                                     
[9]	validation-rmse:6.43794                                                     
[10]	validation-rmse:6.43163                                                    
[11]	validation-rmse:6.42573                                                    
[12]	validation-rmse:6.42250




[4]	validation-rmse:9.96713                                                     
[5]	validation-rmse:9.63442                                                     
[6]	validation-rmse:9.33388                                                     
[7]	validation-rmse:9.06289                                                     
[8]	validation-rmse:8.81849                                                     
[9]	validation-rmse:8.59889                                                     
[10]	validation-rmse:8.40190                                                    
[11]	validation-rmse:8.22512                                                    
[12]	validation-rmse:8.06715                                                    
[13]	validation-rmse:7.92514                                                    
[14]	validation-rmse:7.79852                                                    
[15]	validation-rmse:7.68553                                                    
[16]	validation-rmse:7.58482




[1]	validation-rmse:11.33202                                                    
[2]	validation-rmse:10.93993                                                    
[3]	validation-rmse:10.57766                                                    
[4]	validation-rmse:10.24309                                                    
[5]	validation-rmse:9.93519                                                     
[6]	validation-rmse:9.65139                                                     
[7]	validation-rmse:9.39062                                                     
[8]	validation-rmse:9.15133                                                     
[9]	validation-rmse:8.93154                                                     
[10]	validation-rmse:8.73002                                                    
[11]	validation-rmse:8.54580                                                    
[12]	validation-rmse:8.37737                                                    
[13]	validation-rmse:8.22311




[0]	validation-rmse:7.94894                                                     
[1]	validation-rmse:6.92887                                                     
[2]	validation-rmse:6.68622                                                     
[3]	validation-rmse:6.60084                                                     
[4]	validation-rmse:6.56763                                                     
[5]	validation-rmse:6.55064                                                     
[6]	validation-rmse:6.53655                                                     
[7]	validation-rmse:6.53067                                                     
[8]	validation-rmse:6.52631                                                     
[9]	validation-rmse:6.52208                                                     
[10]	validation-rmse:6.51850                                                    
[11]	validation-rmse:6.51537                                                    
[12]	validation-rmse:6.51252




[0]	validation-rmse:11.15691                                                    
[1]	validation-rmse:10.28391                                                    
[2]	validation-rmse:9.56712                                                     
[3]	validation-rmse:8.98665                                                     
[4]	validation-rmse:8.51233                                                     
[5]	validation-rmse:8.12774                                                     
[6]	validation-rmse:7.82144                                                     
[7]	validation-rmse:7.57985                                                     
[8]	validation-rmse:7.38608                                                     
[9]	validation-rmse:7.22834                                                     
[10]	validation-rmse:7.10584                                                    
[11]	validation-rmse:7.00659                                                    
[12]	validation-rmse:6.92551




[0]	validation-rmse:10.69762                                                    
[1]	validation-rmse:9.55429                                                     
[2]	validation-rmse:8.70242                                                     
[3]	validation-rmse:8.07706                                                     
[4]	validation-rmse:7.62633                                                     
[5]	validation-rmse:7.30038                                                     
[6]	validation-rmse:7.06941                                                     
[7]	validation-rmse:6.90265                                                     
[8]	validation-rmse:6.78084                                                     
[9]	validation-rmse:6.69494                                                     
[10]	validation-rmse:6.63291                                                    
[11]	validation-rmse:6.58698                                                    
[12]	validation-rmse:6.55091




[0]	validation-rmse:11.68177                                                    
[1]	validation-rmse:11.19461                                                    
[2]	validation-rmse:10.74871                                                    
[3]	validation-rmse:10.34109                                                    
[4]	validation-rmse:9.96904                                                     
[5]	validation-rmse:9.63010                                                     
[6]	validation-rmse:9.32161                                                     
[7]	validation-rmse:9.04208                                                     
[8]	validation-rmse:8.78787                                                     
[9]	validation-rmse:8.55760                                                     
[10]	validation-rmse:8.34912                                                    
[11]	validation-rmse:8.16146                                                    
[12]	validation-rmse:7.99182




[0]	validation-rmse:11.63727                                                    
[1]	validation-rmse:11.11429                                                    
[2]	validation-rmse:10.64370                                                    
[3]	validation-rmse:10.21301                                                    
[4]	validation-rmse:9.82845                                                     
[5]	validation-rmse:9.48159                                                     
[6]	validation-rmse:9.16535                                                     
[7]	validation-rmse:8.88526                                                     
[8]	validation-rmse:8.63163                                                     
[9]	validation-rmse:8.40556                                                     
[10]	validation-rmse:8.20116                                                    
[11]	validation-rmse:8.02484                                                    
[12]	validation-rmse:7.86119




[0]	validation-rmse:9.57114                                                     
[1]	validation-rmse:8.09359                                                     
[2]	validation-rmse:7.31460                                                     
[3]	validation-rmse:6.90768                                                     
[4]	validation-rmse:6.70048                                                     
[5]	validation-rmse:6.59043                                                     
[6]	validation-rmse:6.52646                                                     
[7]	validation-rmse:6.48690                                                     
[8]	validation-rmse:6.46029                                                     
[9]	validation-rmse:6.44631                                                     
[10]	validation-rmse:6.43477                                                    
[11]	validation-rmse:6.42452                                                    
[12]	validation-rmse:6.42005




[0]	validation-rmse:11.56112                                                    
[1]	validation-rmse:10.97603                                                    
[2]	validation-rmse:10.45265                                                    
[3]	validation-rmse:9.98491                                                     
[4]	validation-rmse:9.56891                                                     
[5]	validation-rmse:9.19995                                                     
[6]	validation-rmse:8.87272                                                     
[7]	validation-rmse:8.58354                                                     
[8]	validation-rmse:8.32893                                                     
[9]	validation-rmse:8.10457                                                     
[10]	validation-rmse:7.90688                                                    
[11]	validation-rmse:7.73447                                                    
[12]	validation-rmse:7.58265




[0]	validation-rmse:11.45096                                                    
[1]	validation-rmse:10.78027                                                    
[2]	validation-rmse:10.19170                                                    
[3]	validation-rmse:9.67747                                                     
[4]	validation-rmse:9.22919                                                     
[5]	validation-rmse:8.83950                                                     
[6]	validation-rmse:8.50199                                                     
[7]	validation-rmse:8.21102                                                     
[8]	validation-rmse:7.96126                                                     
[9]	validation-rmse:7.74611                                                     
[10]	validation-rmse:7.56140                                                    
[11]	validation-rmse:7.40307                                                    
[12]	validation-rmse:7.26871




[0]	validation-rmse:9.67442                                                     
[1]	validation-rmse:8.25230                                                     
[2]	validation-rmse:7.48913                                                     
[3]	validation-rmse:7.09231                                                     
[4]	validation-rmse:6.87197                                                     
[5]	validation-rmse:6.75559                                                     
[6]	validation-rmse:6.68406                                                     
[7]	validation-rmse:6.63966                                                     
[8]	validation-rmse:6.61421                                                     
[9]	validation-rmse:6.59545                                                     
[10]	validation-rmse:6.58264                                                    
[11]	validation-rmse:6.57228                                                    
[12]	validation-rmse:6.56661




[0]	validation-rmse:11.80358                                                    
[1]	validation-rmse:11.42056                                                    
[2]	validation-rmse:11.06275                                                    
[3]	validation-rmse:10.72862                                                    
[4]	validation-rmse:10.41669                                                    
[5]	validation-rmse:10.12599                                                    
[6]	validation-rmse:9.85570                                                     
[7]	validation-rmse:9.60397                                                     
[8]	validation-rmse:9.37035                                                     
[9]	validation-rmse:9.15324                                                     
[10]	validation-rmse:8.95199                                                    
[11]	validation-rmse:8.76528                                                    
[12]	validation-rmse:8.59245




[0]	validation-rmse:8.87585                                                     
[1]	validation-rmse:7.43375                                                     
[2]	validation-rmse:6.86523                                                     
[3]	validation-rmse:6.63389                                                     
[4]	validation-rmse:6.52922                                                     
[5]	validation-rmse:6.48265                                                     
[6]	validation-rmse:6.45470                                                     
[7]	validation-rmse:6.43723                                                     
[8]	validation-rmse:6.42660                                                     
[9]	validation-rmse:6.42124                                                     
[10]	validation-rmse:6.41453                                                    
[11]	validation-rmse:6.41021                                                    
[12]	validation-rmse:6.40914




[0]	validation-rmse:10.01783                                                    
[1]	validation-rmse:8.62135                                                     
[2]	validation-rmse:7.76200                                                     
[3]	validation-rmse:7.24988                                                     
[4]	validation-rmse:6.95053                                                     
[5]	validation-rmse:6.76830                                                     
[6]	validation-rmse:6.65563                                                     
[7]	validation-rmse:6.58516                                                     
[8]	validation-rmse:6.53958                                                     
[9]	validation-rmse:6.50754                                                     
[10]	validation-rmse:6.48599                                                    
[11]	validation-rmse:6.47159                                                    
[12]	validation-rmse:6.45823




[0]	validation-rmse:9.00136                                                     
[1]	validation-rmse:7.54428                                                     
[2]	validation-rmse:6.92944                                                     
[3]	validation-rmse:6.67330                                                     
[4]	validation-rmse:6.55873                                                     
[5]	validation-rmse:6.50232                                                     
[6]	validation-rmse:6.47075                                                     
[7]	validation-rmse:6.45035                                                     
[8]	validation-rmse:6.43747                                                     
[9]	validation-rmse:6.43033                                                     
[10]	validation-rmse:6.42454                                                    
[11]	validation-rmse:6.41889                                                    
[12]	validation-rmse:6.41898




[1]	validation-rmse:9.79273                                                     
[2]	validation-rmse:8.98684                                                     
[3]	validation-rmse:8.38072                                                     
[4]	validation-rmse:7.92849                                                     
[5]	validation-rmse:7.59454                                                     
[6]	validation-rmse:7.34806                                                     
[7]	validation-rmse:7.16699                                                     
[8]	validation-rmse:7.03402                                                     
[9]	validation-rmse:6.93027                                                     
[10]	validation-rmse:6.85238                                                    
[11]	validation-rmse:6.79610                                                    
[12]	validation-rmse:6.75273                                                    
[13]	validation-rmse:6.71696




[0]	validation-rmse:6.77085                                                     
[1]	validation-rmse:6.59303                                                     
[2]	validation-rmse:6.57521                                                     
[3]	validation-rmse:6.56132                                                     
[4]	validation-rmse:6.54920                                                     
[5]	validation-rmse:6.53805                                                     
[6]	validation-rmse:6.52680                                                     
[7]	validation-rmse:6.52026                                                     
[8]	validation-rmse:6.51343                                                     
[9]	validation-rmse:6.50591                                                     
[10]	validation-rmse:6.50089                                                    
[11]	validation-rmse:6.49676                                                    
[12]	validation-rmse:6.49178




[0]	validation-rmse:11.50612                                                    
[1]	validation-rmse:10.88348                                                    
[2]	validation-rmse:10.33185                                                    
[3]	validation-rmse:9.84555                                                     
[4]	validation-rmse:9.41555                                                     
[5]	validation-rmse:9.04732                                                     
[6]	validation-rmse:8.71776                                                     
[7]	validation-rmse:8.43108                                                     
[8]	validation-rmse:8.18832                                                     
[9]	validation-rmse:7.97292                                                     
[10]	validation-rmse:7.78326                                                    
[11]	validation-rmse:7.62077                                                    
[12]	validation-rmse:7.48013




[0]	validation-rmse:11.24850                                                    
[1]	validation-rmse:10.43128                                                    
[2]	validation-rmse:9.74551                                                     
[3]	validation-rmse:9.17191                                                     
[4]	validation-rmse:8.69310                                                     
[5]	validation-rmse:8.29778                                                     
[6]	validation-rmse:7.97171                                                     
[7]	validation-rmse:7.70425                                                     
[8]	validation-rmse:7.48444                                                     
[9]	validation-rmse:7.30439                                                     
[10]	validation-rmse:7.15860                                                    
[11]	validation-rmse:7.03843                                                    
[12]	validation-rmse:6.93992




[0]	validation-rmse:11.24942                                                    
[1]	validation-rmse:10.44134                                                    
[2]	validation-rmse:9.75893                                                     
[3]	validation-rmse:9.20024                                                     
[4]	validation-rmse:8.72685                                                     
[5]	validation-rmse:8.34060                                                     
[6]	validation-rmse:8.02036                                                     
[7]	validation-rmse:7.75679                                                     
[8]	validation-rmse:7.54262                                                     
[9]	validation-rmse:7.37106                                                     
[10]	validation-rmse:7.22613                                                    
[11]	validation-rmse:7.11270                                                    
[12]	validation-rmse:7.01771




[0]	validation-rmse:10.90541                                                    
[1]	validation-rmse:9.87197                                                     
[2]	validation-rmse:9.06801                                                     
[3]	validation-rmse:8.44731                                                     
[4]	validation-rmse:7.97013                                                     
[5]	validation-rmse:7.61100                                                     
[6]	validation-rmse:7.33797                                                     
[7]	validation-rmse:7.13370                                                     
[8]	validation-rmse:6.97918                                                     
[9]	validation-rmse:6.86071                                                     
[10]	validation-rmse:6.77446                                                    
[11]	validation-rmse:6.70512                                                    
[12]	validation-rmse:6.65134




[0]	validation-rmse:11.02469                                                    
[1]	validation-rmse:10.07235                                                    
[2]	validation-rmse:9.31151                                                     
[3]	validation-rmse:8.71574                                                     
[4]	validation-rmse:8.23713                                                     
[5]	validation-rmse:7.86655                                                     
[6]	validation-rmse:7.57620                                                     
[7]	validation-rmse:7.35561                                                     
[8]	validation-rmse:7.18620                                                     
[9]	validation-rmse:7.05653                                                     
[10]	validation-rmse:6.95353                                                    
[11]	validation-rmse:6.87506                                                    
[12]	validation-rmse:6.81191




[0]	validation-rmse:10.12194                                                    
[1]	validation-rmse:8.77730                                                     
[2]	validation-rmse:7.93911                                                     
[3]	validation-rmse:7.42441                                                     
[4]	validation-rmse:7.10966                                                     
[5]	validation-rmse:6.92095                                                     
[6]	validation-rmse:6.79860                                                     
[7]	validation-rmse:6.71993                                                     
[8]	validation-rmse:6.66862                                                     
[9]	validation-rmse:6.63518                                                     
[10]	validation-rmse:6.61326                                                    
[11]	validation-rmse:6.59664                                                    
[12]	validation-rmse:6.58390




[0]	validation-rmse:11.41660                                                    
[1]	validation-rmse:10.72521                                                    
[2]	validation-rmse:10.12289                                                    
[3]	validation-rmse:9.60543                                                     
[4]	validation-rmse:9.15780                                                     
[5]	validation-rmse:8.77854                                                     
[6]	validation-rmse:8.44654                                                     
[7]	validation-rmse:8.17026                                                     
[8]	validation-rmse:7.93502                                                     
[9]	validation-rmse:7.73323                                                     
[10]	validation-rmse:7.56423                                                    
[11]	validation-rmse:7.42314                                                    
[12]	validation-rmse:7.29800




[0]	validation-rmse:10.40637                                                    
[1]	validation-rmse:9.15188                                                     
[2]	validation-rmse:8.30274                                                     
[3]	validation-rmse:7.73302                                                     
[4]	validation-rmse:7.35785                                                     
[5]	validation-rmse:7.11881                                                     
[6]	validation-rmse:6.95445                                                     
[7]	validation-rmse:6.84524                                                     
[8]	validation-rmse:6.77230                                                     
[9]	validation-rmse:6.72119                                                     
[10]	validation-rmse:6.68128                                                    
[11]	validation-rmse:6.65654                                                    
[12]	validation-rmse:6.63783




[0]	validation-rmse:11.00649                                                    
[1]	validation-rmse:10.03418                                                    
[2]	validation-rmse:9.25704                                                     
[3]	validation-rmse:8.64079                                                     
[4]	validation-rmse:8.16090                                                     
[5]	validation-rmse:7.78268                                                     
[6]	validation-rmse:7.49424                                                     
[7]	validation-rmse:7.27001                                                     
[8]	validation-rmse:7.09850                                                     
[9]	validation-rmse:6.96351                                                     
[10]	validation-rmse:6.85955                                                    
[11]	validation-rmse:6.77783                                                    
[12]	validation-rmse:6.71330




[1]	validation-rmse:10.50627                                                    
[2]	validation-rmse:9.84429                                                     
[3]	validation-rmse:9.28861                                                     
[4]	validation-rmse:8.82276                                                     
[5]	validation-rmse:8.43685                                                     
[6]	validation-rmse:8.11665                                                     
[7]	validation-rmse:7.85031                                                     
[8]	validation-rmse:7.63148                                                     
[9]	validation-rmse:7.45151                                                     
[10]	validation-rmse:7.30215                                                    
[11]	validation-rmse:7.17939                                                    
[12]	validation-rmse:7.07935                                                    
[13]	validation-rmse:6.99692




[0]	validation-rmse:10.50835                                                    
[1]	validation-rmse:9.28029                                                     
[2]	validation-rmse:8.41425                                                     
[3]	validation-rmse:7.81556                                                     
[4]	validation-rmse:7.40549                                                     
[5]	validation-rmse:7.12902                                                     
[6]	validation-rmse:6.94281                                                     
[7]	validation-rmse:6.81298                                                     
[8]	validation-rmse:6.72568                                                     
[9]	validation-rmse:6.65984                                                     
[10]	validation-rmse:6.61441                                                    
[11]	validation-rmse:6.58025                                                    
[12]	validation-rmse:6.55502




[0]	validation-rmse:11.18156                                                    
[1]	validation-rmse:10.32465                                                    
[2]	validation-rmse:9.61668                                                     
[3]	validation-rmse:9.03470                                                     
[4]	validation-rmse:8.56130                                                     
[5]	validation-rmse:8.17760                                                     
[6]	validation-rmse:7.86300                                                     
[7]	validation-rmse:7.61516                                                     
[8]	validation-rmse:7.41827                                                     
[9]	validation-rmse:7.25604                                                     
[10]	validation-rmse:7.12374                                                    
[11]	validation-rmse:7.02461                                                    
[12]	validation-rmse:6.94295




[0]	validation-rmse:10.91202                                                    
[1]	validation-rmse:9.88660                                                     
[2]	validation-rmse:9.08744                                                     
[3]	validation-rmse:8.47042                                                     
[4]	validation-rmse:8.00013                                                     
[5]	validation-rmse:7.64122                                                     
[6]	validation-rmse:7.37116                                                     
[7]	validation-rmse:7.16866                                                     
[8]	validation-rmse:7.01674                                                     
[9]	validation-rmse:6.90252                                                     
[10]	validation-rmse:6.81403                                                    
[11]	validation-rmse:6.74678                                                    
[12]	validation-rmse:6.69637




[1]	validation-rmse:10.98430                                                    
[2]	validation-rmse:10.46576                                                    
[3]	validation-rmse:10.00536                                                    
[4]	validation-rmse:9.59389                                                     
[5]	validation-rmse:9.23263                                                     
[6]	validation-rmse:8.91173                                                     
[7]	validation-rmse:8.63108                                                     
[8]	validation-rmse:8.38214                                                     
[9]	validation-rmse:8.16446                                                     
[10]	validation-rmse:7.97422                                                    
[11]	validation-rmse:7.80670                                                    
[12]	validation-rmse:7.66020                                                    
[13]	validation-rmse:7.53103




[0]	validation-rmse:10.22568                                                    
[1]	validation-rmse:8.90826                                                     
[2]	validation-rmse:8.06325                                                     
[3]	validation-rmse:7.52726                                                     
[4]	validation-rmse:7.19947                                                     
[5]	validation-rmse:6.98754                                                     
[6]	validation-rmse:6.85831                                                     
[7]	validation-rmse:6.76729                                                     
[8]	validation-rmse:6.71128                                                     
[9]	validation-rmse:6.67079                                                     
[10]	validation-rmse:6.64278                                                    
[11]	validation-rmse:6.62294                                                    
[12]	validation-rmse:6.60922




[0]	validation-rmse:9.19521                                                     
[1]	validation-rmse:7.73984                                                     
[2]	validation-rmse:7.08275                                                     
[3]	validation-rmse:6.78703                                                     
[4]	validation-rmse:6.65019                                                     
[5]	validation-rmse:6.57854                                                     
[6]	validation-rmse:6.53653                                                     
[7]	validation-rmse:6.51334                                                     
[8]	validation-rmse:6.49987                                                     
[9]	validation-rmse:6.49058                                                     
[10]	validation-rmse:6.48573                                                    
[11]	validation-rmse:6.48095                                                    
[12]	validation-rmse:6.47830




[1]	validation-rmse:7.27347                                                     
[2]	validation-rmse:6.83811                                                     
[3]	validation-rmse:6.68274                                                     
[4]	validation-rmse:6.61032                                                     
[5]	validation-rmse:6.57560                                                     
[6]	validation-rmse:6.55483                                                     
[7]	validation-rmse:6.54048                                                     
[8]	validation-rmse:6.52442                                                     
[9]	validation-rmse:6.51991                                                     
[10]	validation-rmse:6.51634                                                    
[11]	validation-rmse:6.51324                                                    
[12]	validation-rmse:6.50823                                                    
[13]	validation-rmse:6.50495




[0]	validation-rmse:7.07377                                                     
[1]	validation-rmse:6.60945                                                     
[2]	validation-rmse:6.54259                                                     
[3]	validation-rmse:6.51982                                                     
[4]	validation-rmse:6.51007                                                     
[5]	validation-rmse:6.49757                                                     
[6]	validation-rmse:6.49047                                                     
[7]	validation-rmse:6.48084                                                     
[8]	validation-rmse:6.47296                                                     
[9]	validation-rmse:6.46205                                                     
[10]	validation-rmse:6.45563                                                    
[11]	validation-rmse:6.45205                                                    
[12]	validation-rmse:6.44519




[0]	validation-rmse:11.37488                                                    
[1]	validation-rmse:10.65072                                                    
[2]	validation-rmse:10.03186                                                    
[3]	validation-rmse:9.49350                                                     
[4]	validation-rmse:9.04472                                                     
[5]	validation-rmse:8.65943                                                     
[6]	validation-rmse:8.33383                                                     
[7]	validation-rmse:8.05216                                                     
[8]	validation-rmse:7.82265                                                     
[9]	validation-rmse:7.62243                                                     
[10]	validation-rmse:7.45504                                                    
[11]	validation-rmse:7.31705                                                    
[12]	validation-rmse:7.20519

In [49]:
mlflow.xgboost.autolog(disable=True)

In [50]:
with mlflow.start_run():
    
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_val, label=y_val)

    best_params = {
        'learning_rate': 0.09585355369315604,
        'max_depth': 30,
        'min_child_weight': 1.060597050922164,
        'objective': 'reg:linear',
        'reg_alpha': 0.018060244040060163,
        'reg_lambda': 0.011658731377413597,
        'seed': 42
    }

    mlflow.log_params(best_params)

    booster = xgb.train(
        params=best_params,
        dtrain=train,
        num_boost_round=1000,
        evals=[(valid, 'validation')],
        early_stopping_rounds=50
    )

    y_pred = booster.predict(valid)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    with open("models/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)
        
    mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")



[0]	validation-rmse:11.44482
[1]	validation-rmse:10.77202
[2]	validation-rmse:10.18363
[3]	validation-rmse:9.67396
[4]	validation-rmse:9.23166
[5]	validation-rmse:8.84808
[6]	validation-rmse:8.51883
[7]	validation-rmse:8.23597
[8]	validation-rmse:7.99320
[9]	validation-rmse:7.78709
[10]	validation-rmse:7.61022
[11]	validation-rmse:7.45952
[12]	validation-rmse:7.33049
[13]	validation-rmse:7.22098
[14]	validation-rmse:7.12713
[15]	validation-rmse:7.04752
[16]	validation-rmse:6.98005
[17]	validation-rmse:6.92232
[18]	validation-rmse:6.87112
[19]	validation-rmse:6.82740
[20]	validation-rmse:6.78995
[21]	validation-rmse:6.75792
[22]	validation-rmse:6.72994
[23]	validation-rmse:6.70547
[24]	validation-rmse:6.68390
[25]	validation-rmse:6.66421
[26]	validation-rmse:6.64806
[27]	validation-rmse:6.63280
[28]	validation-rmse:6.61924
[29]	validation-rmse:6.60773
[30]	validation-rmse:6.59777
[31]	validation-rmse:6.58875
[32]	validation-rmse:6.58107
[33]	validation-rmse:6.57217
[34]	validation-rmse:



In [51]:
logged_model = 'runs:/7fdb8da0da054dad8ccc9888fafd81fb/models_mlflow'

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)



In [52]:
loaded_model

mlflow.pyfunc.loaded_model:
  artifact_path: models_mlflow
  flavor: mlflow.xgboost
  run_id: 7fdb8da0da054dad8ccc9888fafd81fb

In [53]:
xgboost_model = mlflow.xgboost.load_model(logged_model)



In [54]:
xgboost_model

<xgboost.core.Booster at 0x3272bd610>

In [55]:
y_pred = xgboost_model.predict(valid)

In [56]:
y_pred[:10]

array([14.782765 ,  7.184751 , 15.971323 , 24.328938 ,  9.559302 ,
       17.115105 , 11.6522455,  8.688133 ,  8.962229 , 18.982166 ],
      dtype=float32)

In [25]:
# !python -m pip install scikit-learn==1.3.2

In [30]:
# import sklearn
# print('The scikit-learn version is {}.'.format(sklearn.__version__))

# import mlflow
# print('The mlflow version is {}.'.format(mlflow.__version__))

The scikit-learn version is 1.3.2.
The mlflow version is 2.13.0.


In [57]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR

mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):

    with mlflow.start_run():

        mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
        mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)
        

