# Linear Regression using MLib

In [1]:
import findspark
findspark.init()

import pyspark
findspark.find()

'c:\\spark'

In [2]:
from __future__ import print_function

from pyspark.ml.regression import LinearRegression

from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors

In [3]:
# Create a SparkSession (Note, the config section is only for Windows!)
spark = SparkSession.builder.config("spark.sql.warehouse.dir", "file:///C:/temp").appName("LinearRegression").getOrCreate()

In [4]:
# Load up our data and convert it to the format MLLib expects.
inputLines = spark.sparkContext.textFile("regression.txt")
data = inputLines.map(lambda x: x.split(",")).map(lambda x: (float(x[0]), Vectors.dense(float(x[1]))))

# Note, there are lots of cases where you can avoid going from an RDD to a DataFrame.
# Perhaps you're importing data from a real database. Or you are using structured streaming to get your data.
    
# Convert this RDD to a DataFrame
colNames = ["label", "features"]
df = data.toDF(colNames)

In [7]:
# Let's split our data into training data and testing data
trainTest = df.randomSplit([0.5, 0.5])
trainingDF = trainTest[0]
testDF = trainTest[1]

In [8]:
# Now create our linear regression model
lir = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)

# Train the model using our training data
model = lir.fit(trainingDF)

In [9]:
# Now see if we can predict values in our test data.
# Generate predictions using our linear regression model for all features in our test dataframe:
fullPredictions = model.transform(testDF).cache()

In [10]:
# Extract the predictions and the "known" correct labels.
predictions = fullPredictions.select("prediction").rdd.map(lambda x: x[0])
labels = fullPredictions.select("label").rdd.map(lambda x: x[0])

In [11]:
# Zip them together
predictionAndLabel = predictions.zip(labels).collect()

# Print out the predicted and actual values for each point
for prediction in predictionAndLabel:
    print(prediction)

(-2.6671619945647635, -3.74)
(-1.8292638852698753, -2.58)
(-1.87186887387809, -2.36)
(-1.6730455937064217, -2.29)
(-1.5594322907511826, -2.27)
(-1.5594322907511826, -2.17)
(-1.3535081791448118, -2.12)
(-1.452919819230646, -2.07)
(-1.438718156361241, -2.0)
(-1.3038023591018948, -1.91)
(-1.3464073477101093, -1.88)
(-1.310903190536597, -1.8)
(-1.1688865618425481, -1.66)
(-1.1617857304078458, -1.6)
(-1.1759873932772507, -1.58)
(-1.197289887581358, -1.53)
(-0.9842649445402848, -1.48)
(-1.0410715960179042, -1.47)
(-1.005567438844392, -1.46)
(-1.1333824046690362, -1.42)
(-1.0552732588873093, -1.33)
(-0.8138449901074263, -1.3)
(-1.033970764583202, -1.3)
(-1.0481724274526067, -1.29)
(-0.8493491472809385, -1.27)
(-0.856449978715641, -1.25)
(-0.9416599559320703, -1.25)
(-0.7783408329339142, -1.24)
(-0.856449978715641, -1.22)
(-0.8919541358891532, -1.16)
(-0.8209458215421287, -1.14)
(-0.7854416643686166, -1.12)
(-0.7144333500215921, -1.11)
(-0.8777524730197482, -1.11)
(-0.7712400014992117, -1.1)
(

In [12]:
# Stop the session
spark.stop()