In [0]:
#Congratulations! You've been contracted by Hyundai Heavy Industries to help them build a predictive model for some ships. Hyundai Heavy Industries is one of the world's largest ship manufacturing companies and builds cruise liners.

#You've been flown to their headquarters in Ulsan, South Korea to help them give accurate estimates of how many crew members a ship will require.

#They are currently building new ships for some customers and want you to create a model and use it to predict how many crew members the ships will need.



# P:S: for a better view of this notebook please click on this link 

https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/6209651637822328/538998700402656/2290918915078516/latest.html

In [0]:
from pyspark.sql import SparkSession # Importing the sparksession
spark = SparkSession.builder.appName('LINEAR').getOrCreate() # Intializng the spark session

In [0]:
df = spark.read.csv('dbfs:/FileStore/shared_uploads/shhivramcss@gmail.com/cruise_ship_info.csv', inferSchema=True, header=True) # Data Import

In [0]:
df.printSchema() # studying the dataset's dtypes

In [0]:
df.columns # dataset columns 

In [0]:
df.show(5) # head of the dataset

In [0]:
df.describe().show() # Descriptive statistics of the dataste

In [0]:
for i in df.columns[2:]:
  print("="*100)
  print("Histogram of ", i)
  display(df.select(i))
  print("\n")
  print("Descriptive Statistics of ", i)
  print(dict(df.select(i).describe().collect()))
  print("="*100)
  

Age
6
6
26
11
17
22
15
23
19
6


Tonnage
30.277
30.277
47.262
110.0
101.353
70.367
70.367
70.367
70.367
110.239


passengers
6.94
6.94
14.86
29.74
26.42
20.52
20.52
20.56
20.52
37.0


length
5.94
5.94
7.22
9.53
8.92
8.55
8.55
8.55
8.55
9.51


cabins
3.55
3.55
7.43
14.88
13.21
10.2
10.2
10.22
10.2
14.87


passenger_density
42.64
42.64
31.8
36.99
38.36
34.29
34.29
34.23
34.29
29.79


crew
3.55
3.55
6.7
19.1
10.0
9.2
9.2
9.2
9.2
11.5


In [0]:
# Since cruise line is categorical nominal I'm one hot encoding it
from pyspark.ml.feature import (StringIndexer, OneHotEncoder, VectorAssembler)

In [0]:
# Instansiating the string indexer
Sindexer = StringIndexer(inputCol= 'Cruise_line', outputCol= 'cru_indexed')

# Fitting the string indexer
Sindxed_df = Sindexer.fit(df).transform(df)

# Instansiating the Onehotencoder
Onehencoder = OneHotEncoder(inputCol= 'cru_indexed', outputCol= 'One_hot_cruz')

# Fitting the Onehotencoder
Onehencodered_df = Onehencoder.fit(Sindxed_df).transform(Sindxed_df)

In [0]:
Onehencodered_df.columns

In [0]:
# Invoking and implementing vector assembler
vc = VectorAssembler(inputCols= ['Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'One_hot_cruz'], outputCol= 'Features')
op_Df = vc.transform(Onehencodered_df)

In [0]:
# Creating the final dataset
f_df = op_Df.select('Features','crew')

In [0]:
# Train test split
train, test = f_df.randomSplit([0.7,0.3])

In [0]:
# Invaoking the leniar Regression Model
from pyspark.ml.regression import LinearRegression
lr_model = LinearRegression(featuresCol= 'Features', labelCol='crew')

# Training the model
lr_model_trained = lr_model.fit(train)

# Model Evaluation
test_results = lr_model_trained.evaluate(test)

# Model Results
lr_model_preds = lr_model_trained.transform(test)

In [0]:
# Evaluating the model

from pyspark.sql.functions import avg, stddev
print("RMSE of the lr_model is", test_results.rootMeanSquaredError)
print("Mean of the lr_model is", test.select(avg('crew')).collect()[0][0])
print("Std of the lr_model is", test.select(stddev('crew')).collect()[0][0])
print("r2 of the lr_model is", test_results.r2)
print("Adjusted r2 of the lr_model is", test_results.r2adj)

In [0]:
# Visualizing the Model Results
print("==========================Visualizing the Model Results==================================")
display(lr_model_preds.select('crew','prediction'))

crew,prediction
8.58,8.10547241116974
11.85,12.337260715137406
8.48,8.118217420926504
11.76,12.375495744407695
7.65,7.444063114077119
7.2,6.742494874776758
8.22,8.293976572879094
11.6,12.72419104082519
9.3,10.987430218571856
9.2,9.154090060787972
