In [1]:
from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark.sql.types import StructType,StructField, StringType, IntegerType, DoubleType
from pyspark.sql.functions import col,sum,avg,max
from pyspark.sql import Row
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.regression import LinearRegressionWithSGD
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.regression import GBTRegressor

import numpy as np
import datetime
import os
os.environ['PYSPARK_PYTHON'] = "/usr/local/bin/python3"
os.environ['PYSPARK_DRIVER_PYTHON'] = "/usr/local/bin/python3"

def process_record(record):
    date_obj = datetime.datetime.strptime(record['date'], "%d.%m.%Y").date()
    return Row(year=date_obj.year, month=date_obj.month, day=date_obj.day, shop_id=record['shop_id'], \
               date_block_num=record['date_block_num'], item_id=record['item_id'], item_price=record['item_price'], item_cnt_day=record['item_cnt_day'])

In [2]:
sc = SparkContext('local', 'linear')
spark = SparkSession.builder.appName("Linear regression w/ Spark ML").getOrCreate()


In [3]:
structureSchema = StructType([ \
    StructField("date",StringType(),True), \
    StructField("date_block_num",IntegerType(),True), \
    StructField("shop_id",IntegerType(),True), \
    StructField("item_id", IntegerType(), True), \
    StructField("item_price", DoubleType(), True), \
    StructField("item_cnt_day", DoubleType(), True) \
  ])
train_df = spark.read.schema(structureSchema).format('csv').options(header='true').load('./sales_train.csv')
transformed = train_df.rdd.map(process_record)
train_df = transformed.toDF()
train_df = train_df.filter(train_df.item_cnt_day > 0) 
train_df

DataFrame[year: bigint, month: bigint, day: bigint, shop_id: bigint, date_block_num: bigint, item_id: bigint, item_price: double, item_cnt_day: double]

In [4]:
train_df = train_df.groupBy("date_block_num","shop_id","item_id").agg(avg("item_price").alias("avg_item_price"), \
                                                     sum("item_cnt_day").alias("item_cnt_month"))
train_df.show(3)

+--------------+-------+-------+--------------+--------------+
|date_block_num|shop_id|item_id|avg_item_price|item_cnt_month|
+--------------+-------+-------+--------------+--------------+
|             0|     25|   3773|         299.0|           1.0|
|             0|     24|  17544|         999.0|           1.0|
|             0|     24|  10836|         149.0|           2.0|
+--------------+-------+-------+--------------+--------------+
only showing top 3 rows



In [5]:
vectorAssembler = VectorAssembler(inputCols = ['date_block_num', 'shop_id', 'item_id', 'avg_item_price'], outputCol = 'features')
train_examples = vectorAssembler.transform(train_df)
train_examples = train_examples.select('features', 'item_cnt_month')
train_examples.show(10)

+--------------------+--------------+
|            features|item_cnt_month|
+--------------------+--------------+
|[0.0,25.0,3773.0,...|           1.0|
|[0.0,24.0,17544.0...|           1.0|
|[0.0,24.0,10836.0...|           2.0|
|[0.0,25.0,14862.0...|           1.0|
|[0.0,25.0,17489.0...|           1.0|
|[0.0,25.0,16122.0...|           2.0|
|[0.0,25.0,15592.0...|           1.0|
|[0.0,25.0,8459.0,...|           1.0|
|[0.0,25.0,9768.0,...|           3.0|
|[0.0,19.0,21809.0...|           1.0|
+--------------------+--------------+
only showing top 10 rows



In [6]:
splits = train_examples.randomSplit([0.7, 0.3])
train_df = splits[0]
test_df = splits[1]

In [8]:
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml.evaluation import RegressionEvaluator

dt = DecisionTreeRegressor(featuresCol ='features', labelCol = 'item_cnt_month')
dt_model = dt.fit(train_df)
dt_predictions = dt_model.transform(test_df)
dt_evaluator = RegressionEvaluator(
    labelCol="item_cnt_month", predictionCol="prediction", metricName="rmse")
rmse = dt_evaluator.evaluate(dt_predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

Root Mean Squared Error (RMSE) on test data = 6.6306
