In [1]:
import pandas as pd
import pickle

  from pandas.core import (


In [2]:
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error

In [4]:
from mlflow.tracking import MlflowClient

MLFLOW_TRACKING_URI = "sqlite:///mlflow.db"

client = MlflowClient(tracking_uri=MLFLOW_TRACKING_URI)

In [5]:
import mlflow


mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location='/workspaces/mlops-zoomcamp_/02-experiment-tracking/mlruns/1', creation_time=1716819604874, experiment_id='1', last_update_time=1716819604874, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [6]:
def read_dataframe(filename):
    if filename.endswith('.csv'):
        df = pd.read_csv(filename)

        df.tpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
        df.tpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)
    elif filename.endswith('.parquet'):
        df = pd.read_parquet(filename)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [7]:
df_train = read_dataframe('https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-01.parquet')
df_val = read_dataframe('https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-02.parquet')

In [8]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [9]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [10]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [11]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

mean_squared_error(y_val, y_pred, squared=False)



7.758715209663881

In [12]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

In [14]:
with mlflow.start_run():

    mlflow.set_tag("developer", "simon")

    mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
    mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")

    alpha = 0.1
    mlflow.log_param("alpha", alpha)
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle")



In [15]:
import xgboost as xgb

In [16]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [17]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [18]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [None]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

  0%|                                                                                                                                                                | 0/50 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:8.88183                                                                                                                                                                                 
[1]	validation-rmse:7.54013                                                                                                                                                                                 
[2]	validation-rmse:7.03495                                                                                                                                                                                 
[3]	validation-rmse:6.84328                                                                                                                                                                                 
[4]	validation-rmse:6.76918                                                                                                                                                         





[0]	validation-rmse:11.26005                                                                                                                                                                                
[1]	validation-rmse:10.45440                                                                                                                                                                                
[2]	validation-rmse:9.77386                                                                                                                                                                                 
[3]	validation-rmse:9.20416                                                                                                                                                                                 
[4]	validation-rmse:8.72765                                                                                                                                                         





[0]	validation-rmse:9.65934                                                                                                                                                                                 
[1]	validation-rmse:8.22886                                                                                                                                                                                 
[2]	validation-rmse:7.46551                                                                                                                                                                                 
[3]	validation-rmse:7.07103                                                                                                                                                                                 
[4]	validation-rmse:6.85901                                                                                                                                                         





[0]	validation-rmse:8.78829                                                                                                                                                                                 
[1]	validation-rmse:7.45298                                                                                                                                                                                 
[2]	validation-rmse:6.96809                                                                                                                                                                                 
[3]	validation-rmse:6.78246                                                                                                                                                                                 
[4]	validation-rmse:6.70094                                                                                                                                                         





[0]	validation-rmse:8.83815                                                                                                                                                                                 
[1]	validation-rmse:7.43185                                                                                                                                                                                 
[2]	validation-rmse:6.89032                                                                                                                                                                                 
[3]	validation-rmse:6.67070                                                                                                                                                                                 
[4]	validation-rmse:6.57381                                                                                                                                                         





[0]	validation-rmse:6.90663                                                                                                                                                                                 
[1]	validation-rmse:6.69049                                                                                                                                                                                 
[2]	validation-rmse:6.65088                                                                                                                                                                                 
[3]	validation-rmse:6.62787                                                                                                                                                                                 
[4]	validation-rmse:6.62138                                                                                                                                                         





[0]	validation-rmse:6.95255                                                                                                                                                                                 
[1]	validation-rmse:6.58040                                                                                                                                                                                 
[2]	validation-rmse:6.53476                                                                                                                                                                                 
[3]	validation-rmse:6.52812                                                                                                                                                                                 
[4]	validation-rmse:6.51422                                                                                                                                                         





[0]	validation-rmse:9.26405                                                                                                                                                                                 
[1]	validation-rmse:7.82019                                                                                                                                                                                 
[2]	validation-rmse:7.17232                                                                                                                                                                                 
[3]	validation-rmse:6.87437                                                                                                                                                                                 
[4]	validation-rmse:6.72439                                                                                                                                                         





[0]	validation-rmse:7.30162                                                                                                                                                                                 
[1]	validation-rmse:6.71972                                                                                                                                                                                 
[2]	validation-rmse:6.62387                                                                                                                                                                                 
[3]	validation-rmse:6.59814                                                                                                                                                                                 
[4]	validation-rmse:6.58743                                                                                                                                                         





[0]	validation-rmse:11.43617                                                                                                                                                                                
[1]	validation-rmse:10.75705                                                                                                                                                                                
[2]	validation-rmse:10.16590                                                                                                                                                                                
[3]	validation-rmse:9.65353                                                                                                                                                                                 
[4]	validation-rmse:9.21001                                                                                                                                                         





[0]	validation-rmse:11.53630                                                                                                                                                                                
[1]	validation-rmse:10.93173                                                                                                                                                                                
[2]	validation-rmse:10.39339                                                                                                                                                                                
[3]	validation-rmse:9.91442                                                                                                                                                                                 
[4]	validation-rmse:9.49088                                                                                                                                                         





[0]	validation-rmse:9.22931                                                                                                                                                                                 
[1]	validation-rmse:7.76947                                                                                                                                                                                 
[2]	validation-rmse:7.10815                                                                                                                                                                                 
[3]	validation-rmse:6.80448                                                                                                                                                                                 
[4]	validation-rmse:6.66199                                                                                                                                                         

In [None]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR
mlflow.set_experiment(experiment_id="0")
mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):

    with mlflow.start_run():

        mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
        mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)
        