# Start Spark session

In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('pyspark-test').getOrCreate()
spark

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/05/17 17:15:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/05/17 17:15:11 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
22/05/17 17:15:11 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


# Get data

In [2]:
data = spark.read.csv('../data/test1.csv', header=True, inferSchema=True)
data.show()

+------+---+----------+------+
|  name|age|experience|salary|
+------+---+----------+------+
| krish| 31|        10| 30000|
|   sud| 30|         8| 25000|
| sunny| 29|         4| 20000|
|  paul| 24|         3| 20000|
|harsha| 21|         1| 15000|
|  shub| 23|         2| 18000|
+------+---+----------+------+



In [3]:
data.printSchema()

root
 |-- name: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- experience: integer (nullable = true)
 |-- salary: integer (nullable = true)



In [4]:
# define features
from pyspark.ml.feature import VectorAssembler

feature_assembler = VectorAssembler(
    inputCols=["age", "experience"],
    outputCol="X"
)

data = feature_assembler.transform(data)
data.show()

+------+---+----------+------+-----------+
|  name|age|experience|salary|          X|
+------+---+----------+------+-----------+
| krish| 31|        10| 30000|[31.0,10.0]|
|   sud| 30|         8| 25000| [30.0,8.0]|
| sunny| 29|         4| 20000| [29.0,4.0]|
|  paul| 24|         3| 20000| [24.0,3.0]|
|harsha| 21|         1| 15000| [21.0,1.0]|
|  shub| 23|         2| 18000| [23.0,2.0]|
+------+---+----------+------+-----------+



In [5]:
# use only X and y
data_final = data.select('X', 'salary')
data_final.show()

+-----------+------+
|          X|salary|
+-----------+------+
|[31.0,10.0]| 30000|
| [30.0,8.0]| 25000|
| [29.0,4.0]| 20000|
| [24.0,3.0]| 20000|
| [21.0,1.0]| 15000|
| [23.0,2.0]| 18000|
+-----------+------+



In [8]:
# train test split
data_train, data_test = data_final.randomSplit([0.75, 0.25])

print(data_train.show())
print(data_test.show())

+-----------+------+
|          X|salary|
+-----------+------+
| [21.0,1.0]| 15000|
| [24.0,3.0]| 20000|
| [29.0,4.0]| 20000|
| [30.0,8.0]| 25000|
|[31.0,10.0]| 30000|
+-----------+------+

None
+----------+------+
|         X|salary|
+----------+------+
|[23.0,2.0]| 18000|
+----------+------+

None


In [10]:
# train model
from pyspark.ml.regression import LinearRegression

model = LinearRegression(featuresCol='X', labelCol='salary')
model = model.fit(data_train)

22/05/17 17:19:25 WARN Instrumentation: [43576af3] regParam is zero, which might cause numerical instability and overfitting.


In [11]:
model


LinearRegressionModel: uid=LinearRegression_8fcf8aa23018, numFeatures=2

In [12]:
model.coefficients

DenseVector([-90.5483, 1608.7819])

In [13]:
model.intercept

16079.136690647425

In [14]:
# prediction
results = model.evaluate(data_test)
results.predictions.show()



+----------+------+-----------------+
|         X|salary|       prediction|
+----------+------+-----------------+
|[23.0,2.0]| 18000|17214.09079632846|
+----------+------+-----------------+



In [15]:
# check MAE
results.meanAbsoluteError

785.909203671541

In [16]:
# check MSE
results.meanSquaredError

617653.2764156357