In [1]:
!python -V

Python 3.9.18


In [2]:
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import pandas as pd

In [3]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression, Lasso, Ridge
from sklearn.metrics import mean_squared_error

In [4]:
import mlflow

mlflow.set_tracking_uri("sqlite:///../mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location='/home/pytholic/mlops-zoomcamp-pytholic/02-experiment-tracking/mlruns/1', creation_time=1702972736474, experiment_id='1', last_update_time=1702972736474, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [5]:
def read_dataframe(filename):
    if filename.endswith('.csv'):
        df = pd.read_csv(filename)

        df.lpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
        df.lpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)
    elif filename.endswith('.parquet'):
        df = pd.read_parquet(filename)

    # Calculate duration
    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    # Filter between 1 min and 60 mins
    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']    
    df[categorical] = df[categorical].astype(str)
    
    return df

In [6]:
df_train = read_dataframe("../data/green_tripdata_2021-01.parquet")
df_val = read_dataframe("../data/green_tripdata_2021-02.parquet")

In [7]:
len(df_train), len(df_val)

(73908, 61921)

**Note:** There was an issue in applying separate transforms i.e. X_train and X_val yielded different shapes. So I applied transform after combining them and then separated them again. 

In [8]:
df_combined = pd.concat([df_train, df_val], axis=0)

In [9]:
df_combined['PU_DO'] = df_combined['PULocationID'] + '_' + df_combined['DOLocationID']

In [10]:
categorical = ['PU_DO']
numerical = ['trip_distance']

dv = DictVectorizer()
combined_dicts = df_combined[categorical + numerical].to_dict(orient='records')
X_combined = dv.fit_transform(combined_dicts)

In [11]:
X_train = X_combined[:73908, :]
X_val = X_combined[73908:, :]

In [12]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [14]:
import xgboost as xgb
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [15]:
# Prepare data for xgboost
train = xgb.DMatrix(X_train, label=y_train)
val = xgb.DMatrix(X_val, label=y_val)

In [19]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(val, "validation")],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)
        
    return {"loss": rmse, "status": STATUS_OK}


In [20]:
search_space = {
    "max_depth": scope.int(hp.quniform("max_depth", 4, 100, 1)), # depth of trees
    "learning_rate": hp.loguniform("learning_rate", -3, 0), # exp(-3), exp(0) => [0.05, 1]
    "reg_alpha": hp.loguniform("reg_alpha", -5, -1),
    "reg_lambda": hp.loguniform("reg_lambda", -6, -1),
    "min_child_weight": hp.loguniform("min_child_weight", -1, 3),
    "objective": "reg:linear",
    "seed": 42,
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:9.96892                           
[1]	validation-rmse:8.58102                           
[2]	validation-rmse:7.75115                           
[3]	validation-rmse:7.27333                           
[4]	validation-rmse:6.98708                           
[5]	validation-rmse:6.82734                           
[6]	validation-rmse:6.73527                           
[7]	validation-rmse:6.66852                           
[8]	validation-rmse:6.62437                           
[9]	validation-rmse:6.59841                           
[10]	validation-rmse:6.58231                          
[11]	validation-rmse:6.56667                          
[12]	validation-rmse:6.55591                          
[13]	validation-rmse:6.54881                          
[14]	validation-rmse:6.53942                          
[15]	validation-rmse:6.53508                          
[16]	validation-rmse:6.53235                          
[17]	validation-rmse:6.52777                          
[18]	valid




[0]	validation-rmse:8.55825                                                      
[1]	validation-rmse:7.22177                                                      
[2]	validation-rmse:6.77455                                                      
[3]	validation-rmse:6.60909                                                      
[4]	validation-rmse:6.53827                                                      
[5]	validation-rmse:6.50107                                                      
[6]	validation-rmse:6.48701                                                      
[7]	validation-rmse:6.48208                                                      
[8]	validation-rmse:6.47283                                                      
[9]	validation-rmse:6.46680                                                      
[10]	validation-rmse:6.46004                                                     
[11]	validation-rmse:6.45578                                                     
[12]	validation-




[0]	validation-rmse:11.21374                                                   
[1]	validation-rmse:10.37939                                                   
[2]	validation-rmse:9.68752                                                    
[3]	validation-rmse:9.11628                                                    
[4]	validation-rmse:8.64905                                                    
[5]	validation-rmse:8.26849                                                    
[6]	validation-rmse:7.96008                                                    
[7]	validation-rmse:7.71173                                                    
[8]	validation-rmse:7.51154                                                    
[9]	validation-rmse:7.34989                                                    
[10]	validation-rmse:7.22082                                                   
[11]	validation-rmse:7.11560                                                   
[12]	validation-rmse:7.03090            




[0]	validation-rmse:8.92417                                                      
[1]	validation-rmse:7.54685                                                      
[2]	validation-rmse:7.00879                                                      
[3]	validation-rmse:6.79461                                                      
[4]	validation-rmse:6.68979                                                      
[5]	validation-rmse:6.64199                                                      
[6]	validation-rmse:6.61149                                                      
[7]	validation-rmse:6.59719                                                      
[8]	validation-rmse:6.57993                                                      
[9]	validation-rmse:6.57173                                                      
[10]	validation-rmse:6.56524                                                     
[11]	validation-rmse:6.56269                                                     
[12]	validation-




[0]	validation-rmse:11.52793                                                     
[1]	validation-rmse:10.91908                                                     
[2]	validation-rmse:10.37635                                                     
[3]	validation-rmse:9.89974                                                      
[4]	validation-rmse:9.47337                                                      
[5]	validation-rmse:9.10617                                                      
[6]	validation-rmse:8.77470                                                      
[7]	validation-rmse:8.48903                                                      
[8]	validation-rmse:8.24144                                                      
[9]	validation-rmse:8.02162                                                      
[10]	validation-rmse:7.83214                                                     
[11]	validation-rmse:7.66750                                                     
[12]	validation-

KeyboardInterrupt: 

In [21]:
# Use best params to train model
params = {
    "learning_rate": 0.20472169880371677,
    "max_depth": 17,
    "min_child_weight": 1.2402611720043835,
    "objective": "reg:linear",
    "reg_alpha": 0.28567896734700793,
    "reg_lambda": 0.004264404814393109,
    "seed": 42
}

# Using autologging
mlflow.xgboost.autolog()

booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(val, "validation")],
            early_stopping_rounds=50
        )

2023/12/20 00:51:53 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID 'dff0c7637c3a4b2a9f01e9186fd8b51f', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current xgboost workflow


[0]	validation-rmse:10.64830
[1]	validation-rmse:9.50267


Parameters: { "max_Deoth" } are not used.



[2]	validation-rmse:8.68276
[3]	validation-rmse:8.10029
[4]	validation-rmse:7.69634
[5]	validation-rmse:7.41589
[6]	validation-rmse:7.22465
[7]	validation-rmse:7.08841
[8]	validation-rmse:6.99507
[9]	validation-rmse:6.92897
[10]	validation-rmse:6.88235
[11]	validation-rmse:6.84871
[12]	validation-rmse:6.82244
[13]	validation-rmse:6.80302
[14]	validation-rmse:6.78783
[15]	validation-rmse:6.77836
[16]	validation-rmse:6.77244
[17]	validation-rmse:6.76438
[18]	validation-rmse:6.76023
[19]	validation-rmse:6.75424
[20]	validation-rmse:6.74983
[21]	validation-rmse:6.74746
[22]	validation-rmse:6.74481
[23]	validation-rmse:6.74318
[24]	validation-rmse:6.73971
[25]	validation-rmse:6.73794
[26]	validation-rmse:6.73665
[27]	validation-rmse:6.73545
[28]	validation-rmse:6.73443
[29]	validation-rmse:6.73233
[30]	validation-rmse:6.73144
[31]	validation-rmse:6.72969
[32]	validation-rmse:6.72811
[33]	validation-rmse:6.72616
[34]	validation-rmse:6.72514
[35]	validation-rmse:6.72506
[36]	validation-rmse:6

