# Testing On Linear Regression

In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.regression import LinearRegression

In [2]:
spark = SparkSession.builder.appName("linear_reg").getOrCreate()

### To check on vector which would affect the crew count in regression

In [3]:
cruise_final_data = spark.read.csv('cruise_ship_info.csv', inferSchema=True, header=True)
cruise_final_data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [4]:
from pyspark.ml.feature import VectorAssembler, StringIndexer

In [5]:
cruise_indexer = StringIndexer(inputCol="Cruise_line",outputCol="cruise_index")
cruise_indexer.fit(cruise_final_data).transform(cruise_final_data).head(2)

[Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, cruise_index=16.0),
 Row(Ship_name='Quest', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, cruise_index=16.0)]

In [6]:
assembler = VectorAssembler(inputCols=["cruise_index", "Age", "Tonnage", "passengers", "length", "cabins", "passenger_density"], outputCol="features")

### Split into 70% 30%

In [7]:
train_cruise_data, test_cruise_data = cruise_final_data.randomSplit([0.7,0.3])

In [8]:
regressionModel = LinearRegression(labelCol="crew")

### Put Into Pipeline Failed (only applicable to Logistics Regression

In [9]:
from pyspark.ml import Pipeline

In [10]:
pipeline = Pipeline(stages=[cruise_indexer, assembler, regressionModel])

In [11]:
trained_data_set = pipeline.fit(train_cruise_data)

In [12]:
result = trained_data_set.transform(test_cruise_data)

In [13]:
assembler_cruise = assembler.transform(cruise_indexer.fit(cruise_final_data).transform(cruise_final_data))
assembler_cruise = assembler_cruise.select("features","crew")
assembler_cruise.show(2)

+--------------------+----+
|            features|crew|
+--------------------+----+
|[16.0,6.0,30.2769...|3.55|
|[16.0,6.0,30.2769...|3.55|
+--------------------+----+
only showing top 2 rows



In [14]:
final_a_cruise_data, test_a_cruise_data = assembler_cruise.randomSplit([0.7,0.3])

In [15]:
lr = regressionModel.fit(final_a_cruise_data)

In [16]:
prediction = lr.evaluate(test_a_cruise_data)

In [17]:
prediction.r2

0.9242617637893116

In [18]:
lr.summary.r2

0.9269497118404524

In [19]:
prediction.rootMeanSquaredError

0.9227912088550504

In [23]:
lr.summary.meanSquaredError

0.9254895757831877

In [24]:
lr.summary.rootMeanSquaredError

0.962023687745363

In [20]:
test_a_cruise_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|                56|
|   mean|7.5492857142857135|
| stddev| 3.383440345480998|
|    min|              0.59|
|    max|              13.6|
+-------+------------------+



In [21]:
lr_df = lr.transform(test_a_cruise_data)
lr_df.select("prediction","crew","features").show(5)

+------------------+-----+--------------------+
|        prediction| crew|            features|
+------------------+-----+--------------------+
|13.685294827952042| 13.6|[0.0,7.0,158.0,43...|
|12.887506386454934|11.85|[0.0,11.0,138.0,3...|
|12.856104571510597|11.76|[0.0,14.0,138.0,3...|
|  8.17754236006678|  6.6|[0.0,15.0,78.491,...|
| 9.091287855761921| 8.22|[0.0,22.0,73.941,...|
+------------------+-----+--------------------+
only showing top 5 rows



In [22]:
lr.summary.residuals.show()




+--------------------+
|           residuals|
+--------------------+
| 0.09228840061231836|
| -1.4070413373311847|
| -0.1525626942094931|
|  -0.538863649928734|
|-0.22961948539018984|
| -1.0479736581030465|
|-0.31915221374207725|
|-0.10868494209396573|
|  -1.027039114806822|
| -1.1065718431587097|
| -1.0696098694706606|
| -0.5170750884186681|
| -0.3763125596568875|
| -1.0591425978225484|
| -0.7599074167663646|
| -0.8873035183760756|
|  0.8320891736562359|
| -0.4732739045211929|
| -0.6037824322362777|
|  0.7360348414297064|
+--------------------+
only showing top 20 rows

