In [112]:
import os
import pickle
import uuid

import pandas as pd

# !pip install mlflow
import mlflow

from sklearn.feature_extraction import DictVectorizer
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import make_pipeline

In [113]:
year = 2021
month = 2
taxi_type = 'green_tripdata_'

input_file = f'data/{taxi_type}{year:04d}-{month:02d}.parquet'
output_file = f'output/{taxi_type}{year:04d}-{month:02d}.parquet'

mlflow.set_tracking_uri("http://127.0.0.1:5051")
mlflow.set_experiment("green-taxi-duration")

RUN_ID = os.getenv('RUN_ID', '6e6e2893453049cf88231bc93b4e8e83')




In [114]:
def generate_uuids(n):
    ride_ids = []
    for i in range(n):
        ride_ids.append(str(uuid.uuid4()))
    return ride_ids

def read_dataframe(filename: str):
    df = pd.read_parquet(filename)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.dt.total_seconds() / 60
    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    df['ride_id'] = generate_uuids(len(df))    
    return df


def prepare_dictionaries(df: pd.DataFrame):
    df['PU_DO'] = df['PULocationID'] + '_' + df['DOLocationID']
    categorical = ['PU_DO']
    numerical = ['trip_distance']
    dicts = df[categorical + numerical].to_dict(orient='records')
    return dicts

def load_model(run_id):
    # logged_model = f's3://mlflow-models-alexey/1/{run_id}/artifacts/model'
    logged_model = f'runs:/{run_id}/model'
    model = mlflow.pyfunc.load_model(logged_model)
    return model


In [115]:
def apply_model(input_file, run_id,  output_file):
    # !pip install pyarrow
    df = read_dataframe(input_file)

    dicts = prepare_dictionaries(df)

    model = load_model(run_id)
    y_pred = model.predict(dicts)

    df_result = pd.DataFrame()
    df_result['ride_id'] = df['ride_id']
    df_result['lpep_pickup_datetime'] = df['lpep_pickup_datetime']
    df_result['lpep_dropoff_datetime'] = df['lpep_dropoff_datetime']
    df_result['PULocationID'] = df['PULocationID']
    df_result['DOLocationID'] = df['DOLocationID']
    df_result['actual_duration'] = df['duration']
    df_result['predicted_duration'] = y_pred
    df_result['diff'] = df_result['actual_duration'] - df_result['predicted_duration']
    df_result['model_version'] = RUN_ID
    df_result.head(3)

    df_result.to_parquet(output_file)

    return df_result

In [116]:
# !mkdir output

apply_model(input_file, RUN_ID,  output_file)

Unnamed: 0,ride_id,lpep_pickup_datetime,lpep_dropoff_datetime,PULocationID,DOLocationID,actual_duration,predicted_duration,diff,model_version
0,0e5537a1-7b00-4f02-8b33-572a5d1d8050,2021-02-01 00:34:03,2021-02-01 00:51:58,130,205,17.916667,21.294545,-3.377879,6e6e2893453049cf88231bc93b4e8e83
1,3415f354-9644-4c02-9b07-d0258d5f4bb0,2021-02-01 00:04:00,2021-02-01 00:10:30,152,244,6.500000,14.785001,-8.285001,6e6e2893453049cf88231bc93b4e8e83
2,e738ffcd-2bb3-4de6-91b8-69f06b8136c8,2021-02-01 00:18:51,2021-02-01 00:34:06,152,48,15.250000,21.294545,-6.044545,6e6e2893453049cf88231bc93b4e8e83
3,5ca0c961-eb6d-47aa-901d-55528080c8a5,2021-02-01 00:53:27,2021-02-01 01:11:41,152,241,18.233333,21.294545,-3.061212,6e6e2893453049cf88231bc93b4e8e83
4,0ec554d5-d7fb-46fd-950c-3fbd150d30b2,2021-02-01 00:57:46,2021-02-01 01:06:44,75,42,8.966667,15.696657,-6.729991,6e6e2893453049cf88231bc93b4e8e83
...,...,...,...,...,...,...,...,...,...
64567,106c266e-347a-4b1d-a353-22daf2c67e5d,2021-02-28 22:19:00,2021-02-28 22:29:00,129,7,10.000000,21.069957,-11.069957,6e6e2893453049cf88231bc93b4e8e83
64568,b15af018-70f6-4aa2-a2f7-dc9fa33576ea,2021-02-28 23:18:00,2021-02-28 23:27:00,116,166,9.000000,15.696657,-6.696657,6e6e2893453049cf88231bc93b4e8e83
64569,4eeff5c7-c735-45ad-8b2d-fcf91fb782f4,2021-02-28 23:44:00,2021-02-28 23:58:00,74,151,14.000000,20.771487,-6.771487,6e6e2893453049cf88231bc93b4e8e83
64570,9965832b-f905-4ab5-8e88-577c0b611b65,2021-02-28 23:07:00,2021-02-28 23:14:00,42,42,7.000000,15.113777,-8.113777,6e6e2893453049cf88231bc93b4e8e83
