# Testing On Linear Regression

In [5]:
from pyspark.sql import SparkSession
from pyspark.ml.regression import LinearRegression

In [2]:
spark = SparkSession.builder.appName("linear_reg").getOrCreate()

### To check on vector which would affect the crew count in regression

In [14]:
cruise_final_data = spark.read.csv('cruise_ship_info.csv', inferSchema=True, header=True)
cruise_final_data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [15]:
from pyspark.ml.feature import VectorAssembler, StringIndexer

In [19]:
cruise_indexer = StringIndexer(inputCol="Cruise_line",outputCol="cruise_index")
cruise_indexer.fit(cruise_final_data).transform(cruise_final_data).head(2)

[Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, cruise_index=16.0),
 Row(Ship_name='Quest', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, cruise_index=16.0)]

In [20]:
assembler = VectorAssembler(inputCols=["cruise_index", "Age", "Tonnage", "passengers", "length", "cabins", "passenger_density"], outputCol="features")

### Split into 70% 30%

In [23]:
train_cruise_data, test_cruise_data = cruise_final_data.randomSplit([0.7,0.3])

In [25]:
regressionModel = LinearRegression(labelCol="crew")

### Put Into Pipeline

In [26]:
from pyspark.ml import Pipeline

In [27]:
pipeline = Pipeline(stages=[cruise_indexer, assembler, regressionModel])

In [29]:
trained_data_set = pipeline.fit(train_cruise_data)

In [30]:
result = trained_data_set.transform(test_cruise_data)

In [51]:
assembler_cruise = assembler.transform(cruise_indexer.fit(cruise_final_data).transform(cruise_final_data))
assembler_cruise = assembler_cruise.select("features","crew")
assembler_cruise.show(2)

+--------------------+----+
|            features|crew|
+--------------------+----+
|[16.0,6.0,30.2769...|3.55|
|[16.0,6.0,30.2769...|3.55|
+--------------------+----+
only showing top 2 rows



In [52]:
final_a_cruise_data, test_a_cruise_data = assembler_cruise.randomSplit([0.7,0.3])

In [53]:
lr = regressionModel.fit(final_a_cruise_data)

In [54]:
prediction = lr.evaluate(test_a_cruise_data)

In [55]:
prediction.r2

0.9057023783686344

In [64]:
lr.summary.r2

0.9383983105161378

In [62]:
prediction.rootMeanSquaredError

1.3227772669548175

In [63]:
lr.summary.rootMeanSquaredError

0.7844833168339089

In [59]:
test_a_cruise_data.describe().show()

+-------+-----------------+
|summary|             crew|
+-------+-----------------+
|  count|               39|
|   mean|8.324871794871797|
| stddev|4.363920781635944|
|    min|             0.88|
|    max|             21.0|
+-------+-----------------+



In [61]:
lr_df = lr.transform(test_a_cruise_data)
lr_df.select("prediction","crew","features").show(5)

+------------------+-----+--------------------+
|        prediction| crew|            features|
+------------------+-----+--------------------+
| 20.41847583160257| 21.0|[0.0,4.0,220.0,54...|
| 8.890306059241052| 8.68|[0.0,12.0,90.09,2...|
|12.764234519669364|11.85|[0.0,12.0,138.0,3...|
|12.747017002603172|11.76|[0.0,13.0,138.0,3...|
| 8.234979591846816|  6.6|[0.0,15.0,78.491,...|
+------------------+-----+--------------------+
only showing top 5 rows



In [67]:
lr.summary.residuals.show()


+--------------------+
|           residuals|
+--------------------+
| -1.2757171951905146|
|-0.16819342135733883|
|-0.09152476237777485|
| -0.6071300651519085|
|-0.34474109337343606|
| -0.9486695538017482|
| -0.4275235763072427|
|  -0.931452036735557|
| -0.9697994855369814|
|  -1.024057834736002|
| -0.5677620747806227|
|-0.41043385572784086|
|  -1.006840317669809|
| -0.7273857933476462|
| -0.7206957046266229|
| -0.6805174123303797|
|  0.8686837920948394|
|-0.34126478278736805|
| 0.45112569512440714|
| -0.3961999086610568|
+--------------------+
only showing top 20 rows

