In [1]:
sc

In [80]:
trips = sqlContext.read.format("com.mongodb.spark.sql.DefaultSource").\
    option("uri","mongodb://127.0.0.1/bikeshare.trip").load()

trips.show(5)

+--------------------+-------+--------+---------------+--------------+-------------------+----+---------------+----------------+--------------------+-----------------+--------+
|                 _id|bike_id|duration|       end_date|end_station_id|   end_station_name|  id|     start_date|start_station_id|  start_station_name|subscription_type|zip_code|
+--------------------+-------+--------+---------------+--------------+-------------------+----+---------------+----------------+--------------------+-----------------+--------+
|[5a5ea615399107bd...|    661|      70|8/29/2013 14:43|            10| San Jose City Hall|4607|8/29/2013 14:42|              10|  San Jose City Hall|       Subscriber|   95138|
|[5a5ea615399107bd...|    319|      83|8/29/2013 12:04|            67|     Market at 10th|4299|8/29/2013 12:02|              66|South Van Ness at...|       Subscriber|   94103|
|[5a5ea615399107bd...|    527|     103|8/29/2013 18:56|            59|Golden Gate at Polk|4927|8/29/2013 18:54|    

In [81]:
trips = trips.drop('_id', 'bike_id', 'end_station_name', 'start_station_name', 'id', 'zip_code', 'end_date')
trips.show(5)

+--------+--------------+---------------+----------------+-----------------+
|duration|end_station_id|     start_date|start_station_id|subscription_type|
+--------+--------------+---------------+----------------+-----------------+
|      70|            10|8/29/2013 14:42|              10|       Subscriber|
|      83|            67|8/29/2013 12:02|              66|       Subscriber|
|     103|            59|8/29/2013 18:54|              59|       Subscriber|
|     109|             5|8/29/2013 13:25|               4|       Subscriber|
|     111|             8|8/29/2013 14:02|               8|       Subscriber|
+--------+--------------+---------------+----------------+-----------------+
only showing top 5 rows



In [82]:
trips.cache()
trips.count()

669959

In [83]:
from pyspark.sql.types import *
from pyspark.sql.functions import udf
from datetime import datetime

def get_day_of_week(ts):
    print(ts)
    date = datetime.strptime(ts, "%m/%d/%Y %H:%M")
    return date.isoweekday()

get_dow = udf(lambda x: get_day_of_week(x), IntegerType())

def get_hour(ts):
    print(ts)
    date = datetime.strptime(ts, "%m/%d/%Y %H:%M")
    return date.hour

get_hr = udf(lambda x: get_hour(x), IntegerType())


trips = trips.withColumn('day_of_week', get_dow(trips['start_date']))
trips = trips.withColumn('hour', get_hr(trips['start_date']))
trips = trips.drop('start_date')
trips.show(5)

+--------+--------------+----------------+-----------------+-----------+----+
|duration|end_station_id|start_station_id|subscription_type|day_of_week|hour|
+--------+--------------+----------------+-----------------+-----------+----+
|      70|            10|              10|       Subscriber|          4|  14|
|      83|            67|              66|       Subscriber|          4|  12|
|     103|            59|              59|       Subscriber|          4|  18|
|     109|             5|               4|       Subscriber|          4|  13|
|     111|             8|               8|       Subscriber|          4|  14|
+--------+--------------+----------------+-----------------+-----------+----+
only showing top 5 rows



In [84]:
trips.printSchema()

root
 |-- duration: integer (nullable = true)
 |-- end_station_id: integer (nullable = true)
 |-- start_station_id: integer (nullable = true)
 |-- subscription_type: string (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- hour: integer (nullable = true)



In [85]:
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import CrossValidator,ParamGridBuilder
from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder

In [96]:
train, test = trips.randomSplit([0.9, 0.1])

In [97]:
train.count(), test.count()

(602773, 67186)

In [98]:
def implement_string_indexer(cols, train, test):
    for c in cols:
        si = StringIndexer(inputCol=c, outputCol=c+'_si')
        sm = si.fit(train)
        train = sm.transform(train).drop(c)
        train = train.withColumnRenamed(c + '_si', c)
        test = sm.transform(test).drop(c)
        test = test.withColumnRenamed(c + '_si', c)
        return (train,test)

cols = ['subscription_type']

train_si, test_si = implement_string_indexer(cols, train, test)

In [99]:
train_si.count(), test_si.count()

(602773, 67186)

In [88]:
train_si.printSchema()

root
 |-- duration: integer (nullable = true)
 |-- end_station_id: integer (nullable = true)
 |-- start_station_id: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- hour: integer (nullable = true)
 |-- subscription_type: double (nullable = true)



In [89]:
train_si.show(5)

+--------+--------------+----------------+-----------+----+-----------------+
|duration|end_station_id|start_station_id|day_of_week|hour|subscription_type|
+--------+--------------+----------------+-----------+----+-----------------+
|      60|            24|              24|          2|   8|              0.0|
|      60|            41|              41|          1|  14|              0.0|
|      60|            50|              50|          1|   9|              0.0|
|      60|            60|              60|          7|  21|              0.0|
|      60|            65|              65|          1|  18|              0.0|
+--------+--------------+----------------+-----------+----+-----------------+
only showing top 5 rows



In [90]:
def implement_one_hot_encoding(cols, train, test):
    for c in cols:
        ohe = OneHotEncoder(inputCol=c, outputCol=c+'_ohe')
        train = ohe.transform(train).drop(c)
        train = train.withColumnRenamed(c + '_ohe', c)
        test = ohe.transform(test).drop(c)
        test = test.withColumnRenamed(c + '_ohe', c)
        return (train,test)

train_ohe, test_ohe = implement_one_hot_encoding(cols, train_si, test_si)

train_ohe.show(5)

+--------+--------------+----------------+-----------+----+-----------------+
|duration|end_station_id|start_station_id|day_of_week|hour|subscription_type|
+--------+--------------+----------------+-----------+----+-----------------+
|      60|            24|              24|          2|   8|    (1,[0],[1.0])|
|      60|            41|              41|          1|  14|    (1,[0],[1.0])|
|      60|            50|              50|          1|   9|    (1,[0],[1.0])|
|      60|            60|              60|          7|  21|    (1,[0],[1.0])|
|      60|            65|              65|          1|  18|    (1,[0],[1.0])|
+--------+--------------+----------------+-----------+----+-----------------+
only showing top 5 rows



In [101]:
accu = sc.accumulator(0)
test_ohe.foreach(lambda x: accu.add(1))
accu

Accumulator<id=0, value=66674>

In [102]:
train_ohe.count(), test_ohe.count()

(603285, 66674)

In [103]:
train_ohe.printSchema()

root
 |-- duration: integer (nullable = true)
 |-- end_station_id: integer (nullable = true)
 |-- start_station_id: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- hour: integer (nullable = true)
 |-- subscription_type: vector (nullable = true)



In [104]:
input_cols = ['start_station_id', 'end_station_id', 'day_of_week', 'hour', 'subscription_type']

va = VectorAssembler(outputCol = 'features', inputCols = input_cols)
train_transformed = va.transform(train_ohe).select('features', 'duration').withColumnRenamed('duration', 'label')
test_transformed = va.transform(test_ohe).select('features', 'duration').withColumnRenamed('duration', 'label')

train_transformed.show(5)

+--------------------+-----+
|            features|label|
+--------------------+-----+
|[24.0,24.0,2.0,8....|   60|
|[41.0,41.0,1.0,14...|   60|
|[50.0,50.0,1.0,9....|   60|
|[60.0,60.0,7.0,21...|   60|
|[65.0,65.0,1.0,18...|   60|
+--------------------+-----+
only showing top 5 rows



In [105]:
train_transformed.cache()
test_transformed.cache()
train_transformed.count(), test_transformed.count()

(603285, 66674)