In [1]:
import pandas as pd

In [26]:
pd.__version__

'2.2.3'

In [27]:
!pip install pyarrow



In [28]:
import pandas as pd
import pickle
import seaborn as sns
import matplotlib.pyplot as plt

In [29]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso,Ridge
from sklearn.metrics import root_mean_squared_error

In [30]:
import mlflow
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc__taxi_experiment")

<Experiment: artifact_location='/workspaces/MLOps-zoomcamp/02_experiment_tracking/mlruns/2', creation_time=1740689925026, experiment_id='2', last_update_time=1740689925026, lifecycle_stage='active', name='nyc__taxi_experiment', tags={}>

#### Green taxi trip-2021

In [31]:
def read_dataframe(filename):
    if filename.endswith('.csv'):
        df=pd.read_csv(filename)
        df.lpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
        df.lpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)
    elif filename.endswith('.parquet'):
        df=pd.read_parquet(filename)
    df['duration']=df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration=df.duration.apply(lambda m : m.total_seconds() / 60)
    
    df=df[(df.duration >= 1) & (df.duration <= 60)]
    categorical=['PULocationID','DOLocationID']
    numerical=['trip_distance']
    
    df[categorical]=df[categorical].astype(str)
    return df    

In [32]:
df_train=read_dataframe('https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-01.parquet')
df_val=read_dataframe('https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-02.parquet')

In [33]:
len(df_train),len(df_val)

(73908, 61921)

In [34]:
df_train.head()

Unnamed: 0,VendorID,lpep_pickup_datetime,lpep_dropoff_datetime,store_and_fwd_flag,RatecodeID,PULocationID,DOLocationID,passenger_count,trip_distance,fare_amount,...,mta_tax,tip_amount,tolls_amount,ehail_fee,improvement_surcharge,total_amount,payment_type,trip_type,congestion_surcharge,duration
0,2,2021-01-01 00:15:56,2021-01-01 00:19:52,N,1.0,43,151,1.0,1.01,5.5,...,0.5,0.0,0.0,,0.3,6.8,2.0,1.0,0.0,3.933333
1,2,2021-01-01 00:25:59,2021-01-01 00:34:44,N,1.0,166,239,1.0,2.53,10.0,...,0.5,2.81,0.0,,0.3,16.86,1.0,1.0,2.75,8.75
2,2,2021-01-01 00:45:57,2021-01-01 00:51:55,N,1.0,41,42,1.0,1.12,6.0,...,0.5,1.0,0.0,,0.3,8.3,1.0,1.0,0.0,5.966667
3,2,2020-12-31 23:57:51,2021-01-01 00:04:56,N,1.0,168,75,1.0,1.99,8.0,...,0.5,0.0,0.0,,0.3,9.3,2.0,1.0,0.0,7.083333
7,2,2021-01-01 00:26:31,2021-01-01 00:28:50,N,1.0,75,75,6.0,0.45,3.5,...,0.5,0.96,0.0,,0.3,5.76,1.0,1.0,0.0,2.316667


In [35]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [36]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [37]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

### Linear Regression

In [38]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

root_mean_squared_error(y_val, y_pred)

7.758715209663881

In [39]:
with open('models/lin_reg.bin','wb') as f_out:
    pickle.dump((dv,lr),f_out)

### Lasso Regression

In [40]:
with mlflow.start_run():
    mlflow.set_tag("developer","slv")
    alpha=0.1
    mlflow.log_param("alpha",alpha)
    lasso_reg = Lasso(alpha)
    lasso_reg.fit(X_train, y_train)
    
    y_pred=lasso_reg.predict(X_val)
    rmse=root_mean_squared_error(y_val,y_pred)
    mlflow.log_metric("rmse",rmse)    
    mlflow.log_artifact(local_path="models/lin_reg.bin",artifact_path="models_pickle")

### xgboost - Hyperparameter optimization

In [41]:
import xgboost as xgb
from hyperopt import fmin,tpe,hp,STATUS_OK,Trials
from hyperopt.pyll import scope


In [42]:
train=xgb.DMatrix(X_train,label=y_train)
valid=xgb.DMatrix(X_val,label=y_val)

In [49]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model","xgboost")
        mlflow.log_params(params)
        booster=xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=100,
            evals=[(valid,'validation')],
            early_stopping_rounds=50
        )
        y_pred=booster.predict(valid)
        rmse=root_mean_squared_error(y_val,y_pred)
        mlflow.log_metric("rmse",rmse)
        with open("models/preprocessor.b","wb") as f_out:
            pickle.dump(dv,f_out)
        mlflow.log_artifact("models/preprocessor.b",artifact_path="models.xgboost")
    return {'loss':rmse,'status':STATUS_OK}
            


In [50]:
search_space={
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

In [51]:
result=fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=10,
    trials=Trials()
)

  0%|                                                | 0/10 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:10.79832                                                                
[1]	validation-rmse:9.70998                                                                 
[2]	validation-rmse:8.88571                                                                 
[3]	validation-rmse:8.26666                                                                 
[4]	validation-rmse:7.80737                                                                 
[5]	validation-rmse:7.46900                                                                 
[6]	validation-rmse:7.22087                                                                 
[7]	validation-rmse:7.04119                                                                 
[8]	validation-rmse:6.90894                                                                 
[9]	validation-rmse:6.81058                                                                 
[10]	validation-rmse:6.73715                                          

[88]	validation-rmse:6.41450                                                                
[89]	validation-rmse:6.41403                                                                
[90]	validation-rmse:6.41403                                                                
[91]	validation-rmse:6.41375                                                                
[92]	validation-rmse:6.41330                                                                
[93]	validation-rmse:6.41270                                                                
[94]	validation-rmse:6.41218                                                                
[95]	validation-rmse:6.41160                                                                
[96]	validation-rmse:6.41099                                                                
[97]	validation-rmse:6.41088                                                                
[98]	validation-rmse:6.41003                                          




[0]	validation-rmse:9.89499                                                                 
[1]	validation-rmse:8.47355                                                                 
[2]	validation-rmse:7.63559                                                                 
[3]	validation-rmse:7.15393                                                                 
[4]	validation-rmse:6.87746                                                                 
[5]	validation-rmse:6.71866                                                                 
[6]	validation-rmse:6.62308                                                                 
[7]	validation-rmse:6.56246                                                                 
[8]	validation-rmse:6.52356                                                                 
[9]	validation-rmse:6.49669                                                                 
[10]	validation-rmse:6.47814                                          

[88]	validation-rmse:6.34257                                                                
[89]	validation-rmse:6.34249                                                                
[90]	validation-rmse:6.34208                                                                
[91]	validation-rmse:6.34213                                                                
[92]	validation-rmse:6.34001                                                                
[93]	validation-rmse:6.33901                                                                
[94]	validation-rmse:6.33922                                                                
[95]	validation-rmse:6.33850                                                                
[96]	validation-rmse:6.33820                                                                
[97]	validation-rmse:6.33821                                                                
[98]	validation-rmse:6.33829                                          




[0]	validation-rmse:11.56992                                                                
[1]	validation-rmse:10.99344                                                                
[2]	validation-rmse:10.47855                                                                
[3]	validation-rmse:10.01950                                                                
[4]	validation-rmse:9.61106                                                                 
[5]	validation-rmse:9.24864                                                                 
[6]	validation-rmse:8.92872                                                                 
[7]	validation-rmse:8.64608                                                                 
[8]	validation-rmse:8.39700                                                                 
[9]	validation-rmse:8.17879                                                                 
[10]	validation-rmse:7.98642                                          

[88]	validation-rmse:6.53980                                                                
[89]	validation-rmse:6.53931                                                                
[90]	validation-rmse:6.53894                                                                
[91]	validation-rmse:6.53857                                                                
[92]	validation-rmse:6.53815                                                                
[93]	validation-rmse:6.53773                                                                
[94]	validation-rmse:6.53739                                                                
[95]	validation-rmse:6.53712                                                                
[96]	validation-rmse:6.53695                                                                
[97]	validation-rmse:6.53672                                                                
[98]	validation-rmse:6.53633                                          




[0]	validation-rmse:11.39570                                                                
[1]	validation-rmse:10.68910                                                                
[2]	validation-rmse:10.08060                                                                
[3]	validation-rmse:9.55919                                                                 
[4]	validation-rmse:9.11409                                                                 
[5]	validation-rmse:8.73612                                                                 
[6]	validation-rmse:8.41632                                                                 
[7]	validation-rmse:8.14613                                                                 
[8]	validation-rmse:7.91910                                                                 
[9]	validation-rmse:7.72818                                                                 
[10]	validation-rmse:7.56706                                          

[88]	validation-rmse:6.65674                                                                
[89]	validation-rmse:6.65644                                                                
[90]	validation-rmse:6.65601                                                                
[91]	validation-rmse:6.65520                                                                
[92]	validation-rmse:6.65470                                                                
[93]	validation-rmse:6.65454                                                                
[94]	validation-rmse:6.65429                                                                
[95]	validation-rmse:6.65384                                                                
[96]	validation-rmse:6.65330                                                                
[97]	validation-rmse:6.65290                                                                
[98]	validation-rmse:6.65251                                          




[0]	validation-rmse:8.22707                                                                 
[1]	validation-rmse:7.03387                                                                 
[2]	validation-rmse:6.69764                                                                 
[3]	validation-rmse:6.58407                                                                 
[4]	validation-rmse:6.54119                                                                 
[5]	validation-rmse:6.51993                                                                 
[6]	validation-rmse:6.50395                                                                 
[7]	validation-rmse:6.49571                                                                 
[8]	validation-rmse:6.49244                                                                 
[9]	validation-rmse:6.48705                                                                 
[10]	validation-rmse:6.47944                                          

[88]	validation-rmse:6.38196                                                                
[89]	validation-rmse:6.38136                                                                
[90]	validation-rmse:6.38162                                                                
[91]	validation-rmse:6.38238                                                                
[92]	validation-rmse:6.38227                                                                
[93]	validation-rmse:6.38245                                                                
[94]	validation-rmse:6.38129                                                                
[95]	validation-rmse:6.38178                                                                
[96]	validation-rmse:6.38131                                                                
[97]	validation-rmse:6.38198                                                                
[98]	validation-rmse:6.38118                                          




[0]	validation-rmse:11.64373                                                                
[1]	validation-rmse:11.12717                                                                
[2]	validation-rmse:10.65958                                                                
[3]	validation-rmse:10.23713                                                                
[4]	validation-rmse:9.85621                                                                 
[5]	validation-rmse:9.51301                                                                 
[6]	validation-rmse:9.20488                                                                 
[7]	validation-rmse:8.92878                                                                 
[8]	validation-rmse:8.68143                                                                 
[9]	validation-rmse:8.46071                                                                 
[10]	validation-rmse:8.26378                                          

[88]	validation-rmse:6.62231                                                                
[89]	validation-rmse:6.62208                                                                
[90]	validation-rmse:6.62174                                                                
[91]	validation-rmse:6.62139                                                                
[92]	validation-rmse:6.62083                                                                
[93]	validation-rmse:6.62055                                                                
[94]	validation-rmse:6.62028                                                                
[95]	validation-rmse:6.61997                                                                
[96]	validation-rmse:6.61961                                                                
[97]	validation-rmse:6.61932                                                                
[98]	validation-rmse:6.61908                                          




[0]	validation-rmse:8.33393                                                                 
[1]	validation-rmse:7.12772                                                                 
[2]	validation-rmse:6.78310                                                                 
[3]	validation-rmse:6.66947                                                                 
[4]	validation-rmse:6.62414                                                                 
[5]	validation-rmse:6.59980                                                                 
[6]	validation-rmse:6.58670                                                                 
[7]	validation-rmse:6.58027                                                                 
[8]	validation-rmse:6.57409                                                                 
[9]	validation-rmse:6.56596                                                                 
[10]	validation-rmse:6.56016                                          

[88]	validation-rmse:6.46783                                                                
[89]	validation-rmse:6.46759                                                                
[90]	validation-rmse:6.46729                                                                
[91]	validation-rmse:6.46708                                                                
[92]	validation-rmse:6.46669                                                                
[93]	validation-rmse:6.46672                                                                
[94]	validation-rmse:6.46621                                                                
[95]	validation-rmse:6.46601                                                                
[96]	validation-rmse:6.46525                                                                
[97]	validation-rmse:6.46462                                                                
[98]	validation-rmse:6.46453                                          




[0]	validation-rmse:8.13921                                                                 
[1]	validation-rmse:7.07602                                                                 
[2]	validation-rmse:6.80965                                                                 
[3]	validation-rmse:6.74021                                                                 
[4]	validation-rmse:6.71560                                                                 
[5]	validation-rmse:6.70034                                                                 
[6]	validation-rmse:6.69258                                                                 
[7]	validation-rmse:6.68773                                                                 
[8]	validation-rmse:6.68152                                                                 
[9]	validation-rmse:6.67612                                                                 
[10]	validation-rmse:6.67012                                          

[88]	validation-rmse:6.54321                                                                
[89]	validation-rmse:6.54208                                                                
[90]	validation-rmse:6.54170                                                                
[91]	validation-rmse:6.53887                                                                
[92]	validation-rmse:6.53755                                                                
[93]	validation-rmse:6.53768                                                                
[94]	validation-rmse:6.53686                                                                
[95]	validation-rmse:6.53714                                                                
[96]	validation-rmse:6.53527                                                                
[97]	validation-rmse:6.53448                                                                
[98]	validation-rmse:6.53268                                          




[3]	validation-rmse:6.80298                                                                 
[4]	validation-rmse:6.79236                                                                 
[5]	validation-rmse:6.76935                                                                 
[6]	validation-rmse:6.76361                                                                 
[7]	validation-rmse:6.75850                                                                 
[8]	validation-rmse:6.75076                                                                 
[9]	validation-rmse:6.74673                                                                 
[10]	validation-rmse:6.74410                                                                
[11]	validation-rmse:6.73656                                                                
[12]	validation-rmse:6.73309                                                                
[13]	validation-rmse:6.72919                                          

[91]	validation-rmse:6.61585                                                                
[92]	validation-rmse:6.61567                                                                
[93]	validation-rmse:6.61508                                                                
[94]	validation-rmse:6.61449                                                                
[95]	validation-rmse:6.61426                                                                
[96]	validation-rmse:6.61440                                                                
[97]	validation-rmse:6.61399                                                                
[98]	validation-rmse:6.61311                                                                
[99]	validation-rmse:6.61312                                                                
 90%|████████████████████▋  | 9/10 [04:01<00:15, 15.54s/trial, best loss: 6.337907152115211]




[0]	validation-rmse:7.75121                                                                 
[1]	validation-rmse:6.85662                                                                 
[2]	validation-rmse:6.66708                                                                 
[3]	validation-rmse:6.61213                                                                 
[4]	validation-rmse:6.59282                                                                 
[5]	validation-rmse:6.57541                                                                 
[6]	validation-rmse:6.56804                                                                 
[7]	validation-rmse:6.56145                                                                 
[8]	validation-rmse:6.55657                                                                 
[9]	validation-rmse:6.55288                                                                 
[10]	validation-rmse:6.54839                                          

[88]	validation-rmse:6.46464                                                                
[89]	validation-rmse:6.46371                                                                
[90]	validation-rmse:6.46321                                                                
[91]	validation-rmse:6.46334                                                                
[92]	validation-rmse:6.46378                                                                
[93]	validation-rmse:6.46383                                                                
[94]	validation-rmse:6.46393                                                                
[95]	validation-rmse:6.46399                                                                
[96]	validation-rmse:6.46317                                                                
[97]	validation-rmse:6.46289                                                                
[98]	validation-rmse:6.46392                                          

### MLflow-autologging

In [None]:
from sklearn.ensemble import RandomForestRegressor,GradientBoostingRegressor,ExtraTreesRegressor
from sklearn.svm import LinearSVR

mlflow.sklearn.autolog()
for model_class in(RandomForestRegressor,GradientBoostingRegressor,ExtraTreesRegressor,LinearSVR):
    with mlflow.start_run():
        mlflow.log_artifact("models/preprocessor.b",artifact_path="preprocessor")
        mlmodel=model_class()
        mlmodel.fit(X_train,y_train)
        y_pred=mlmodel.pred(X_val)
        rmse=root_mean_squared_error(y_val,y_pred)
        mlflow.log_metric("rmse",rmse)

