In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('cruise').getOrCreate()

In [2]:
from pyspark.ml.regression import LinearRegression

In [3]:
df = spark.read.csv('cruise_ship_info.csv', inferSchema=True, header=True)

In [4]:
df.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [5]:
df.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [6]:
df.groupBy('Cruise_line').count().show()

+-----------------+-----+
|      Cruise_line|count|
+-----------------+-----+
|            Costa|   11|
|              P&O|    6|
|           Cunard|    3|
|Regent_Seven_Seas|    5|
|              MSC|    8|
|         Carnival|   22|
|          Crystal|    2|
|           Orient|    1|
|         Princess|   17|
|        Silversea|    4|
|         Seabourn|    3|
| Holland_American|   14|
|         Windstar|    3|
|           Disney|    2|
|        Norwegian|   13|
|          Oceania|    3|
|          Azamara|    2|
|        Celebrity|   10|
|             Star|    6|
|  Royal_Caribbean|   23|
+-----------------+-----+



In [7]:
from pyspark.ml.feature import StringIndexer

In [8]:
indexer = StringIndexer(inputCol='Cruise_line', outputCol='cruise_cat')
indexed = indexer.fit(df).transform(df)
indexed.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|cruise_cat|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|      16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|      16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|       1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|       1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|       1.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|       1.0|
|    Elation|   Carnival| 15

In [9]:
from pyspark.ml.linalg import Vector
from pyspark.ml.feature import VectorAssembler

In [10]:
indexed.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'cruise_cat']

In [11]:
assembler = VectorAssembler(inputCols=['Age', 'Tonnage', 'passengers', 'length', 'cabins', 'passenger_density', 'cruise_cat'],
                            outputCol='features')

In [12]:
output = assembler.transform(indexed)

In [13]:
output.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)
 |-- cruise_cat: double (nullable = false)
 |-- features: vector (nullable = true)



In [14]:
final_df = output.select(['features', 'crew'])
final_df.show()

+--------------------+----+
|            features|crew|
+--------------------+----+
|[6.0,30.276999999...|3.55|
|[6.0,30.276999999...|3.55|
|[26.0,47.262,14.8...| 6.7|
|[11.0,110.0,29.74...|19.1|
|[17.0,101.353,26....|10.0|
|[22.0,70.367,20.5...| 9.2|
|[15.0,70.367,20.5...| 9.2|
|[23.0,70.367,20.5...| 9.2|
|[19.0,70.367,20.5...| 9.2|
|[6.0,110.23899999...|11.5|
|[10.0,110.0,29.74...|11.6|
|[28.0,46.052,14.5...| 6.6|
|[18.0,70.367,20.5...| 9.2|
|[17.0,70.367,20.5...| 9.2|
|[11.0,86.0,21.24,...| 9.3|
|[8.0,110.0,29.74,...|11.6|
|[9.0,88.5,21.24,9...|10.3|
|[15.0,70.367,20.5...| 9.2|
|[12.0,88.5,21.24,...| 9.3|
|[20.0,70.367,20.5...| 9.2|
+--------------------+----+
only showing top 20 rows



In [15]:
train_df, test_df = final_df.randomSplit([.7, .3])

In [16]:
lr = LinearRegression(labelCol='crew')

In [17]:
lr_model = lr.fit(train_df)

In [18]:
test_result = lr_model.evaluate(test_df)

In [19]:
test_result.residuals.show()

+--------------------+
|           residuals|
+--------------------+
| -1.2844543459381388|
| -0.2794822895952809|
| -1.0400085205711056|
| -1.0026537815331729|
|   0.633119367804408|
|   0.593605046058391|
|  0.5948741458678644|
|-0.39337580372567693|
|-0.02821196847788876|
| 0.39604848007302706|
|  -0.278018039993821|
|  0.7924073887659411|
|-0.41493095687325976|
|  1.1632309892342185|
|-0.07652043996525926|
| -0.3447520937027928|
| -0.3492094515026505|
| -0.3063451726701878|
|  -0.391848453578449|
|  1.0648134046610753|
+--------------------+
only showing top 20 rows



In [20]:
test_result.rootMeanSquaredError

1.3663117972225576

In [21]:
test_result.r2

0.8311995418283047

In [22]:
test_result.meanSquaredError

1.866807927229535

In [23]:
test_result.meanAbsoluteError

0.8149951560282024

In [24]:
from pyspark.sql.functions import corr

In [25]:
df.select(corr('crew', 'passengers')).show()

+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+



In [26]:
df.select(corr('crew', 'cabins')).show()

+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+



In [27]:
unlebeled_df = test_df.select('features')

In [28]:
unlebeled_df.show()

+--------------------+
|            features|
+--------------------+
|[5.0,86.0,21.04,9...|
|[5.0,122.0,28.5,1...|
|[6.0,30.276999999...|
|[6.0,90.0,20.0,9....|
|[6.0,93.0,23.94,9...|
|[6.0,110.23899999...|
|[6.0,113.0,37.82,...|
|[7.0,116.0,31.0,9...|
|[7.0,158.0,43.7,1...|
|[8.0,77.499,19.5,...|
|[8.0,110.0,29.74,...|
|[9.0,88.5,21.24,9...|
|[9.0,105.0,27.2,8...|
|[9.0,113.0,26.74,...|
|[9.0,116.0,26.0,9...|
|[10.0,58.825,15.6...|
|[10.0,81.76899999...|
|[10.0,90.09,25.01...|
|[10.0,105.0,27.2,...|
|[11.0,108.977,26....|
+--------------------+
only showing top 20 rows



In [29]:
predictions = lr_model.transform(unlebeled_df)

In [30]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[5.0,86.0,21.04,9...| 9.284454345938139|
|[5.0,122.0,28.5,1...| 6.979482289595281|
|[6.0,30.276999999...| 4.590008520571105|
|[6.0,90.0,20.0,9....|10.002653781533173|
|[6.0,93.0,23.94,9...|10.456880632195592|
|[6.0,110.23899999...|10.906394953941609|
|[6.0,113.0,37.82,...|11.405125854132136|
|[7.0,116.0,31.0,9...|12.393375803725677|
|[7.0,158.0,43.7,1...|13.628211968477888|
|[8.0,77.499,19.5,...| 8.603951519926973|
|[8.0,110.0,29.74,...| 11.87801803999382|
|[9.0,88.5,21.24,9...|  9.50759261123406|
|[9.0,105.0,27.2,8...| 11.09493095687326|
|[9.0,113.0,26.74,...|11.216769010765782|
|[9.0,116.0,26.0,9...| 11.07652043996526|
|[10.0,58.825,15.6...| 7.344752093702793|
|[10.0,81.76899999...|  8.76920945150265|
|[10.0,90.09,25.01...| 8.886345172670188|
|[10.0,105.0,27.2,...|11.071848453578449|
|[11.0,108.977,26....|10.935186595338925|
+--------------------+------------