<a href="https://colab.research.google.com/github/safi50/K-Means-clustering-on-Airline-Data/blob/main/K_Means_clustering_on_Airline_Data_PySpark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [58]:
# !pip install pyspark

In [59]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, mean, round
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.evaluation import ClusteringEvaluator

In [60]:
spark = SparkSession.builder.appName('Airline Kmeans').getOrCreate()

In [61]:
data = spark.read.csv('airlines1.csv', inferSchema=True, header=True)
data.take(1)

[Row(_c0=0, Year=1998, Quarter=1, Month=1, DayofMonth=2, DayOfWeek=5, FlightDate=datetime.date(1998, 1, 2), Reporting_Airline='NW', DOT_ID_Reporting_Airline=19386, IATA_CODE_Reporting_Airline='NW', Tail_Number='N297US', Flight_Number_Reporting_Airline=675, OriginAirportID=13487, OriginAirportSeqID=1348701, OriginCityMarketID=31650, Origin='MSP', OriginCityName='Minneapolis, MN', OriginState='MN', OriginStateFips=27.0, OriginStateName='Minnesota', OriginWac=63, DestAirportID=14869, DestAirportSeqID=1486902, DestCityMarketID=34614, Dest='SLC', DestCityName='Salt Lake City, UT', DestState='UT', DestStateFips=49.0, DestStateName='Utah', DestWac=87, CRSDepTime=1640, DepTime=1659.0, DepDelay=19.0, DepDelayMinutes=19.0, DepDel15=1.0, DepartureDelayGroups=1.0, DepTimeBlk='1600-1659', TaxiOut=24.0, WheelsOff=1723.0, WheelsOn=1856.0, TaxiIn=3.0, CRSArrTime=1836, ArrTime=1859.0, ArrDelay=23.0, ArrDelayMinutes=23.0, ArrDel15=1.0, ArrivalDelayGroups=1.0, ArrTimeBlk='1800-1859', Cancelled=0.0, Cance

In [62]:
data.count()

50001

In [63]:
airline_data = data.select('Origin', 'Dest', 'AirTime', 'Distance')
airline_data.cache()
airline_data.show(50)

+------+----+-------+--------+
|Origin|Dest|AirTime|Distance|
+------+----+-------+--------+
|   MSP| SLC|  153.0|   991.0|
|   MKE| MCO|  141.0|  1066.0|
|   GJT| DFW|  103.0|   773.0|
|   LAX| DTW|  220.0|  1979.0|
|   EWR| CLT|   80.0|   529.0|
|   DFW| SHV|   28.0|   190.0|
|   BOS| CLE|   94.0|   563.0|
|   ATL| CAE|   35.0|   192.0|
|   ORD| CLE|   59.0|   316.0|
|   MDW| DAL|  114.0|   793.0|
|   SAN| LAX|   NULL|   109.0|
|   ELP| DAL|   77.0|   562.0|
|   SJU| MIA|   NULL|  1045.0|
|   ABQ| LAX|   95.0|   677.0|
|   ORD| LGA|   99.0|   733.0|
|   GSO| BWI|   NULL|   278.0|
|   DTW| MBS|   24.0|    98.0|
|   SLC| SEA|  102.0|   689.0|
|   LAX| IAD|  255.0|  2288.0|
|   SMF| LAX|   NULL|   373.0|
|   MSY| ORD|  119.0|   837.0|
|   SGF| ATL|   85.0|   563.0|
|   IND| STL|   NULL|   229.0|
|   PHX| LGB|   62.0|   355.0|
|   BUR| SFO|   53.0|   326.0|
|   DCA| BOS|   63.0|   399.0|
|   DFW| ATL|  101.0|   732.0|
|   CAE| ATL|   39.0|   192.0|
|   ORD| DSM|   NULL|   299.0|
|   RDU|

In [64]:
airline_data.printSchema()

root
 |-- Origin: string (nullable = true)
 |-- Dest: string (nullable = true)
 |-- AirTime: double (nullable = true)
 |-- Distance: double (nullable = true)



In [65]:
airline_data = airline_data.withColumn('AirTime', airline_data['AirTime'].cast('integer'))
airline_data = airline_data.withColumn('Distance', airline_data['Distance'].cast('integer'))
airline_data.printSchema()

root
 |-- Origin: string (nullable = true)
 |-- Dest: string (nullable = true)
 |-- AirTime: integer (nullable = true)
 |-- Distance: integer (nullable = true)



In [66]:
print(airline_data.where(col('AirTime').isNull()).count())
print(airline_data.where(col('Distance').isNull()).count())

10529
0


In [67]:
mean_airtime = airline_data.select(mean(col('AirTime'))).first()[0]
mean_airtime

105.83902513173895

In [68]:
airline_data = airline_data.fillna({'AirTime': mean_airtime})
airline_data.where(col('AirTime').isNull()).count(), airline_data.show(20)

+------+----+-------+--------+
|Origin|Dest|AirTime|Distance|
+------+----+-------+--------+
|   MSP| SLC|    153|     991|
|   MKE| MCO|    141|    1066|
|   GJT| DFW|    103|     773|
|   LAX| DTW|    220|    1979|
|   EWR| CLT|     80|     529|
|   DFW| SHV|     28|     190|
|   BOS| CLE|     94|     563|
|   ATL| CAE|     35|     192|
|   ORD| CLE|     59|     316|
|   MDW| DAL|    114|     793|
|   SAN| LAX|    105|     109|
|   ELP| DAL|     77|     562|
|   SJU| MIA|    105|    1045|
|   ABQ| LAX|     95|     677|
|   ORD| LGA|     99|     733|
|   GSO| BWI|    105|     278|
|   DTW| MBS|     24|      98|
|   SLC| SEA|    102|     689|
|   LAX| IAD|    255|    2288|
|   SMF| LAX|    105|     373|
+------+----+-------+--------+
only showing top 20 rows



(0, None)

In [69]:
airline_data = airline_data.withColumn('Distance', round(airline_data.Distance * 1.60034, 0))
airline_data = airline_data.withColumn('AirTime', round(airline_data.AirTime, 0))
airline_data.show(20)

+------+----+-------+--------+
|Origin|Dest|AirTime|Distance|
+------+----+-------+--------+
|   MSP| SLC|    153|  1586.0|
|   MKE| MCO|    141|  1706.0|
|   GJT| DFW|    103|  1237.0|
|   LAX| DTW|    220|  3167.0|
|   EWR| CLT|     80|   847.0|
|   DFW| SHV|     28|   304.0|
|   BOS| CLE|     94|   901.0|
|   ATL| CAE|     35|   307.0|
|   ORD| CLE|     59|   506.0|
|   MDW| DAL|    114|  1269.0|
|   SAN| LAX|    105|   174.0|
|   ELP| DAL|     77|   899.0|
|   SJU| MIA|    105|  1672.0|
|   ABQ| LAX|     95|  1083.0|
|   ORD| LGA|     99|  1173.0|
|   GSO| BWI|    105|   445.0|
|   DTW| MBS|     24|   157.0|
|   SLC| SEA|    102|  1103.0|
|   LAX| IAD|    255|  3662.0|
|   SMF| LAX|    105|   597.0|
+------+----+-------+--------+
only showing top 20 rows



## Vectorizing Features

In [70]:
vectorizer = VectorAssembler(inputCols=['AirTime', 'Distance'], outputCol='features')
dataset = vectorizer.setHandleInvalid("keep").transform(airline_data)
dataset.show(20)

+------+----+-------+--------+--------------+
|Origin|Dest|AirTime|Distance|      features|
+------+----+-------+--------+--------------+
|   MSP| SLC|    153|  1586.0|[153.0,1586.0]|
|   MKE| MCO|    141|  1706.0|[141.0,1706.0]|
|   GJT| DFW|    103|  1237.0|[103.0,1237.0]|
|   LAX| DTW|    220|  3167.0|[220.0,3167.0]|
|   EWR| CLT|     80|   847.0|  [80.0,847.0]|
|   DFW| SHV|     28|   304.0|  [28.0,304.0]|
|   BOS| CLE|     94|   901.0|  [94.0,901.0]|
|   ATL| CAE|     35|   307.0|  [35.0,307.0]|
|   ORD| CLE|     59|   506.0|  [59.0,506.0]|
|   MDW| DAL|    114|  1269.0|[114.0,1269.0]|
|   SAN| LAX|    105|   174.0| [105.0,174.0]|
|   ELP| DAL|     77|   899.0|  [77.0,899.0]|
|   SJU| MIA|    105|  1672.0|[105.0,1672.0]|
|   ABQ| LAX|     95|  1083.0| [95.0,1083.0]|
|   ORD| LGA|     99|  1173.0| [99.0,1173.0]|
|   GSO| BWI|    105|   445.0| [105.0,445.0]|
|   DTW| MBS|     24|   157.0|  [24.0,157.0]|
|   SLC| SEA|    102|  1103.0|[102.0,1103.0]|
|   LAX| IAD|    255|  3662.0|[255

In [71]:
kmeans = KMeans().setK(5).setSeed(1)
model = kmeans.fit(dataset)

In [72]:
eval = ClusteringEvaluator()

In [73]:
preds = model.transform(dataset)
preds.show(20)

+------+----+-------+--------+--------------+----------+
|Origin|Dest|AirTime|Distance|      features|prediction|
+------+----+-------+--------+--------------+----------+
|   MSP| SLC|    153|  1586.0|[153.0,1586.0]|         4|
|   MKE| MCO|    141|  1706.0|[141.0,1706.0]|         4|
|   GJT| DFW|    103|  1237.0|[103.0,1237.0]|         1|
|   LAX| DTW|    220|  3167.0|[220.0,3167.0]|         3|
|   EWR| CLT|     80|   847.0|  [80.0,847.0]|         1|
|   DFW| SHV|     28|   304.0|  [28.0,304.0]|         2|
|   BOS| CLE|     94|   901.0|  [94.0,901.0]|         1|
|   ATL| CAE|     35|   307.0|  [35.0,307.0]|         2|
|   ORD| CLE|     59|   506.0|  [59.0,506.0]|         2|
|   MDW| DAL|    114|  1269.0|[114.0,1269.0]|         1|
|   SAN| LAX|    105|   174.0| [105.0,174.0]|         2|
|   ELP| DAL|     77|   899.0|  [77.0,899.0]|         1|
|   SJU| MIA|    105|  1672.0|[105.0,1672.0]|         4|
|   ABQ| LAX|     95|  1083.0| [95.0,1083.0]|         1|
|   ORD| LGA|     99|  1173.0| 

In [74]:
silhouette = eval.evaluate(preds)
print(f'Silhouette with squared euclidean distance = {silhouette}')

Silhouette with squared euclidean distance = 0.7419395970412536


In [75]:
centroids = model.clusterCenters()
print('Cluster Centers: ')
for centroid in centroids:
    print(centroid)

Cluster Centers: 
[ 272.83681462 3887.7924282 ]
[ 92.25634611 967.78132588]
[ 62.5674548  432.65935198]
[ 189.36312598 2594.03570627]
[ 131.22506344 1619.63458911]


In [76]:
preds.groupBy('prediction').count().show()

+----------+-----+
|prediction|count|
+----------+-----+
|         1|13591|
|         3| 4453|
|         4|10246|
|         2|19413|
|         0| 2298|
+----------+-----+



In [77]:
preds.select('Origin', 'Dest', 'AirTime', 'Distance').where(preds.prediction == '1').show(20)

+------+----+-------+--------+
|Origin|Dest|AirTime|Distance|
+------+----+-------+--------+
|   GJT| DFW|    103|  1237.0|
|   EWR| CLT|     80|   847.0|
|   BOS| CLE|     94|   901.0|
|   MDW| DAL|    114|  1269.0|
|   ELP| DAL|     77|   899.0|
|   ABQ| LAX|     95|  1083.0|
|   ORD| LGA|     99|  1173.0|
|   SLC| SEA|    102|  1103.0|
|   SGF| ATL|     85|   901.0|
|   DFW| ATL|    101|  1171.0|
|   RDU| SRQ|    105|  1000.0|
|   ATL| DFW|    121|  1171.0|
|   ATL| ORD|     91|   970.0|
|   AUS| ELP|     80|   845.0|
|   LNK| ORD|     71|   746.0|
|   PHX| OAK|    102|  1034.0|
|   MEM| TPA|    105|  1050.0|
|   ATL| MDW|     93|   946.0|
|   ATL| DFW|    116|  1171.0|
|   MSY| FLL|     97|  1079.0|
+------+----+-------+--------+
only showing top 20 rows

