In [29]:
import warnings
warnings.filterwarnings("ignore")
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName("Linear").getOrCreate()

In [4]:
from pyspark.ml.regression import LinearRegression

In [5]:
# Load training data
training = spark.read.format("libsvm").load("sample_linear_regression_data.txt")

In [6]:
lr = LinearRegression(featuresCol='features', labelCol='label', predictionCol='prediction')

# You could also pass in additional parameters for regularization, do the reading 
# in ISLR to fully understand that, after that its just some simple parameter calls.
# Check the documentation with Shift+Tab for more info!

In [7]:
# Fit the model
lrModel = lr.fit(training)

In [11]:
print("Coefficients:{}".format(str(lrModel.coefficients)))
print('\n')
print("Intercept:{}".format(str(lrModel.intercept)))

Coefficients:[0.0073350710225801715,0.8313757584337543,-0.8095307954684084,2.441191686884721,0.5191713795290003,1.1534591903547016,-0.2989124112808717,-0.5128514186201779,-0.619712827067017,0.6956151804322931]


Intercept:0.14228558260358093


In [13]:
modelsummary = lrModel.summary

In [16]:
modelsummary.residuals.show()

+-------------------+
|          residuals|
+-------------------+
|-11.011130022096554|
| 0.9236590911176538|
|-4.5957401897776675|
|  -20.4201774575836|
|-10.339160314788181|
|-5.9552091439610555|
|-10.726906349283922|
|  2.122807193191233|
|  4.077122222293811|
|-17.316168071241652|
| -4.593044343959059|
|  6.380476690746936|
| 11.320566035059846|
|-20.721971774534094|
| -2.736692773777401|
| -16.66886934252847|
|  8.242186378876315|
|-1.3723486332690233|
|-0.7060332131264666|
|-1.1591135969994064|
+-------------------+
only showing top 20 rows



In [17]:
print("MSE:{}".format(modelsummary.meanSquaredError))
print("r2:{}".format(modelsummary.r2))

MSE:103.28843028724194
r2:0.027839179518600154


In [18]:
all_data = spark.read.format('libsvm').load("sample_linear_regression_data.txt")

In [20]:
train_data, test_data = all_data.randomSplit([0.7,0.3])

In [21]:
lr = LinearRegression(featuresCol='features',labelCol='label',predictionCol='prediction',maxIter=10,regParam=0.3)

In [22]:
correct_model = lr.fit(train_data)

In [24]:
sum = correct_model.summary
print("MSE:{}".format(sum.meanSquaredError))

MSE:106.78478510901837


In [35]:
test_results = correct_model.evaluate(test_data)
test_results.residuals.show()
print("MSE: {}".format(test_results.meanSquaredError))

+-------------------+
|          residuals|
+-------------------+
|-23.111018541809074|
| -21.89957670806893|
|-19.517954671179876|
|-18.517543541334334|
| -20.90186996661434|
|-17.016155172191677|
|-15.213773043465242|
|-17.263834018788234|
|-14.668150031633038|
|-11.596429470290344|
|-11.261953168228107|
|-14.025401800571819|
| -15.93607351649755|
|-10.658038390968656|
|  -13.1279554181704|
| -9.399927536263633|
|-11.679559748307376|
|-12.848281489001227|
| -10.64396228124224|
|-11.351191527299598|
+-------------------+
only showing top 20 rows

MSE: 100.91434662975908


In [25]:
unlabeled_data = test_data.select('features')

In [32]:
predictions = correct_model.transform(unlabeled_data)

In [33]:
predictions.show()

+--------------------+--------------------+
|            features|          prediction|
+--------------------+--------------------+
|(10,[0,1,2,3,4,5,...|  1.6786307776432683|
|(10,[0,1,2,3,4,5,...|  2.0150159337955054|
|(10,[0,1,2,3,4,5,...|    1.71432848251536|
|(10,[0,1,2,3,4,5,...|  1.4910512771247855|
|(10,[0,1,2,3,4,5,...|   4.209662945303233|
|(10,[0,1,2,3,4,5,...|  1.2840668999524303|
|(10,[0,1,2,3,4,5,...|-0.09720754595104736|
|(10,[0,1,2,3,4,5,...|  2.2073510442458018|
|(10,[0,1,2,3,4,5,...|  0.6920191004803353|
|(10,[0,1,2,3,4,5,...|  -2.270658424868424|
|(10,[0,1,2,3,4,5,...| -1.7158955571639964|
|(10,[0,1,2,3,4,5,...|    1.46682601171563|
|(10,[0,1,2,3,4,5,...|  3.4352997311424955|
|(10,[0,1,2,3,4,5,...| -1.4723148213192736|
|(10,[0,1,2,3,4,5,...|  1.2229685154952854|
|(10,[0,1,2,3,4,5,...| -2.4574228291657922|
|(10,[0,1,2,3,4,5,...| 0.06378448329175052|
|(10,[0,1,2,3,4,5,...|  1.4164791234607672|
|(10,[0,1,2,3,4,5,...| -0.3953855270115872|
|(10,[0,1,2,3,4,5,...|   0.40527