In [88]:
import pandas as pd
import numpy as np
import pyspark as sp
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
import pyspark.ml.tuning as tune
import pyspark.ml.evaluation as evals

In [6]:
sc = sp.SparkContext.getOrCreate()
print(sc)
print(sc.version)

<SparkContext master=local[*] appName=pyspark-shell>
3.5.1


In [4]:
# Create my_spark
spark = SparkSession.builder.appName("pyspark-project").getOrCreate()

# print my_spark
spark

In [7]:
# create pd_temp
pd_temp = pd.DataFrame(np.random.random(10))

# Create spark_temp from pd_temp
spark_temp = spark.createDataFrame(pd_temp)

# Examine the table in the catalog
print(spark.catalog.listTables())

# add spark_temp to the catalog
spark_temp.createOrReplaceTempView('temp')

# examine the table in the catalog again
print(spark.catalog.listTables())

[]
[Table(name='temp', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]


In [15]:
file_path = './data/airports/airports.csv'

# read in the airports path
airports = spark.read.csv(file_path, header = True)

airports.show(5)

+---+--------------------+----------+-----------+----+---+---+
|faa|                name|       lat|        lon| alt| tz|dst|
+---+--------------------+----------+-----------+----+---+---+
|04G|   Lansdowne Airport|41.1304722|-80.6195833|1044| -5|  A|
|06A|Moton Field Munic...|32.4605722|-85.6800278| 264| -5|  A|
|06C| Schaumburg Regional|41.9893408|-88.1012428| 801| -6|  A|
|06N|     Randall Airport| 41.431912|-74.3915611| 523| -5|  A|
|09J|Jekyll Island Air...|31.0744722|-81.4277778|  11| -4|  A|
+---+--------------------+----------+-----------+----+---+---+
only showing top 5 rows



In [9]:
type(airports)

pyspark.sql.dataframe.DataFrame

In [11]:
spark.catalog.listDatabases()

[Database(name='default', catalog='spark_catalog', description='default database', locationUri='file:/D:/Data_Engineer/Spark_PySpark/spark-warehouse')]

In [12]:
spark.catalog.listTables()

[Table(name='temp', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]

In [17]:
flights = spark.read.csv("./data/airports/flights_small.csv", header = True)
flights.show(5)

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54|
+----+-----+---+--------+---------+-----

In [20]:
flights.name = flights.createOrReplaceTempView('flights')
spark.catalog.listTables()

[Table(name='flights', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='temp', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]

In [21]:
# Create the dataframe flight
flights_df = spark.table('flights')
print(flights_df.show(5))

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54|
+----+-----+---+--------+---------+-----

In [22]:
# include a new columns called duration_hrs
flights = flights.withColumn("duration_hrs", flights_df.air_time / 60)

flights.show(5)

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+------------------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|      duration_hrs|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+------------------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|               2.2|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|               6.0|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|              1.85|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|1.3833333333333333|
|2014|    3|  9|     754|  

In [23]:
flights.describe().show()

+-------+------+------------------+-----------------+------------------+------------------+------------------+------------------+-------+-------+-----------------+------+-----+------------------+-----------------+------------------+-----------------+------------------+
|summary|  year|             month|              day|          dep_time|         dep_delay|          arr_time|         arr_delay|carrier|tailnum|           flight|origin| dest|          air_time|         distance|              hour|           minute|      duration_hrs|
+-------+------+------------------+-----------------+------------------+------------------+------------------+------------------+-------+-------+-----------------+------+-----+------------------+-----------------+------------------+-----------------+------------------+
|  count| 10000|             10000|            10000|             10000|             10000|             10000|             10000|  10000|  10000|            10000| 10000|10000|             1

In [24]:
# filter flights with a SQL string
long_flights = flights.filter('distance > 1000')
long_flights.show(5)

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+-----------------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|     duration_hrs|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+-----------------+
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|              6.0|
|2014|    4| 19|    1236|       -4|    1508|       -7|     AS| N309AS|   490|   SEA| SAN|     135|    1050|  12|    36|             2.25|
|2014|   11| 19|    1812|       -3|    2352|       -4|     AS| N564AS|    26|   SEA| ORD|     198|    1721|  18|    12|              3.3|
|2014|    8|  3|    1120|        0|    1415|        2|     AS| N305AS|   656|   SEA| PHX|     154|    1107|  11|    20|2.566666666666667|
|2014|   11| 12|    2346|       -4

In [25]:
# filter flights with a boolean columns
long_flights2 = flights.filter(flights.distance > 1000)
long_flights2.show(5)

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+-----------------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|     duration_hrs|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+-----------------+
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|              6.0|
|2014|    4| 19|    1236|       -4|    1508|       -7|     AS| N309AS|   490|   SEA| SAN|     135|    1050|  12|    36|             2.25|
|2014|   11| 19|    1812|       -3|    2352|       -4|     AS| N564AS|    26|   SEA| ORD|     198|    1721|  18|    12|              3.3|
|2014|    8|  3|    1120|        0|    1415|        2|     AS| N305AS|   656|   SEA| PHX|     154|    1107|  11|    20|2.566666666666667|
|2014|   11| 12|    2346|       -4

In [27]:
# select the first set of columns as a string
selected_1 = flights.select('tailnum', 'origin', 'dest')
selected_1

DataFrame[tailnum: string, origin: string, dest: string]

In [28]:
# select the second set of columns unsinf df.columns
temp = flights.select(flights.origin, flights.dest, flights.carrier)
temp.show(5)

+------+----+-------+
|origin|dest|carrier|
+------+----+-------+
|   SEA| LAX|     VX|
|   SEA| HNL|     AS|
|   SEA| SFO|     VX|
|   PDX| SJC|     WN|
|   SEA| BUR|     AS|
+------+----+-------+
only showing top 5 rows



In [29]:
# define first filter to only keep flights from SEA to PDX
filterA = flights.origin == 'SEA'
filterB = flights.dest == 'PDX'

In [30]:
selected_2 = temp.filter(filterA).filter(filterB)
selected_2.show(5)

+------+----+-------+
|origin|dest|carrier|
+------+----+-------+
|   SEA| PDX|     OO|
|   SEA| PDX|     OO|
|   SEA| PDX|     OO|
|   SEA| PDX|     OO|
|   SEA| PDX|     OO|
+------+----+-------+
only showing top 5 rows



In [32]:
# Create a table of the average speed of each flight both ways
# Calculate average speed by dividing the distance by the air_time (converted to hours). use the .alias() me thod name
# define avg_speed
avg_speed = (flights.distance / (flights.air_time / 60)).alias('avg_speed')
speed_1 = flights.select('origin', 'dest', 'tailnum', avg_speed)

In [33]:
speed_1.show(5)

+------+----+-------+------------------+
|origin|dest|tailnum|         avg_speed|
+------+----+-------+------------------+
|   SEA| LAX| N846VA| 433.6363636363636|
|   SEA| HNL| N559AS| 446.1666666666667|
|   SEA| SFO| N847VA|367.02702702702703|
|   PDX| SJC| N360SW| 411.3253012048193|
|   SEA| BUR| N612AS| 442.6771653543307|
+------+----+-------+------------------+
only showing top 5 rows



In [34]:
# using the Spark method .selectExpr()
speed_2 = flights.selectExpr('origin', 'dest', 'tailnum', 'distance/(air_time/60) as avg_speed')
speed_2.show(5)

+------+----+-------+------------------+
|origin|dest|tailnum|         avg_speed|
+------+----+-------+------------------+
|   SEA| LAX| N846VA| 433.6363636363636|
|   SEA| HNL| N559AS| 446.1666666666667|
|   SEA| SFO| N847VA|367.02702702702703|
|   PDX| SJC| N360SW| 411.3253012048193|
|   SEA| BUR| N612AS| 442.6771653543307|
+------+----+-------+------------------+
only showing top 5 rows



In [35]:
# arr_time: string and distance: string, so to find min() and max() we need to covert this float
flights = flights.withColumn('distance', flights.distance.cast('float'))
flights = flights.withColumn('air_time', flights.air_time.cast('float'))

flights.describe('air_time', 'distance').show()

+-------+------------------+-----------------+
|summary|          air_time|         distance|
+-------+------------------+-----------------+
|  count|              9925|            10000|
|   mean|152.88423173803525|        1208.1516|
| stddev|  72.8656286392139|656.8599023464376|
|    min|              20.0|             93.0|
|    max|             409.0|           2724.0|
+-------+------------------+-----------------+



In [37]:
# find the length of the shortest (in terms of distance) flight that left PDX
flights.filter(flights.origin == 'PDX').groupBy().min('distance').show()

+-------------+
|min(distance)|
+-------------+
|        106.0|
+-------------+



In [39]:
# find the length of the longest (in terms of time) flight that left SEA
flights.filter(flights.origin == 'SEA').groupBy().max('air_time').show()

+-------------+
|max(air_time)|
+-------------+
|        409.0|
+-------------+



In [41]:
# get the average air time of Delta Airlines flights that left SEA
flights.filter(flights.carrier == 'DL').filter(flights.origin == 'SEA').groupBy().avg('air_time').show()

+------------------+
|     avg(air_time)|
+------------------+
|188.20689655172413|
+------------------+



In [42]:
# get the total number of hours all planes in this dataset spent in the air by creating a column called duration_hrs
flights.withColumn('duration_hrs', flights.air_time / 60).groupBy().sum('duration_hrs').show()

+------------------+
| sum(duration_hrs)|
+------------------+
|25289.600000000126|
+------------------+



In [43]:
# Group by tailnum column
by_plane = flights.groupBy('tailnum')

In [44]:
# use the .count() method with no arguments to count the number of flights each plane made
by_plane.count().show()

+-------+-----+
|tailnum|count|
+-------+-----+
| N442AS|   38|
| N102UW|    2|
| N36472|    4|
| N38451|    4|
| N73283|    4|
| N513UA|    2|
| N954WN|    5|
| N388DA|    3|
| N567AA|    1|
| N516UA|    2|
| N927DN|    1|
| N8322X|    1|
| N466SW|    1|
|  N6700|    1|
| N607AS|   45|
| N622SW|    4|
| N584AS|   31|
| N914WN|    4|
| N654AW|    2|
| N336NW|    1|
+-------+-----+
only showing top 20 rows



In [46]:
# group by origin column
by_origin = flights.groupBy('origin')

In [47]:
# find the .avg() of the air_time column to find average duration of flights from PDX and SEA
by_origin.avg('air_time').show()

+------+------------------+
|origin|     avg(air_time)|
+------+------------------+
|   SEA| 160.4361496051259|
|   PDX|137.11543248288737|
+------+------------------+



In [48]:
# import pyspark.sql.functions as f
# covert to dep_delay to numeric column
flights = flights.withColumn('dep_delay', flights.dep_delay.cast('float'))

# group by month and dest
by_month_dest = flights.groupBy('month', 'dest')

In [49]:
# average departure delay by moth and destination
by_month_dest.avg('dep_delay').show()

+-----+----+--------------------+
|month|dest|      avg(dep_delay)|
+-----+----+--------------------+
|   11| TUS| -2.3333333333333335|
|   11| ANC|   7.529411764705882|
|    1| BUR|               -1.45|
|    1| PDX| -5.6923076923076925|
|    6| SBA|                -2.5|
|    5| LAX|-0.15789473684210525|
|   10| DTW|                 2.6|
|    6| SIT|                -1.0|
|   10| DFW|  18.176470588235293|
|    3| FAI|                -2.2|
|   10| SEA|                -0.8|
|    2| TUS| -0.6666666666666666|
|   12| OGG|  25.181818181818183|
|    9| DFW|   4.066666666666666|
|    5| EWR|               14.25|
|    3| RDM|                -6.2|
|    8| DCA|                 2.6|
|    7| ATL|   4.675675675675675|
|    4| JFK| 0.07142857142857142|
|   10| SNA| -1.1333333333333333|
+-----+----+--------------------+
only showing top 20 rows



In [50]:
airports.show(5)

+---+--------------------+----------+-----------+----+---+---+
|faa|                name|       lat|        lon| alt| tz|dst|
+---+--------------------+----------+-----------+----+---+---+
|04G|   Lansdowne Airport|41.1304722|-80.6195833|1044| -5|  A|
|06A|Moton Field Munic...|32.4605722|-85.6800278| 264| -5|  A|
|06C| Schaumburg Regional|41.9893408|-88.1012428| 801| -6|  A|
|06N|     Randall Airport| 41.431912|-74.3915611| 523| -5|  A|
|09J|Jekyll Island Air...|31.0744722|-81.4277778|  11| -4|  A|
+---+--------------------+----------+-----------+----+---+---+
only showing top 5 rows



In [51]:
# rename the faa column
airports = airports.withColumnRenamed('faa', 'dest')

In [52]:
# Join the dataframes
flights_with_airports = flights.join(airports, on = 'dest', how = 'leftouter')
flights_with_airports.show(5)

+----+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+--------+--------+----+------+------------------+--------------------+---------+-----------+---+---+---+
|dest|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|air_time|distance|hour|minute|      duration_hrs|                name|      lat|        lon|alt| tz|dst|
+----+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+--------+--------+----+------+------------------+--------------------+---------+-----------+---+---+---+
| LAX|2014|   12|  8|     658|     -7.0|     935|       -5|     VX| N846VA|  1780|   SEA|   132.0|   954.0|   6|    58|               2.2|    Los Angeles Intl|33.942536|-118.408075|126| -8|  A|
| HNL|2014|    1| 22|    1040|      5.0|    1505|        5|     AS| N559AS|   851|   SEA|   360.0|  2677.0|  10|    40|               6.0|       Honolulu Intl|21.318681|-157.922428| 13|-10|  N|
| SFO|2014|    3|  9|    1443|

In [53]:
planes = spark.read.csv("./data/airports/planes.csv", header = True)
planes.show(5)

+-------+----+--------------------+----------------+--------+-------+-----+-----+---------+
|tailnum|year|                type|    manufacturer|   model|engines|seats|speed|   engine|
+-------+----+--------------------+----------------+--------+-------+-----+-----+---------+
| N102UW|1998|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N103US|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N104UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N105UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N107US|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
+-------+----+--------------------+----------------+--------+-------+-----+-----+---------+
only showing top 5 rows



In [54]:
# rename year column on panes to avoid duplicate columns name
planes = planes.withColumnRenamed('year', 'plane_year')

In [55]:
# join the flight and plane table use key as tailnum column
model_data = flights.join(planes, on = 'tailnum', how = 'leftouter')
model_data.show(5)

+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+------------------+----------+--------------------+------------+--------+-------+-----+-----+---------+
|tailnum|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|flight|origin|dest|air_time|distance|hour|minute|      duration_hrs|plane_year|                type|manufacturer|   model|engines|seats|speed|   engine|
+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+------------------+----------+--------------------+------------+--------+-------+-----+-----+---------+
| N846VA|2014|   12|  8|     658|     -7.0|     935|       -5|     VX|  1780|   SEA| LAX|   132.0|   954.0|   6|    58|               2.2|      2011|Fixed wing multi ...|      AIRBUS|A320-214|      2|  182|   NA|Turbo-fan|
| N559AS|2014|    1| 22|    1040|      5.0|    1505|        5|     AS|   851|   SEA| HNL|   360.0|  2677.0| 

In [57]:
model_data.describe().show()

+-------+-------+------+------------------+-----------------+------------------+------------------+------------------+------------------+-------+-----------------+------+-----+------------------+-----------------+------------------+-----------------+------------------+-----------------+--------------------+--------------------+-------------+--------------------+------------------+-----------------+-------------+
|summary|tailnum|  year|             month|              day|          dep_time|         dep_delay|          arr_time|         arr_delay|carrier|           flight|origin| dest|          air_time|         distance|              hour|           minute|      duration_hrs|       plane_year|                type|        manufacturer|        model|             engines|             seats|            speed|       engine|
+-------+-------+------+------------------+-----------------+------------------+------------------+------------------+------------------+-------+-----------------+-----

In [59]:
model_data = model_data.withColumn('arr_delay', model_data.arr_delay.cast('integer'))
model_data = model_data.withColumn('air_time', model_data.air_time.cast('integer'))
model_data = model_data.withColumn('month', model_data.month.cast('integer'))
model_data = model_data.withColumn('plane_year', model_data.plane_year.cast('integer'))

In [60]:
model_data.describe('arr_delay', 'air_time', 'month', 'plane_year').show()

+-------+------------------+------------------+------------------+-----------------+
|summary|         arr_delay|          air_time|             month|       plane_year|
+-------+------------------+------------------+------------------+-----------------+
|  count|              9925|              9925|             10000|             9354|
|   mean|2.2530982367758186|152.88423173803525|            6.6438|2001.594398118452|
| stddev|31.074918600451877|  72.8656286392139|3.3191600205962097|58.92921992728455|
|    min|               -58|                20|                 1|                0|
|    max|               900|               409|                12|             2014|
+-------+------------------+------------------+------------------+-----------------+



In [61]:
# create a new column
model_data = model_data.withColumn('plane_age', model_data.year - model_data.plane_year)

In [62]:
model_data = model_data.withColumn('is_late', model_data.arr_delay > 50)

model_data = model_data.withColumn('label', model_data.is_late.cast('integer'))

model_data.filter(
    "arr_delay is not NULL and dep_delay is not NULL and air_time is not NULL and plane_year is not NULL").show(5)

+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+------------------+----------+--------------------+------------+--------+-------+-----+-----+---------+---------+-------+-----+
|tailnum|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|flight|origin|dest|air_time|distance|hour|minute|      duration_hrs|plane_year|                type|manufacturer|   model|engines|seats|speed|   engine|plane_age|is_late|label|
+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+------------------+----------+--------------------+------------+--------+-------+-----+-----+---------+---------+-------+-----+
| N846VA|2014|   12|  8|     658|     -7.0|     935|       -5|     VX|  1780|   SEA| LAX|     132|   954.0|   6|    58|               2.2|      2011|Fixed wing multi ...|      AIRBUS|A320-214|      2|  182|   NA|Turbo-fan|      3.0|  false|    0|
| N559AS|201

In [78]:
# Create a StringIndexer
carr_indexer = StringIndexer(inputCol = 'carrier', outputCol = 'carrier_index')
# Create a OneHotEncoder
carr_encoder = OneHotEncoder(inputCol = 'carrier_index', outputCol = 'carr_fact')

In [79]:
# encode the dest column just like you did above
dest_indexer = StringIndexer(inputCol = 'dest', outputCol = 'dest_index')
dest_encoder = OneHotEncoder(inputCol = 'dest_index', outputCol = 'dest_fact')

In [80]:
# Assemble a Vector
vec_assembler = VectorAssembler(inputCols = ['month', 'air_time', 'carr_fact', 'dest_fact', 'plane_age'],
                                outputCol = 'features', handleInvalid = 'skip')

In [83]:
# create the pipeline
# you're finally ready to create a pipeline is class in the 'pyspark.ml module' that combines all the Estimators and Transformers that you've already created

flights_pipe = Pipeline(stages = [dest_indexer, dest_encoder, carr_indexer, carr_encoder, vec_assembler])

In [84]:
piped_data = flights_pipe.fit(model_data).transform(model_data)

In [85]:
piped_data.show(5)

+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+------------------+----------+--------------------+------------+--------+-------+-----+-----+---------+---------+-------+-----+----------+---------------+-------------+--------------+--------------------+
|tailnum|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|flight|origin|dest|air_time|distance|hour|minute|      duration_hrs|plane_year|                type|manufacturer|   model|engines|seats|speed|   engine|plane_age|is_late|label|dest_index|      dest_fact|carrier_index|     carr_fact|            features|
+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+------------------+----------+--------------------+------------+--------+-------+-----+-----+---------+---------+-------+-----+----------+---------------+-------------+--------------+--------------------+
| N846VA|2014|   12|  8|    

In [86]:
training, test = piped_data.randomSplit([0.6, 0.4])

In [87]:
lr = LogisticRegression()

In [89]:
# Create the evaluator
# The first thing you need when doing validation for model selection is a way to compare different models. Luckily, the pyspark.ml.evaluation submodule has classes for evaluating different kinds of model. Your model is a binary classification model, so you'll be using the 'BinaryClassificationEvaluator' from the 'pyspark.ml.evaluation' module. This evaluator calculates the area under the ROC. This is a metric that combines the two kinds of errors a binary classifier can make (false positives and false negatives) into a simple numbers.

evaluator = evals.BinaryClassificationEvaluator(metricName = 'areaUnderROC')

In [90]:
# Create the parameter grid
grid = tune.ParamGridBuilder()

# Add the hyperparameter
grid = grid.addGrid(lr.regParam, np.arange(0, .1, .01))
grid = grid.addGrid(lr.elasticNetParam, [0, 1])

# Build the grid
grid = grid.build()

In [91]:
# Create the CrossValidator
cv = tune.CrossValidator(estimator = lr, estimatorParamMaps = grid, evaluator = evaluator)

In [92]:
# fit cross validation models
models = cv.fit(training)

In [93]:
# extract the best model
best_lr = models.bestModel

In [94]:
# use the model to predict the test set
test_results = best_lr.transform(test)

# evaluate the predicts
print(evaluator.evaluate(test_results))

0.6382809871234071
