# `vaex` @ PyData Budapest 2020

## Machine Learning Example - "Deployment"

To find out more details check out
[ML impossible: Train 1 billion samples in 5 minutes on your laptop using Vaex and Scikit-Learn](https://towardsdatascience.com/ml-impossible-train-a-1-billion-sample-model-in-20-minutes-with-vaex-and-scikit-learn-on-your-9e2968e6f385).

Running this notebooks requires `vaex==3.0.0`

In [1]:
import vaex

import warnings; warnings.simplefilter('ignore')

### Load the test data

In [2]:
df = vaex.open('/data/taxi/yellow_taxi_2012.hdf5')

# Train / test split (by date)
df_train, df_test = df.ml.train_test_split(test_size=0.15)

print(f'Number of samples in the training set: {len(df_train):,}')
print(f'Number of samples in the test set:       {len(df_test):,}')

# Check if the lengths of the datasets match
assert len(df) == len(df_test) + len(df_train)

Number of samples in the training set: 151,762,675
Number of samples in the test set:       26,781,649


### Inspect the test set

In [3]:
df_test

#,vendor_id,pickup_datetime,dropoff_datetime,passenger_count,payment_type,trip_distance,pickup_longitude,pickup_latitude,rate_code,store_and_fwd_flag,dropoff_longitude,dropoff_latitude,fare_amount,surcharge,mta_tax,tip_amount,tolls_amount,total_amount
0,CMT,2012-01-10 23:55:50.000000000,2012-01-11 00:03:39.000000000,1,CRD,1.7000000476837158,-73.99468994140625,40.725032806396484,1.0,0.0,-73.9759521484375,40.73078155517578,6.900000095367432,0.5,0.5,1.0,0.0,8.899999618530273
1,CMT,2012-01-11 19:18:25.000000000,2012-01-11 19:26:10.000000000,1,CSH,1.100000023841858,-73.98795318603516,40.75294876098633,1.0,0.0,-73.9945297241211,40.76103973388672,6.099999904632568,1.0,0.5,0.0,0.0,7.599999904632568
2,CMT,2012-01-11 19:19:19.000000000,2012-01-11 19:48:15.000000000,2,CRD,18.0,-73.78309631347656,40.6485481262207,2.0,0.0,-73.99613189697266,40.747623443603516,45.0,0.0,0.5,10.0600004196167,4.800000190734863,60.36000061035156
3,CMT,2012-01-11 19:19:21.000000000,2012-01-11 19:27:00.000000000,1,CRD,1.7000000476837158,-73.96751403808594,40.758453369140625,1.0,0.0,-73.95658111572266,40.779903411865234,6.900000095367432,1.0,0.5,1.0,0.0,9.399999618530273
4,CMT,2012-01-11 14:38:15.000000000,2012-01-11 14:43:51.000000000,1,CSH,1.2000000476837158,-74.01131439208984,40.711448669433594,1.0,0.0,-74.00286865234375,40.72813034057617,5.699999809265137,0.0,0.5,0.0,0.0,6.199999809265137
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
26781644,VTS,2012-02-11 23:28:00.000000000,2012-02-11 23:45:00.000000000,3,CSH,6.119999885559082,-73.98193359375,40.74324035644531,1.0,,-74.01116943359375,40.71636962890625,16.899999618530273,0.5,0.5,0.0,0.0,17.899999618530273
26781645,VTS,2012-02-11 22:46:00.000000000,2012-02-11 22:55:00.000000000,1,CSH,1.5499999523162842,-73.9814224243164,40.67967987060547,1.0,,-73.96326446533203,40.688507080078125,6.900000095367432,0.5,0.5,0.0,0.0,7.900000095367432
26781646,VTS,2012-02-11 23:22:00.000000000,2012-02-11 23:37:00.000000000,6,CSH,2.7899999618530273,-73.9787826538086,40.77758026123047,1.0,,-74.00340270996094,40.74978256225586,10.5,0.5,0.5,0.0,0.0,11.5
26781647,VTS,2012-02-11 23:26:00.000000000,2012-02-11 23:38:00.000000000,1,CRD,3.009999990463257,-74.00403594970703,40.73289108276367,1.0,,-74.00830078125,40.71181106567383,9.699999809265137,0.5,0.5,2.549999952316284,0.0,13.25


### Apply the state to the test DataFrame

In [4]:
df_test.state_load('./taxi_ml_state.json')

In [5]:
df_test

#,vendor_id,pickup_datetime,dropoff_datetime,passenger_count,payment_type,trip_distance,pickup_longitude,pickup_latitude,rate_code,store_and_fwd_flag,dropoff_longitude,dropoff_latitude,fare_amount,surcharge,mta_tax,tip_amount,tolls_amount,total_amount,trip_duration_min,trip_speed_mph,pickup_time,pickup_day,pickup_is_weekend,arc_distance,direction_angle,PCA_0,PCA_1,PCA_2,PCA_3,pickup_time_x,pickup_time_y,pickup_day_x,pickup_day_y,direction_angle_x,direction_angle_y,standard_scaled_arc_distance,predicted_duration_min,pred_final
0,CMT,2012-01-10 23:55:50.000000000,2012-01-11 00:03:39.000000000,1,CRD,1.7000000476837158,-73.99468994140625,40.725032806396484,1.0,0.0,-73.9759521484375,40.73078155517578,6.900000095367432,0.5,0.5,1.0,0.0,8.899999618530273,7.816666666666666,13.049040877742808,23.916666666666668,1,0,1.299300193786621,72.94400024414062,-0.029757998883724213,-0.004688636399805546,-0.01604314148426056,-0.014438532292842865,0.9997620270799091,-0.0218148850345609,0.6234898018587336,0.7818314824680298,0.2933062016963959,0.9560185670852661,0.14673005044460297,10.17210051291147,10.17210051291147
1,CMT,2012-01-11 19:18:25.000000000,2012-01-11 19:26:10.000000000,1,CSH,1.100000023841858,-73.98795318603516,40.75294876098633,1.0,0.0,-73.9945297241211,40.76103973388672,6.099999904632568,1.0,0.5,0.0,0.0,7.599999904632568,7.75,8.51612921684019,19.3,2,0,0.4798426032066345,-39.10504150390625,-0.003371396567672491,0.00664413021877408,-0.001381831243634224,0.01789921149611473,0.3338068592337709,-0.9426414910921784,-0.22252093395631434,0.9749279121818236,0.775990903377533,-0.6307440996170044,-0.7633066773414612,9.691096448206059,9.691096448206059
2,CMT,2012-01-11 19:19:21.000000000,2012-01-11 19:27:00.000000000,1,CRD,1.7000000476837158,-73.96751403808594,40.758453369140625,1.0,0.0,-73.95658111572266,40.779903411865234,6.900000095367432,1.0,0.5,1.0,0.0,9.399999618530273,7.65,13.333333707323261,19.316666666666666,2,0,0.8592353463172913,27.00758934020996,0.013282216154038906,-0.0064217280596494675,0.03550034016370773,-0.002972794696688652,0.3379167180033267,-0.9411760152563707,-0.22252093395631434,0.9749279121818236,0.8909463882446289,0.4541085362434387,-0.3419775664806366,9.76290952872819,9.76290952872819
3,CMT,2012-01-11 14:38:15.000000000,2012-01-11 14:43:51.000000000,1,CSH,1.2000000476837158,-74.01131439208984,40.711448669433594,1.0,0.0,-74.00286865234375,40.72813034057617,5.699999809265137,0.0,0.5,0.0,0.0,6.199999809265137,5.6,12.857143368039813,14.633333333333333,2,0,0.6643630266189575,26.85257339477539,-0.050594620406627655,0.00048208795487880707,-0.033315420150756836,0.006374814547598362,-0.7716245833877202,-0.6360782202777636,-0.22252093395631434,0.9749279121818236,0.8921717405319214,0.45169636607170105,-0.558390200138092,11.53348426067309,11.53348426067309
4,VTS,2012-01-09 19:14:00.000000000,2012-01-09 19:20:00.000000000,1,CSH,1.25,-73.99333190917969,40.727718353271484,1.0,,-73.9815673828125,40.7392463684082,6.099999904632568,1.0,0.5,0.0,0.0,7.599999904632568,6.0,12.5,19.233333333333334,0,0,0.842030942440033,45.581756591796875,-0.026794197037816048,-0.004166812635958195,-0.01217577699571848,-0.005045588128268719,0.31730465640509226,-0.9483236552061993,1.0,0.0,0.6998907923698425,0.714249849319458,-0.3610836863517761,9.972671716598873,9.972671716598873
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21229613,VTS,2012-02-11 23:28:00.000000000,2012-02-11 23:45:00.000000000,3,CSH,6.119999885559082,-73.98193359375,40.74324035644531,1.0,,-74.01116943359375,40.71636962890625,16.899999618530273,0.5,0.5,0.0,0.0,17.899999618530273,17.0,21.59999959609088,23.466666666666665,5,1,2.0838654041290283,-132.586181640625,-0.007537414785474539,-0.003992264624685049,-0.047707557678222656,0.0066642072051763535,0.9902680687415701,-0.13917310096006674,-0.2225209339563146,-0.9749279121818236,-0.6766984462738037,-0.7362602949142456,1.0180175304412842,13.41835160659671,13.41835160659671
21229614,VTS,2012-02-11 22:46:00.000000000,2012-02-11 22:55:00.000000000,1,CSH,1.5499999523162842,-73.9814224243164,40.67967987060547,1.0,,-73.96326446533203,40.688507080078125,6.900000095367432,0.5,0.5,0.0,0.0,7.900000095367432,9.0,10.333333015441895,22.766666666666666,5,1,1.26585853099823,64.07402038574219,-0.05811910703778267,-0.04248497262597084,-0.04395797848701477,-0.048627279698848724,0.9483236552061991,-0.3173046564050927,-0.2225209339563146,-0.9749279121818236,0.4372095465660095,0.8993596434593201,0.10959190130233765,10.303818906286105,10.303818906286105
21229615,VTS,2012-02-11 23:22:00.000000000,2012-02-11 23:37:00.000000000,6,CSH,2.7899999618530273,-73.9787826538086,40.77758026123047,1.0,,-74.00340270996094,40.74978256225586,10.5,0.5,0.5,0.0,0.0,11.5,15.0,11.15999984741211,23.366666666666667,5,1,1.781661033630371,-138.4691162109375,0.021843839436769485,0.014060418121516705,-0.0156773142516613,0.01894466206431389,0.9862856015372314,-0.16504760586067735,-0.2225209339563146,-0.9749279121818236,-0.7485985159873962,-0.6630235910415649,0.68240886926651,12.045755182438363,12.045755182438363
21229616,VTS,2012-02-11 23:26:00.000000000,2012-02-11 23:38:00.000000000,1,CRD,3.009999990463257,-74.00403594970703,40.73289108276367,1.0,,-74.00830078125,40.71181106567383,9.699999809265137,0.5,0.5,2.549999952316284,0.0,13.25,12.0,15.049999952316284,23.433333333333334,5,1,0.49788349866867065,-168.56251525878906,-0.02906632050871849,0.007502423599362373,-0.04987724870443344,0.0017344895750284195,0.9890158633619168,-0.14780941112961052,-0.2225209339563146,-0.9749279121818236,-0.9801416397094727,-0.19829867780208588,-0.7432716488838196,8.72141159892713,8.72141159892713


In [6]:
df_test.predicted_duration_min.minmax(progress='widget')

HBox(children=(FloatProgress(value=0.0, max=1.0), Label(value='In progress...')))

array([-1.34434498, 62.22383593])

In [7]:
df_test.pred_final.minmax(progress='widget')

HBox(children=(FloatProgress(value=0.0, max=1.0), Label(value='In progress...')))

array([ 3., 25.])

# Thank you!