In [0]:
from pyspark.sql.types import IntegerType, FloatType, DateType
import pyspark.sql.functions as F
from pyspark.mllib.linalg import Vectors
from pyspark.ml.param import Param, Params
from pyspark.ml.feature import OneHotEncoder, VectorAssembler, StringIndexer
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from functools import reduce

In [0]:
SEED = 42
DEBUG = True # uses smaller dataset for testing
DISPLAY_LIMIT = 10

In [0]:
airline_df = spark.read.format("delta").load("dbfs:/user/airline/table")

In [0]:
airline_df.printSchema()

root
 |-- FL_DATE: string (nullable = true)
 |-- OP_CARRIER: string (nullable = true)
 |-- OP_CARRIER_FL_NUM: string (nullable = true)
 |-- ORIGIN: string (nullable = true)
 |-- DEST: string (nullable = true)
 |-- CRS_DEP_TIME: string (nullable = true)
 |-- DEP_TIME: string (nullable = true)
 |-- DEP_DELAY: string (nullable = true)
 |-- TAXI_OUT: string (nullable = true)
 |-- WHEELS_OFF: string (nullable = true)
 |-- WHEELS_ON: string (nullable = true)
 |-- TAXI_IN: string (nullable = true)
 |-- CRS_ARR_TIME: string (nullable = true)
 |-- ARR_TIME: string (nullable = true)
 |-- ARR_DELAY: string (nullable = true)
 |-- CANCELLED: string (nullable = true)
 |-- CANCELLATION_CODE: string (nullable = true)
 |-- DIVERTED: string (nullable = true)
 |-- CRS_ELAPSED_TIME: string (nullable = true)
 |-- ACTUAL_ELAPSED_TIME: string (nullable = true)
 |-- AIR_TIME: string (nullable = true)
 |-- DISTANCE: string (nullable = true)
 |-- CARRIER_DELAY: string (nullable = true)
 |-- WEATHER_DELAY: strin

In [0]:
airline_df.count()

Out[5]: 43051239

In [0]:
if DEBUG:
    airline_df = airline_df.limit(400_000)

In [0]:
# TODO: Figure out what to do with nulls in delay columns

In [0]:
# cast columns

def cast_types(airline_df):
     return airline_df.withColumn("FL_DATE", airline_df.FL_DATE.cast(DateType())) \
                   .withColumn("OP_CARRIER_FL_NUM", airline_df.OP_CARRIER_FL_NUM.cast(IntegerType())) \
                   .withColumn("CRS_DEP_TIME", airline_df.CRS_DEP_TIME.cast(IntegerType())) \
                   .withColumn("DEP_TIME", airline_df.DEP_TIME.cast(FloatType())) \
                   .withColumn("DEP_DELAY", airline_df.DEP_DELAY.cast(FloatType())) \
                   .withColumn("TAXI_OUT", airline_df.TAXI_OUT.cast(FloatType())) \
                   .withColumn("WHEELS_OFF", airline_df.WHEELS_OFF.cast(FloatType())) \
                   .withColumn("WHEELS_ON", airline_df.WHEELS_ON.cast(FloatType())) \
                   .withColumn("TAXI_IN", airline_df.TAXI_IN.cast(FloatType())) \
                   .withColumn("CRS_ARR_TIME", airline_df.CRS_ARR_TIME.cast(IntegerType())) \
                   .withColumn("ARR_TIME", airline_df.ARR_TIME.cast(FloatType())) \
                   .withColumn("ARR_DELAY", airline_df.ARR_DELAY.cast(FloatType())) \
                   .withColumn("CANCELLED", airline_df.CANCELLED.cast(FloatType())) \
                   .withColumn("DIVERTED", airline_df.DIVERTED.cast(FloatType())) \
                   .withColumn("CRS_ELAPSED_TIME", airline_df.CRS_ELAPSED_TIME.cast(FloatType())) \
                   .withColumn("ACTUAL_ELAPSED_TIME", airline_df.ACTUAL_ELAPSED_TIME.cast(FloatType())) \
                   .withColumn("AIR_TIME", airline_df.AIR_TIME.cast(FloatType())) \
                   .withColumn("DISTANCE", airline_df.DISTANCE.cast(FloatType())) \
                   .withColumn("CARRIER_DELAY", airline_df.CARRIER_DELAY.cast(FloatType())) \
                   .withColumn("WEATHER_DELAY", airline_df.WEATHER_DELAY.cast(FloatType())) \
                   .withColumn("NAS_DELAY", airline_df.NAS_DELAY.cast(FloatType())) \
                   .withColumn("SECURITY_DELAY", airline_df.SECURITY_DELAY.cast(FloatType())) \
                   .withColumn("LATE_AIRCRAFT_DELAY", airline_df.LATE_AIRCRAFT_DELAY.cast(FloatType()))
    
airline_df = cast_types(airline_df)

In [0]:
print("Count before preprocessing", airline_df.count())

Count before preprocessing 400000


We remove the following columns:
"CARRIER_DELAY", "WEATHER_DELAY", "NAS_DELAY", "SECURITY_DELAY", "LATE_AIRCRAFT_DELAY", "CANCELLATION_CODE", "AIR_TIME", "ACTUAL_ELAPSED_TIME", "ARR_DELAY", "ARR_TIME", "TAXI_IN", "WHEELS_ON", "WHEELS_OFF", "TAXI_OUT", "DEP_DELAY", "DEP_TIME", "DIVERTED"

The reason being that they either have too many null values and/or they would just give away if the flight is cancelled or not, making the prediction with ML models pointless.

For example "ACTUAL_ELAPSED_TIME", "ARR_DELAY", "ARR_TIME", "TAXI_IN", "WHEELS_ON", "WHEELS_OFF", "TAXI_OUT", "DEP_DELAY", "DEP_TIME" are not present for any flights cancelled and are only nulls for a handful (10,000-100,000) of non-cancelled columns, making it very easy to predict wether flight was cancelled without any models.

In a real-world task, it would make sense to predict which flight was going to be cancelled before the actual flight, so we also remove those columns (many of which are removed anyway due to reasons above). For example, "DIVERTED" would not be available before the flight and also, it is not available for any cancelled flights, so we ignore it.

We are left with columns 'FL_DATE', 'OP_CARRIER', 'OP_CARRIER_FL_NUM', 'ORIGIN', 'DEST', 'CRS_DEP_TIME', 'CRS_ARR_TIME', 'CRS_ELAPSED_TIME', 'DISTANCE' (and 'CANCELLED').

Since there are only 38 rows with nulls left, we won't bother imputing them and drop them, since they would have almost no effect on the result, considering the size of the dataset.

In [0]:
remove_cols = ["CARRIER_DELAY", "WEATHER_DELAY", "NAS_DELAY", "SECURITY_DELAY", "LATE_AIRCRAFT_DELAY", "CANCELLATION_CODE", "AIR_TIME", "ACTUAL_ELAPSED_TIME", "ARR_DELAY", "ARR_TIME", "TAXI_IN", "WHEELS_ON", "WHEELS_OFF", "TAXI_OUT", "DEP_DELAY", "DEP_TIME", "DIVERTED"]

def preprocessing(df):
    df = df.drop(*remove_cols)
    return df.filter(reduce(lambda a, b: a & F.col(b).isNotNull(), df.columns, F.lit(True)))


In [0]:
airline_df = preprocessing(airline_df)
print("Count after preprocessing", airline_df.count())

Count after preprocessing 400000


In [0]:
display(airline_df)

FL_DATE,OP_CARRIER,OP_CARRIER_FL_NUM,ORIGIN,DEST,CRS_DEP_TIME,CRS_ARR_TIME,CANCELLED,CRS_ELAPSED_TIME,DISTANCE
2009-01-01,XE,1204,DCA,EWR,1100,1202,0.0,62.0,199.0
2009-01-01,XE,1206,EWR,IAD,1510,1632,0.0,82.0,213.0
2009-01-01,XE,1207,EWR,DCA,1100,1210,0.0,70.0,199.0
2009-01-01,XE,1208,DCA,EWR,1240,1357,0.0,77.0,199.0
2009-01-01,XE,1209,IAD,EWR,1715,1900,0.0,105.0,213.0
2009-01-01,XE,1212,ATL,EWR,1915,2142,0.0,147.0,745.0
2009-01-01,XE,1212,CLE,ATL,1645,1842,0.0,117.0,554.0
2009-01-01,XE,1214,DCA,EWR,1915,2035,0.0,80.0,199.0
2009-01-01,XE,1215,EWR,DCA,1715,1838,0.0,83.0,199.0
2009-01-01,XE,1217,EWR,DCA,1300,1408,0.0,68.0,199.0


In [0]:
def feature_engineering(df):
    return (df
            .withColumn("FL_YEAR", F.year(F.col("FL_DATE")) - 2009)
            .withColumn("FL_MONTH", F.month(F.col("FL_DATE")))
            .withColumn("FL_DAYOFMONTH", F.dayofmonth(F.col("FL_DATE")))
            .withColumn("FL_DAYOFWEEK", F.dayofweek(F.col("FL_DATE")))
           )
    
airline_df = feature_engineering(airline_df)

In [0]:
# added only some for now until it is figured out how to handle nulls
numeric_features = ["OP_CARRIER_FL_NUM", "CRS_DEP_TIME", "CRS_ARR_TIME", "CRS_ELAPSED_TIME", "DISTANCE"]

categorical_features = ["ORIGIN", "DEST", "OP_CARRIER"]

# month, dayofmonth and dayofweek could be numeric or categorical
date_columns = ["FL_MONTH", "FL_DAYOFMONTH", "FL_DAYOFWEEK"]
numeric_features += date_columns # adding to numeric for now

target_col = "CANCELLED"

airline_df = airline_df.withColumn("label", F.col(target_col))

In [0]:
indexers = [
    StringIndexer(inputCol=col, outputCol=f"{col}_INDEXED", handleInvalid="error")
    for col in categorical_features
]

encoders = [
    OneHotEncoder(inputCols=[indexer.getOutputCol()], outputCols=[f"{indexer.getOutputCol()}_ENCODED"], handleInvalid="error")
    for indexer in indexers
]

assembler = VectorAssembler(inputCols=[encoder.getOutputCols()[0] for encoder in encoders] + numeric_features, outputCol="features")

stages = indexers + encoders + [assembler]

In [0]:
# it makes sense to split the data before fitting the pipeline
# the other scenario would not be a good ML practice

(train_df, test_df) = airline_df.randomSplit([0.7, 0.3], seed=SEED)

In [0]:
pipeline = Pipeline(stages=stages)
pipeline_model = pipeline.fit(train_df)

In [0]:
train_df = pipeline_model.transform(train_df)
test_df = pipeline_model.transform(test_df)

In [0]:
# low number of iterations for testing
lr = LogisticRegression(labelCol="label", featuresCol="features", maxIter=3, regParam=0.1)

In [0]:
lr_model = lr.fit(train_df)

In [0]:
predictions = lr_model.transform(test_df)

In [0]:
display(predictions.limit(DISPLAY_LIMIT))

FL_DATE,OP_CARRIER,OP_CARRIER_FL_NUM,ORIGIN,DEST,CRS_DEP_TIME,CRS_ARR_TIME,CANCELLED,CRS_ELAPSED_TIME,DISTANCE,label,FL_YEAR,FL_MONTH,FL_DAYOFMONTH,FL_DAYOFWEEK,ORIGIN_INDEXED,DEST_INDEXED,OP_CARRIER_INDEXED,ORIGIN_INDEXED_ENCODED,DEST_INDEXED_ENCODED,OP_CARRIER_INDEXED_ENCODED,features,rawPrediction,probability,prediction
2009-01-01,9E,2108,OKC,MSP,700,915,0.0,135.0,695.0,0.0,0,1,1,5,62.0,14.0,9.0,"Map(vectorType -> sparse, length -> 278, indices -> List(62), values -> List(1.0))","Map(vectorType -> sparse, length -> 278, indices -> List(14), values -> List(1.0))","Map(vectorType -> sparse, length -> 18, indices -> List(9), values -> List(1.0))","Map(vectorType -> sparse, length -> 582, indices -> List(62, 292, 565, 574, 575, 576, 577, 578, 579, 580, 581), values -> List(1.0, 1.0, 1.0, 2108.0, 700.0, 915.0, 135.0, 695.0, 1.0, 1.0, 5.0))","Map(vectorType -> dense, length -> 2, values -> List(4.024359072190183, -4.024359072190183))","Map(vectorType -> dense, length -> 2, values -> List(0.982439023309676, 0.01756097669032397))",0.0
2009-01-01,9E,2115,MSP,ALO,2245,2343,0.0,58.0,166.0,0.0,0,1,1,5,14.0,276.0,9.0,"Map(vectorType -> sparse, length -> 278, indices -> List(14), values -> List(1.0))","Map(vectorType -> sparse, length -> 278, indices -> List(276), values -> List(1.0))","Map(vectorType -> sparse, length -> 18, indices -> List(9), values -> List(1.0))","Map(vectorType -> sparse, length -> 582, indices -> List(14, 554, 565, 574, 575, 576, 577, 578, 579, 580, 581), values -> List(1.0, 1.0, 1.0, 2115.0, 2245.0, 2343.0, 58.0, 166.0, 1.0, 1.0, 5.0))","Map(vectorType -> dense, length -> 2, values -> List(2.392810061631693, -2.392810061631693))","Map(vectorType -> dense, length -> 2, values -> List(0.9162773890585681, 0.08372261094143185))",0.0
2009-01-01,9E,2120,STL,MSP,1610,1800,0.0,110.0,449.0,0.0,0,1,1,5,30.0,14.0,9.0,"Map(vectorType -> sparse, length -> 278, indices -> List(30), values -> List(1.0))","Map(vectorType -> sparse, length -> 278, indices -> List(14), values -> List(1.0))","Map(vectorType -> sparse, length -> 18, indices -> List(9), values -> List(1.0))","Map(vectorType -> sparse, length -> 582, indices -> List(30, 292, 565, 574, 575, 576, 577, 578, 579, 580, 581), values -> List(1.0, 1.0, 1.0, 2120.0, 1610.0, 1800.0, 110.0, 449.0, 1.0, 1.0, 5.0))","Map(vectorType -> dense, length -> 2, values -> List(3.8932313941736694, -3.8932313941736694))","Map(vectorType -> dense, length -> 2, values -> List(0.9800276387577971, 0.01997236124220292))",0.0
2009-01-01,9E,2122,CLE,MSP,1343,1458,0.0,135.0,622.0,0.0,0,1,1,5,34.0,14.0,9.0,"Map(vectorType -> sparse, length -> 278, indices -> List(34), values -> List(1.0))","Map(vectorType -> sparse, length -> 278, indices -> List(14), values -> List(1.0))","Map(vectorType -> sparse, length -> 18, indices -> List(9), values -> List(1.0))","Map(vectorType -> sparse, length -> 582, indices -> List(34, 292, 565, 574, 575, 576, 577, 578, 579, 580, 581), values -> List(1.0, 1.0, 1.0, 2122.0, 1343.0, 1458.0, 135.0, 622.0, 1.0, 1.0, 5.0))","Map(vectorType -> dense, length -> 2, values -> List(3.889539879618504, -3.889539879618504))","Map(vectorType -> dense, length -> 2, values -> List(0.9799552548388525, 0.020044745161147515))",0.0
2009-01-01,9E,2125,MSP,CLE,1015,1312,0.0,117.0,622.0,0.0,0,1,1,5,14.0,34.0,9.0,"Map(vectorType -> sparse, length -> 278, indices -> List(14), values -> List(1.0))","Map(vectorType -> sparse, length -> 278, indices -> List(34), values -> List(1.0))","Map(vectorType -> sparse, length -> 18, indices -> List(9), values -> List(1.0))","Map(vectorType -> sparse, length -> 582, indices -> List(14, 312, 565, 574, 575, 576, 577, 578, 579, 580, 581), values -> List(1.0, 1.0, 1.0, 2125.0, 1015.0, 1312.0, 117.0, 622.0, 1.0, 1.0, 5.0))","Map(vectorType -> dense, length -> 2, values -> List(3.9343738746144963, -3.9343738746144963))","Map(vectorType -> dense, length -> 2, values -> List(0.9808172341880695, 0.019182765811930458))",0.0
2009-01-01,9E,2126,IND,FLL,715,1005,0.0,170.0,1005.0,0.0,0,1,1,5,48.0,27.0,9.0,"Map(vectorType -> sparse, length -> 278, indices -> List(48), values -> List(1.0))","Map(vectorType -> sparse, length -> 278, indices -> List(27), values -> List(1.0))","Map(vectorType -> sparse, length -> 18, indices -> List(9), values -> List(1.0))","Map(vectorType -> sparse, length -> 582, indices -> List(48, 305, 565, 574, 575, 576, 577, 578, 579, 580, 581), values -> List(1.0, 1.0, 1.0, 2126.0, 715.0, 1005.0, 170.0, 1005.0, 1.0, 1.0, 5.0))","Map(vectorType -> dense, length -> 2, values -> List(4.028455312217938, -4.028455312217938))","Map(vectorType -> dense, length -> 2, values -> List(0.9825095545729328, 0.017490445427067236))",0.0
2009-01-01,9E,2127,FLL,IND,1115,1412,0.0,177.0,1005.0,0.0,0,1,1,5,28.0,47.0,9.0,"Map(vectorType -> sparse, length -> 278, indices -> List(28), values -> List(1.0))","Map(vectorType -> sparse, length -> 278, indices -> List(47), values -> List(1.0))","Map(vectorType -> sparse, length -> 18, indices -> List(9), values -> List(1.0))","Map(vectorType -> sparse, length -> 582, indices -> List(28, 325, 565, 574, 575, 576, 577, 578, 579, 580, 581), values -> List(1.0, 1.0, 1.0, 2127.0, 1115.0, 1412.0, 177.0, 1005.0, 1.0, 1.0, 5.0))","Map(vectorType -> dense, length -> 2, values -> List(4.089994529449093, -4.089994529449093))","Map(vectorType -> dense, length -> 2, values -> List(0.9835362666537043, 0.016463733346295695))",0.0
2009-01-01,9E,2133,MSP,PIT,1320,1628,0.0,128.0,726.0,0.0,0,1,1,5,14.0,49.0,9.0,"Map(vectorType -> sparse, length -> 278, indices -> List(14), values -> List(1.0))","Map(vectorType -> sparse, length -> 278, indices -> List(49), values -> List(1.0))","Map(vectorType -> sparse, length -> 18, indices -> List(9), values -> List(1.0))","Map(vectorType -> sparse, length -> 582, indices -> List(14, 327, 565, 574, 575, 576, 577, 578, 579, 580, 581), values -> List(1.0, 1.0, 1.0, 2133.0, 1320.0, 1628.0, 128.0, 726.0, 1.0, 1.0, 5.0))","Map(vectorType -> dense, length -> 2, values -> List(4.002527935697559, -4.002527935697559))","Map(vectorType -> dense, length -> 2, values -> List(0.9820583858597086, 0.017941614140291384))",0.0
2009-01-01,9E,2139,IND,SAT,1655,1835,0.0,160.0,986.0,0.0,0,1,1,5,48.0,44.0,9.0,"Map(vectorType -> sparse, length -> 278, indices -> List(48), values -> List(1.0))","Map(vectorType -> sparse, length -> 278, indices -> List(44), values -> List(1.0))","Map(vectorType -> sparse, length -> 18, indices -> List(9), values -> List(1.0))","Map(vectorType -> sparse, length -> 582, indices -> List(48, 322, 565, 574, 575, 576, 577, 578, 579, 580, 581), values -> List(1.0, 1.0, 1.0, 2139.0, 1655.0, 1835.0, 160.0, 986.0, 1.0, 1.0, 5.0))","Map(vectorType -> dense, length -> 2, values -> List(4.003173606434933, -4.003173606434933))","Map(vectorType -> dense, length -> 2, values -> List(0.9820697588522994, 0.017930241147700587))",0.0
2009-01-01,9E,2141,DCA,MSN,1105,1218,0.0,133.0,707.0,0.0,0,1,1,5,21.0,83.0,9.0,"Map(vectorType -> sparse, length -> 278, indices -> List(21), values -> List(1.0))","Map(vectorType -> sparse, length -> 278, indices -> List(83), values -> List(1.0))","Map(vectorType -> sparse, length -> 18, indices -> List(9), values -> List(1.0))","Map(vectorType -> sparse, length -> 582, indices -> List(21, 361, 565, 574, 575, 576, 577, 578, 579, 580, 581), values -> List(1.0, 1.0, 1.0, 2141.0, 1105.0, 1218.0, 133.0, 707.0, 1.0, 1.0, 5.0))","Map(vectorType -> dense, length -> 2, values -> List(3.74683156907581, -3.74683156907581))","Map(vectorType -> dense, length -> 2, values -> List(0.9769513930749169, 0.023048606925083148))",0.0


In [0]:
# areaUnderROC
auc_evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction")
auc_evaluator.evaluate(predictions)

Out[51]: 0.7501036830680229

In [0]:
# Accuracy
predictions.filter(F.col("label") == F.col("prediction")).count() / predictions.count()

Out[52]: 0.9812231581050038