## Building a Machine Learning Pipeline with PySpark

model to predict flight lateness based on various features.

In [9]:
# Import necessary modules
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
import pyspark.ml.evaluation as evals
import numpy as np
import pyspark.ml.tuning as tune

In [10]:
# Creating SparkSession
spark = SparkSession.builder.getOrCreate()

# Print spark
print(spark)

<pyspark.sql.session.SparkSession object at 0x000002389E292940>


In [11]:
# Checking for existing tables in the catalog
print(spark.catalog.listTables())

[Table(name='airports', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True), Table(name='flights', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True), Table(name='planes', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]


### Data Gathering

In [12]:
# Reading in the flight data
flight_df = spark.read.csv('flights_small.csv', header=True)

# Reading in the airport data
airports_df = spark.read.csv('airports.csv', header=True)

# Reading in the plane data
planes_df = spark.read.csv('planes.csv', header=True)

In [13]:
# Creating temporary tables to support SQL queries
flight_df.createOrReplaceTempView('flights')
airports_df.createOrReplaceTempView('airports')
planes_df.createOrReplaceTempView('planes')

# Checking for existing tables in the catalog
print(spark.catalog.listTables())

[Table(name='airports', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True), Table(name='flights', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True), Table(name='planes', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]


### Data Assessment

In [14]:
# Query to select the first 10 rows of the flights table
query = "SELECT * FROM flights LIMIT 10"

# Get the first 10 rows of flights
flights10 = spark.sql(query)

# Show the results
flights10.show()

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54|
|2014|    1| 15|    1037|        7|    1

In [15]:
# Function to return the shape of a DataFrame
def dataframe_shape(df):
  """
  This function takes a Spark DataFrame and returns its shape as a tuple
  containing the number of rows and columns.

  Args:
      df: The DataFrame to get the shape of.

  Returns:
      str: the number of rows and columns.
  """
  rows = df.count()
  cols = len(df.columns)
  shape = print("Shape of DataFrame: ({}, {})".format(rows, cols))
  return shape

# Call the function with the flight data
dataframe_shape(flight_df)

Shape of DataFrame: (10000, 16)


In [16]:
# Function to count the number of duplicates in a DataFrame
def count_duplicates(df):
  """
  This function counts the number of duplicate rows in a PySpark DataFrame.

  Args:
      df (pyspark.sql.DataFrame): The DataFrame to check for duplicates.

  Returns:
      str: The number of duplicate rows found.
  """
  num_duplicates = df.count() - df.dropDuplicates().count()
  dups = print("Number of duplicate rows:", num_duplicates)
  return dups

# Checking the number of duplicate rows
count_duplicates(flight_df)

Number of duplicate rows: 0


In [17]:
# Show the first 10 rows of the airports DataFrame 
airports_df.show(10)

+---+--------------------+----------+------------+----+---+---+
|faa|                name|       lat|         lon| alt| tz|dst|
+---+--------------------+----------+------------+----+---+---+
|04G|   Lansdowne Airport|41.1304722| -80.6195833|1044| -5|  A|
|06A|Moton Field Munic...|32.4605722| -85.6800278| 264| -5|  A|
|06C| Schaumburg Regional|41.9893408| -88.1012428| 801| -6|  A|
|06N|     Randall Airport| 41.431912| -74.3915611| 523| -5|  A|
|09J|Jekyll Island Air...|31.0744722| -81.4277778|  11| -4|  A|
|0A9|Elizabethton Muni...|36.3712222| -82.1734167|1593| -4|  A|
|0G6|Williams County A...|41.4673056| -84.5067778| 730| -5|  A|
|0G7|Finger Lakes Regi...|42.8835647| -76.7812318| 492| -5|  A|
|0P2|Shoestring Aviati...|39.7948244| -76.6471914|1000| -5|  U|
|0S9|Jefferson County ...|48.0538086|-122.8106436| 108| -8|  A|
+---+--------------------+----------+------------+----+---+---+
only showing top 10 rows



In [18]:
# Show the shape of the airports DataFrame
dataframe_shape(airports_df)

Shape of DataFrame: (1397, 7)


In [19]:
# Check the number of duplicate rows
count_duplicates(airports_df)

Number of duplicate rows: 0


In [20]:
# Show the first 10 rows of the planes DataFrame
planes_df.show(10)

+-------+----+--------------------+----------------+--------+-------+-----+-----+---------+
|tailnum|year|                type|    manufacturer|   model|engines|seats|speed|   engine|
+-------+----+--------------------+----------------+--------+-------+-----+-----+---------+
| N102UW|1998|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N103US|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N104UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N105UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N107US|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N108UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N109UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N110UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA

In [21]:
# Show the shape of the planes DataFrame
dataframe_shape(planes_df)

Shape of DataFrame: (2628, 9)


In [22]:
# Check the number of duplicate rows
count_duplicates(planes_df)

Number of duplicate rows: 0


### Data Preparation

In [23]:
# Rename year column for ease of joining
planes_df = planes_df.withColumnRenamed("year", "plane_year")

In [24]:
# Join the DataFrames
model_data = flight_df.join(planes_df, on="tailnum", how="leftouter")

In [25]:
# Cast the columns to integers
model_data = model_data.withColumn("arr_delay", model_data.arr_delay.cast("integer"))
model_data = model_data.withColumn("air_time", model_data.air_time.cast("integer"))
model_data = model_data.withColumn("month", model_data.month.cast("integer"))
model_data = model_data.withColumn("plane_year", model_data.plane_year.cast("integer"))

In [26]:
# Create the column plane_age
model_data = model_data.withColumn("plane_age", model_data.year - model_data.plane_year)

In [27]:
# Create is_late
model_data = model_data.withColumn("is_late", model_data.arr_delay > 0)

In [28]:
# Convert is_late to an integer
model_data = model_data.withColumn("label", model_data.is_late.cast("integer"))

In [29]:
# Remove missing values
model_data = model_data.filter("arr_delay is not NULL and dep_delay is not NULL and air_time is not NULL and plane_year is not NULL")

### Feature Engineering for Categorical Data

In [30]:
# Create a StringIndexer
carr_indexer = StringIndexer(inputCol="carrier", outputCol="carrier_index")

# Create a OneHotEncoder
carr_encoder = OneHotEncoder(inputCol="carrier_index", outputCol="carrier_fact")

In [31]:
# Create a StringIndexer
dest_indexer = StringIndexer(inputCol= "dest", outputCol= "dest_index")

# Create a OneHotEncoder
dest_encoder = OneHotEncoder(inputCol= "dest_index", outputCol = "dest_fact")

### Building the Machine Learning Pipeline

In [32]:
# Make a VectorAssembler
vec_assembler = VectorAssembler(inputCols= ["month", "air_time", "carrier_fact", "dest_fact", "plane_age"], outputCol= "features")

In [33]:
# Make the pipeline
flights_pipe = Pipeline(stages= [dest_indexer, 
                                 dest_encoder, 
                                 carr_indexer, 
                                 carr_encoder, 
                                 vec_assembler])

In [34]:
# Fit and transform the data
piped_data = flights_pipe.fit(model_data).transform(model_data)

### Training and Evaluation

In [35]:
# Split the data into training and test sets
training, test = piped_data.randomSplit([.6, .4])

In [36]:
# Create a LogisticRegression Estimator
lr = LogisticRegression()

In [37]:
# Create a BinaryClassificationEvaluator
evaluator = evals.BinaryClassificationEvaluator(metricName="areaUnderROC")

In [38]:
# Create the parameter grid
grid = tune.ParamGridBuilder()

# Add the hyperparameter
grid = grid.addGrid(lr.regParam, np.arange(0, .1, .01))
grid = grid.addGrid(lr.elasticNetParam, [0, 1])

# Build the grid
grid = grid.build()

In [39]:
# Create the CrossValidator
cv = tune.CrossValidator(estimator=lr,
               estimatorParamMaps=grid,
               evaluator=evaluator
               )

In [40]:
# Fit cross validation models
models = cv.fit(training)

# Extract the best model
best_lr = models.bestModel

In [41]:
# Call lr.fit()
best_lr = lr.fit(training)

# Print best_lr
print(best_lr)

LogisticRegressionModel: uid=LogisticRegression_ea8b8b83740b, numClasses=2, numFeatures=81


In [42]:
# Use the model to predict the test set
test_results = best_lr.transform(test)

# Evaluate the predictions
print(evaluator.evaluate(test_results))

0.6884011858595228
