In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler
import os

In [2]:
# Create spark session
spark = SparkSession.builder.appName("ICP 7").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

In [3]:
# Load data and select feature and label columns
data = spark.read.format("csv").option("header", True)\
.option("inferSchema", True).option("delimiter", ",")\
.load("car.csv")
data = data.withColumnRenamed("wheel-base", "label").select("label", "length", "width", "height")

In [4]:
# Create vector assembler for feature columns
assembler = VectorAssembler(inputCols=data.columns[1:], outputCol="features")
data = assembler.transform(data)

lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)

In [5]:
# Fit the model
model = lr.fit(data)

# Print the coefficients and intercept for linear regression
print("Coefficients: %s" % str(model.coefficients))
print("Intercept: %s" % str(model.intercept))

# Summarize the model over the training set and print out some metrics
trainingSummary = model.summary
print("numIterations: %d" % trainingSummary.totalIterations)
print("objectiveHistory: %s" % str(trainingSummary.objectiveHistory))
trainingSummary.residuals.show()
print("RMSE: %f" % trainingSummary.rootMeanSquaredError)
print("r2: %f" % trainingSummary.r2)
spark.stop()

Coefficients: [0.22836801258821893,0.8223218915856468,0.580595102043434]
Intercept: -26.380531957157498
numIterations: 11
objectiveHistory: [0.5, 0.38579526656819896, 0.13000842393266873, 0.12985504772567413, 0.12963704261349218, 0.12947103310674205, 0.1294164378448031, 0.1294050846483987, 0.12940508261516015, 0.1294050824628613, 0.12940508245526855]
+--------------------+
|           residuals|
+--------------------+
|  -4.611862798093398|
|  -4.611862798093398|
|  -2.501339043881387|
|-0.11328232985025011|
| -0.6777467081673763|
|  0.3413419946315486|
|  -2.878914311626758|
|  -2.878914311626758|
| -2.9950333320354474|
| -0.8412496309870932|
|  2.3922947158520174|
|  2.3922947158520174|
|  2.3922947158520174|
|  2.3922947158520174|
| -0.6335041529149237|
| -0.6335041529149237|
| -1.3908023008371515|
|  0.4019071188106693|
|   2.084135889634638|
|   2.787341183548463|
+--------------------+
only showing top 20 rows

RMSE: 2.517190
r2: 0.824407
