## Imports

In [0]:
from pyspark.sql import functions as f
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType, NullType, ShortType, DateType, BooleanType, BinaryType, TimestampType
from pyspark.sql import SQLContext
from pyspark.sql.functions import trim
import plotly.express as px
import urllib.request
from pyspark.sql import Window
from pyspark.sql.functions import hour
from pyspark.sql.functions import udf, concat, lit, col
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import TrainValidationSplit
from pyspark.ml.classification import GBTClassifier
from pyspark.ml import PipelineModel
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
import pyspark.sql.functions as F
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql.types import FloatType
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt


sqlContext = SQLContext(sc)


## Utilities

In [0]:
# FUNCTION TO CHECK IF FILE EXISTS
# From https://forums.databricks.com/questions/20129/how-to-check-file-exists-in-databricks.html
def file_exists(path):
  try:
    dbutils.fs.ls(path)
    return True
  except Exception as e:
    if 'java.io.FileNotFoundException' in str(e):
      return False
    else:
      raise
      
# FUNCTION TO COUNT NULLS/NANS
# MAX: added some type safety
def check_nulls_nans(df):
    null_counts_df = df.select([f.count(f.when(f.isnan(c.name) | f.isnull(c.name), c.name)).alias(c.name) for c in df.schema.fields if not isinstance(c.dataType, TimestampType) ])
    return null_counts_df

## Environment

In [0]:
# ENVIRONMENTAL VARIABLES SETUP
# Run validations: True | False
VALIDATE=False
# Verbose: True | False
VERBOSE=False
# Don't use persisted data on disk, instead recompute everything: True | False
FORCE_COMPUTE_ALL=False
# Persist everything to disk: True | False
PERSIST_ALL=False
# Persist models to disk: True | False
PERSIST_MODELS = False
# Datasets: 
# 'Toy': Initial dev data
# '2015': All of 2015
# '2015_2017': 2015-2017
DATASETS='2015_2019'

## User Environment

In [0]:
# Enter "MR7" to load models from disk
user_initials = "MR7"
if user_initials == "MR7":
  PERSIST_MODELS = False # ensuring that the golden models stored under MR7 folder are not overwritten!!

# Datasets

## Bronze Datasets 
Bronze datasets will be named as follows:
* airports_raw
* weather_raw
* airlines_raw

In this first section, we will establish RDDs and SQL tables with those names. 

There is no need to persist the bronze tables since they are just views of raw data.

### airports_raw

Source: https://openflights.org/data.html

In [0]:
# FUNCTION TO SCRAPE AIRPORTS DATA
def init_airports_raw():
  # Note: the airports.dat data is very small, so we will always load it in full
  source_url = 'https://raw.githubusercontent.com/jpatokal/openflights/master/data/airports.dat'
  data_file_name = 'dbfs:/team7/data/airports.dat'
  if not file_exists(data_file_name):
    print(f"Importing the data from {source_url}")
    urllib.request.urlretrieve("https://raw.githubusercontent.com/jpatokal/openflights/master/data/airports.dat", "/tmp/airports.dat") 
    dbutils.fs.mv("file:/tmp/airports.dat", data_file_name)
  else:
    print(f"Skipping import: data already exists at {data_file_name}")
  # display(dbutils.fs.ls("/data/airports.dat"))
  airports_schema = StructType([
    StructField("Airport_ID", IntegerType(), True),
    StructField("Name", StringType(), True),
    StructField("City", StringType(), True),
    StructField("Country", StringType(), True),
    StructField("IATA", StringType(), True),
    StructField("ICAO", StringType(), True),
    StructField("Latitude", DoubleType(), True),
    StructField("Longitude", DoubleType(), True),
    StructField("Altitude", IntegerType(), True),
    StructField("Timezone", DoubleType(), True),
    StructField("DST", StringType(), True),
    StructField("Tz", StringType(), True),
    StructField("Type", StringType(), True),
    StructField("Source", StringType(), True)
  ])
  df = spark.read.format("csv") \
    .option("inferSchema", "false") \
    .option("header", "false") \
    .option("sep", ",") \
    .schema(airports_schema) \
    .load(data_file_name)
  return df

In [0]:
# IMPORT AIRPORTS DATA & CREATE A VIEW
airports_raw = init_airports_raw()
airports_raw.createOrReplaceTempView("airports_raw")
print(f"airports_raw: {airports_raw.count()} rows")

In [0]:
# PRINT SCHEMA
if VERBOSE:
  airports_raw.printSchema()

In [0]:
# DISPLAY DATA
if VERBOSE:
  display(airports_raw)

###Long Holidays 
(Via https://docs.google.com/spreadsheets/d/1zv-Sydffde6luh_lVTcFidjpjsK7ojjXRciAUelbQr4/edit#gid=749082549)

In [0]:
import requests
from pyspark.sql.functions import lit

# FUNCTION TO SCRAPE Long-Holidays DATA
def init_holidays_raw():
    # Grab the Google Sheets doc and save it to file
    source_url = 'https://docs.google.com/spreadsheet/ccc?key=1zv-Sydffde6luh_lVTcFidjpjsK7ojjXRciAUelbQr4&output=csv'
    data_file_name = 'dbfs:/team7/data/holidays.csv'
    if not file_exists(data_file_name):
        print(f"Importing the data from {source_url}")
        response = requests.get(source_url) 
        assert response.status_code == 200, 'Wrong status code'
        dbutils.fs.put(data_file_name, str(response.content, 'utf-8'), True)
    else:
        print(f"Skipping import: data already exists at {data_file_name}")
        
    holidays_schema = StructType([
        StructField("Year", IntegerType(), True),
        StructField("Date", StringType(), True),
        StructField("Weekday", StringType(), True),
        StructField("Name", StringType(), False),
        StructField("Type", StringType(), False),
        ])
    
    df = spark.read.format("csv") \
        .option("inferSchema", "false") \
        .option("header", "true") \
        .option("sep", ",") \
        .schema(holidays_schema) \
        .load(data_file_name)
    
    date_split = f.split(df.Date, '/')

    return (df
            .withColumn("holiday", lit(1))
           )

In [0]:
# SCRAPE Long Holiday DATA
holidays_raw = init_holidays_raw()
holidays_raw.createOrReplaceTempView("holidays_raw")

In [0]:
if VALIDATE:
  display(holidays_raw)

### weather_raw
Source: https://data.nodc.noaa.gov/cgi-bin/iso?id=gov.noaa.ncdc:C00532

In [0]:
# FUNCTION TO IMPORT WEATHER DATA
def init_weather_raw():
  print("Loading weather")
  if DATASETS == 'Toy':
    df = spark.read.option("header", "true").parquet(f"dbfs:/mnt/mids-w261/data/datasets_final_project/weather_data/weather2015a.parquet")

    # extract weather data for the 1st quarter of 2015
    df = df.filter(f.col('date').between("2015-01-01", "2015-04-01"))
  elif DATASETS == '2015':
    df = spark.read.option("header", "true").parquet(f"dbfs:/mnt/mids-w261/data/datasets_final_project/weather_data/weather2015a.parquet")
  elif DATASETS == '2015_2017':
    df = spark.read.option("header", "true").parquet(f"dbfs:/mnt/mids-w261/data/datasets_final_project/weather_data/weather201[5-7]a.parquet")
  elif DATASETS == '2015_2019':
    df = spark.read.option("header", "true").parquet(f"dbfs:/mnt/mids-w261/data/datasets_final_project/weather_data/weather201[5-9]a.parquet")
  else:
    raise RuntimeException(f"Bad value for {DATASETS}")
  return df


In [0]:
# IMPORT WEATHER DATA & CREATE A VIEW
weather_raw = init_weather_raw()
weather_raw.createOrReplaceTempView("weather_raw")
print(f"weather_raw: {weather_raw.count()} rows")

In [0]:
# PRINT SCHEMA
if VERBOSE:
  weather_raw.printSchema()

In [0]:
# DISPLAY DATA
if VERBOSE:
  display(weather_raw)

In [0]:
# verify quarterly counts
query = f"""
    WITH quarters AS (
    select date_trunc("QUARTER", DATE) as quarter, 
           count(*) as count
    from weather_raw
    group by quarter
    )
    select date_format(quarter, "yyyy") as Year,
           date_format(quarter, "q") as Quarter, 
           count from quarters
    order by Year, Quarter
"""
if VERBOSE:
  df = spark.sql(query)
  display(df)

### airlines_raw
Source:

In [0]:
# FUNCTION TO IMPORT AIRLINES DATA
def init_airlines_raw():
  if DATASETS == 'Toy':
    airlines_raw = spark.read.option("header", "true").parquet(f"dbfs:/mnt/mids-w261/data/datasets_final_project/parquet_airlines_data_3m/*.parquet")
  elif DATASETS == '2015':
    airlines_raw = spark.read.option("header", "true").parquet(f"dbfs:/mnt/mids-w261/data/datasets_final_project/parquet_airlines_data/2015.parquet")
  elif DATASETS == '2015_2017':
    airlines_raw = spark.read.option("header", "true").parquet(f"dbfs:/mnt/mids-w261/data/datasets_final_project/parquet_airlines_data/201[5-7].parquet")
  elif DATASETS == '2015_2019':
    airlines_raw = spark.read.option("header", "true").parquet(f"dbfs:/mnt/mids-w261/data/datasets_final_project/parquet_airlines_data/201[5-9].parquet")
  else:
    raise RuntimeException(f"Bad value for {DATASETS}")
  return airlines_raw


In [0]:
# IMPORT AIRLINES DATA & CREATE A VIEW
airlines_raw = init_airlines_raw()
airlines_raw.createOrReplaceTempView("airlines_raw")
if VALIDATE:
  print(f"airlines_raw: {airlines_raw.count()} rows")

In [0]:
# PRINT SCHEMA
if VERBOSE:
  airlines_raw.printSchema()

In [0]:
# DISPLAY DATA
if VERBOSE:
  display(airlines_raw)


In [0]:
# validate quarterly counts
query = f"""
    WITH quarters AS (
      select YEAR, QUARTER, count(*) as count
      from airlines_raw
      group by YEAR, QUARTER
    )
    select YEAR as Year,
           QUARTER as Quarter, 
           count from quarters
    order by Year, Quarter
"""
if VALIDATE:
  print(query)
  df = spark.sql(query)
  display(df)


In [0]:
# join with airlines and holidays
def join_longHolidays_toAirlines(airlinesDF, holidaysDF):
    df = (holidaysDF.select("holiday","Date")
                  .withColumnRenamed("Date","FL_Date")
          .join(airlinesDF, on="FL_Date", how="right") )
    
    return df.fillna({'holiday':0})
  
airlines_raw = join_longHolidays_toAirlines(airlines_raw,holidays_raw)

In [0]:
if VERBOSE :
  display(airlines_raw)

## Silver Datasets
These have derived, cleaned and regularized data.
* airports_clean
* airlines_clean
* airlines_silver (airlines joined with airports)
* weather_clean
* weather_silver (weather joined with airports)
* silver_df (airlines-airports-weather)

### airports_clean
 - eliminate duplicates
 - select relevant columns
 - evaluate airport codes and time zones for completeness

In [0]:
# check for duplicates
if VALIDATE:
  unique_id = ["IATA", "ICAO", "Tz"]
  w = Window.partitionBy(unique_id)
  display(airports_raw.select('*', f.count("*")\
                      .over(w).alias('dupCount'))\
                      .where('dupCount > 1')\
                      .drop('dupCount'))
# OK => no duplicates 

In [0]:
# DROP DUPLICATES
airports_tmp = airports_raw.dropDuplicates(subset = ["IATA", "ICAO", "Tz"])
airports_tmp.count()

In [0]:
# FILTER & SELECT RELEVANT COLUMNS
airports_tmp = airports_tmp.select("IATA", "ICAO", "Name", "Tz")
airports_tmp.createOrReplaceTempView("airports_tmp")
airports_tmp.count()

In [0]:
# check for nulls/missing values
if VALIDATE:
  df = check_nulls_nans(airports_tmp)
  display(df)
  
# there appears to be 1 IATA with null/NaN

In [0]:
# Investigating IATA = null/Nan...
if VALIDATE:
  display(airports_raw.filter(f.isnan('IATA') | f.isnull('IATA')))

# This is not a US aiport. Since we are not going to have flights out of this airport, we can leave as is.
# What's interesting is that 'NAN' is not a missing value NaN - it is infact a legit IATA code! Would be interesting to handle if we are predicting delays for flights out of Fiji :-)

In [0]:
# check for outliers
if VALIDATE:
  display(airports_tmp.describe())
  
# there appear to be some \Ns in IATA, ICAO and Tz fields which are invalid values

In [0]:
# Investigating Tz = '\N'...
if VALIDATE:
  tmp = airports_tmp.select("*").where("length(Tz) = 2")
#   display(tmp)
  print(f"airports with invalid Tzs: {tmp.count()}")
  
# there seem to be 1021 airports with Tz = '\N'

In [0]:
# Investigating the impact of Tz = '\N'...
# Do we have any airport origins that would map to IATAs with Tz = '\N'? 
if VALIDATE:
  airports_with_invalid_Tz = airports_tmp.select("*").where("length(Tz) = 2")
  flight_origins = airlines_raw.select("ORIGIN").distinct()
  origins_with_invalid_Tzs = (flight_origins.withColumnRenamed("ORIGIN", "IATA")
                                 .join(airports_with_invalid_Tz, how="inner", on="IATA")
                           )
  print(f"airports with Tz='\\N': {airports_with_invalid_Tz.count()}")
  print(f"distinct airport origins: {flight_origins.count()}")
  print(f"airport origins with invalid Tz: {origins_with_invalid_Tzs.count()}")
  
  # there seem to be no airports that we have in the 3 year data that have Tz = '\N' . Therefore no impact.

In [0]:
# Investigating the impact of Tz = '\N'...
if VALIDATE:
  display(airports_tmp.select("*").where("length(Tz) = 2 and Country = 'United States'"))
  
# All such seem to be local/municipal airports
# ASSUMPTION: we will not be asked to predict out of such airports

In [0]:
# HANDLE Tz = '\N'
# In the event Tz = '\N' or null, let's default to eastern time
airports_tmp = ( airports_tmp.withColumn("Tz2", f.when(airports_tmp.Tz.isNull() | (airports_tmp.Tz == '\\N'), "America/New_York").otherwise(airports_tmp.Tz))
                      .drop('Tz')
                      .withColumnRenamed("Tz2", "Tz")
                     )
if VALIDATE:
  null_Tz_df = airports_tmp.filter(f.isnan('Tz') | f.isnull('Tz'))
  invalid_Tz_df = airports_tmp.select("*").where("length(Tz) = 2")
  print(f"Null Tzs: {null_Tz_df.count()}")
  print(f"Invalid Tzs: {invalid_Tz_df.count()}")

In [0]:
#Investigating IATAs = '\N'...
if VALIDATE:
  tmp = airports_tmp.select("*").where("length(IATA) != 3")
  print(f"airports with invalid IATA: {tmp.count()}")

In [0]:
#Investigating IATAs = '\N'...
if VALIDATE:
  tmp = airports_tmp.select("*").where("length(IATA) != 3")
  display(tmp)
  
# these seem to all be municipal/private/international airports with no official IATA designation from the FAA.
# there is no way to "create" any designations for these
# investigate further to see if there are any flights from origin airports such as these in the airlines data

In [0]:
# Investigating IATA = '\N' further...
# Checking to see if there are any origin airports that cannot be mapped to IATA's in airports.dat that fall under a similar category as above
if VALIDATE:
  origin_airports = airlines_raw.select("ORIGIN").distinct()
  airports_IATAs = airports_tmp.select("IATA").distinct()
  join_df = origin_airports.withColumnRenamed("ORIGIN", "IATA").join(airports_IATAs, how="left", on="IATA")
  display(join_df.select("IATA").where("length(IATA) != 3"))
  

# there appear to be no flights out of airports with IATA='\N', therefore no need to handle.
# ASSUMPTION: we will not be asked to predict delays for such airports


In [0]:
#Investigating ICAO = '\N'...
if VALIDATE:
  tmp = airports_tmp.select("*").where("length(ICAO) == 2")
  display(tmp)

  # in Brazil. No need to handle.

In [0]:
# DROP IATA='\N' & ICAO= '\N'
airports_tmp = airports_tmp.filter("length(IATA) > 2")
airports_tmp = airports_tmp.filter("length(ICAO) > 2")
airports_tmp = airports_tmp.filter("IATA != 'NAN'")
airports_tmp.count()

In [0]:
# CREATE AIRPORTS CLEAN
airports_clean = airports_tmp
airports_clean.createOrReplaceTempView("airports_clean")
print(f"airports_clean: {airports_clean.count()} rows")

In [0]:
# VALIDATE AIRPORTS CLEAN
# check for outliers
if VALIDATE:
  display(airports_clean.describe())

In [0]:
# check for nulls/missing values
if VALIDATE:
  df = check_nulls_nans(airports_clean)
  display(df)

In [0]:
# DISPLAY DATA
if VERBOSE:
  display(airports_clean)

### airlines_clean
 - eliminate duplicates
 - select relevant columns
 - evaluate and handle nulls

In [0]:
# check for duplicates
if VALIDATE:
  unique_id = ["ORIGIN", "DEST", "OP_UNIQUE_CARRIER", "OP_CARRIER_FL_NUM", "FL_DATE"]
  w = Window.partitionBy(unique_id)
  duplicate_flights = airlines_raw.select('*', f.count("*")\
                      .over(w).alias('dupCount'))\
                      .where('dupCount > 1')\
                      .drop('dupCount')
  display(duplicate_flights)
  
# OK => no duplicates  
# if any are listed, drop them

In [0]:
# ELIMINATE DUPLICATES/MULTIPLES
unique_id = ["ORIGIN", "DEST", "OP_UNIQUE_CARRIER", "OP_CARRIER_FL_NUM", "FL_DATE"]
w = Window.partitionBy(unique_id)
airlines_tmp = airlines_raw.select('*', f.count("*")\
                      .over(w).alias('dupCount'))\
                      .where('dupCount = 1')\
                      .drop('dupCount')
airlines_tmp.createOrReplaceTempView("airlines_tmp")
airlines_tmp.count()


In [0]:
# SELECT RELEVANT COLUMNS
airline_relevant_columns = ['YEAR', 'QUARTER', 'MONTH', 'DAY_OF_MONTH', 'DAY_OF_WEEK', 'FL_DATE', 
                            'OP_UNIQUE_CARRIER', 'OP_CARRIER_AIRLINE_ID', 'OP_CARRIER', 'TAIL_NUM', 'OP_CARRIER_FL_NUM', 
                            'ORIGIN', 'ORIGIN_CITY_NAME', 'ORIGIN_STATE_ABR',
                            'DEST', 'DEST_CITY_NAME', 'DEST_STATE_ABR',
                            'CRS_DEP_TIME', 'DEP_TIME', 'DEP_DELAY', 'DEP_DEL15', 'DEP_DELAY_GROUP', 'DEP_TIME_BLK', 
                            'TAXI_OUT', 'WHEELS_OFF', 'WHEELS_ON', 'TAXI_IN', 
                            'CRS_ARR_TIME', 'ARR_TIME', 'ARR_DELAY', 'ARR_DELAY_NEW', 'ARR_DEL15', 'ARR_DELAY_GROUP', 'ARR_TIME_BLK', 'AIR_TIME',
                            'CANCELLED', 'CANCELLATION_CODE', 'DIVERTED', 
                            'CRS_ELAPSED_TIME', 
                            'FLIGHTS', 'holiday',
                            'DISTANCE', 'DISTANCE_GROUP', 
                            'CARRIER_DELAY', 'WEATHER_DELAY', 'NAS_DELAY', 'SECURITY_DELAY', 'LATE_AIRCRAFT_DELAY']


airlines_tmp = airlines_tmp.select(airline_relevant_columns) 
airlines_tmp.count()

In [0]:
# check for outliers
if VALIDATE:
  display(airlines_tmp.describe())
  
# there appear to be no outliers in the "primary" fields we care about
# TAXI_IN, TAXI_OUT, WHEEL_ON, WHEELS_OFF have figures a few 100s of minutes - seems fishy - careful if/when using as feature.
# different counts for different fields => missing or null values

In [0]:
# VALIDATE AIRLINES DATA
# check for nulls/missing values
if VALIDATE:
  df = check_nulls_nans(airlines_tmp)
  display(df)
  
# lot of fields with nulls and nans

In [0]:
# CREATE AIRLINES CLEAN
airlines_clean = airlines_tmp
airlines_clean.createOrReplaceTempView("airlines_clean")
if VALIDATE:
  print(f"airlines_clean: {airlines_clean.count()} rows") 

In [0]:
# VALIDATE AIRLINES CLEAN
# check for outliers
if VALIDATE:
  display(airlines_clean.describe())


In [0]:
# check for nulls/missing values
if VALIDATE:
  df = check_nulls_nans(airlines_clean)
  display(df)
  
# Should be all zeroes as all nulls have been accounted for

In [0]:
# print schema
if VERBOSE:
  airlines_clean.printSchema()

In [0]:
# DISPLAY DATA
if VERBOSE:
  display(airlines_clean)

### airlines_silver (airlines join airports)
 - join airlines data with airports data, such that every airport has an associated ICAO code

In [0]:
# before the join, ensure that every ORIGIN in airlines_clean has a corresponding IATA, ICAO and Tz in airports_clean
# missing values imply that:
# 1 - there is no entry in airports.dat for ORIGIN airport, or
# 2 - the ICAO is coded incorrectly for the airport (should not be the case if airports_clean is really clean)
if VALIDATE:
  origins = airlines_clean.select("ORIGIN").distinct()
  destinations = airlines_clean.select("DEST").distinct()
  airports_iata = airports_clean.select("IATA").distinct()
  
  print(f"flight origins: {origins.count()}")
  print(f"flight destinations: {destinations.count()}")
  print(f"airports with IATA: {airports_iata.count()}")
  print(f"airports: {airports_clean.count()}")

In [0]:
# Investigating further...
# ORIGINs in airlines_clean with no matching IATA in airports_clean
if VALIDATE:
  missing = (airlines_clean.select("ORIGIN").distinct()).subtract(airports_clean.select("IATA").distinct())
  print(f"Origin airports with no matching IATAs: {missing.count()}")

In [0]:
# Investigating further...
# list of missing IATAs
if VALIDATE:
  flights = airlines_clean.join(missing, on="ORIGIN", how="inner")
  display(flights.select("ORIGIN", "ORIGIN_CITY_NAME").distinct())

In [0]:
# Investigating further...
# flights with origins that have no matching IATA
if VALIDATE:
  flights = airlines_clean.join(missing, on="ORIGIN", how="inner")
  print(f"Flights with no matching IATA in airports_clean: {flights.count()}")


In [0]:
# FUNCTION TO JOIN AIRLINES AND AIRPORTS DATA
# ensure LEFT JOIN between airlines and airport codes
# we want to make sure every airport is associated with an IATA from airports

def make_airlines_join_airports(airlinesDF, airportsDF):
  airport_codes_origin = airportsDF
  airport_codes_dest = airportsDF

  # update ICAO and universal timezone (Tz) and timezone offset (Timezone) for all origin aiports
  df1 = ( airlinesDF.withColumnRenamed("ORIGIN", "IATA")
        .join(airport_codes_origin, on="IATA", how="left")
        .withColumnRenamed("IATA", "ORIGIN")
        .withColumnRenamed("ICAO", "origin_icao")
        .withColumnRenamed("Tz", "origin_time_zone")
        .withColumnRenamed("Timezone", "origin_time_zone_offset")
        .withColumnRenamed("Name", "airport_name")
     )
  
  # update ICAO and universal timezone (Tz) and timezone offset (Timezone) for all destination aiports
  df2 = ( df1.withColumnRenamed("DEST", "IATA")
        .join(airport_codes_dest, on="IATA", how="left")
        .withColumnRenamed("IATA", "DEST")
        .withColumnRenamed("ICAO", "dest_icao")
        .withColumnRenamed("Tz", "dest_time_zone")
        .withColumnRenamed("Timezone", "dest_time_zone_offset")
     )
  
  # convert arrival and departure times to UTC format
  # only crs_dep_time is needed
  df3 = ( df2
         .withColumn("crs_dep_time_utc", f.to_utc_timestamp(f.format_string("%s %02d:%02d", 
                                                    df2.FL_DATE, 
                                                    (df2.CRS_DEP_TIME / 100).cast(IntegerType()), 
                                                    df2.CRS_DEP_TIME % 100), 
                                                     df2.origin_time_zone))
        .withColumn("crs_arr_time_utc", f.to_utc_timestamp(f.format_string("%s %02d:%02d", 
                                                   df2.FL_DATE, 
                                                   (df2.CRS_ARR_TIME / 100).cast(IntegerType()), 
                                                   df2.CRS_ARR_TIME % 100), 
                                                    df2.origin_time_zone))
        .withColumn("dep_time_utc", f.to_utc_timestamp(f.format_string("%s %02d:%02d", 
                                                   df2.FL_DATE, 
                                                   (df2.DEP_TIME / 100).cast(IntegerType()), 
                                                   df2.DEP_TIME % 100), 
                                                    df2.origin_time_zone))
        .withColumn("arr_time_utc", f.to_utc_timestamp(f.format_string("%s %02d:%02d", 
                                                   df2.FL_DATE, 
                                                   (df2.ARR_TIME / 100).cast(IntegerType()), 
                                                   df2.ARR_TIME % 100), 
                                                   df2.origin_time_zone))
        )
  
  # add time buckets to all other time related fields
  # Max: time buckets are only used for aggregation and join - they are not features. so we don't need to add buckets for all times
  df4 = ( df3
        .withColumn("CRS_DEP_HOUR", (df3.CRS_DEP_TIME / 100).cast(IntegerType()))  # possible feature
        .withColumn("CRS_ARR_HOUR", (df3.CRS_ARR_TIME / 100).cast(IntegerType()))  # possible feature 
        # MAX: sch_dep_time_bucket could be used to join aggregates about scheduled traffic
        .withColumn("sch_dep_time_bucket", f.date_format(f.date_trunc("Hour", df3.crs_dep_time_utc), "yyyy-MM-dd HH:mm"))
        .withColumn("sch_arr_time_bucket", f.date_format(f.date_trunc("Hour", df3.crs_arr_time_utc), "yyyy-MM-dd HH:mm"))
        .withColumn("act_dep_time_bucket", f.date_format(f.date_trunc("Hour", df3.dep_time_utc), "yyyy-MM-dd HH:mm"))
        .withColumn("act_arr_time_bucket", f.date_format(f.date_trunc("Hour", df3.arr_time_utc), "yyyy-MM-dd HH:mm"))
        # MAX: pred_time_bucket is used to join weather reports and aggregated previous delays
        .withColumn("pred_time_bucket", f.date_format(f.date_trunc("Hour", df3.crs_dep_time_utc - f.expr("INTERVAL 2 HOUR")), "yyyy-MM-dd HH:mm"))
       )
  
  return df4
      

In [0]:
# CREATE AIRLINES_SILVER
airlines_silver = make_airlines_join_airports(airlines_clean, airports_clean)
airlines_silver.createOrReplaceTempView("airlines_silver")
if VALIDATE:
  print(f"airlines_silver: {airlines_silver.count()} rows")

In [0]:
# VALIDATE AIRLINES_SILVER
# check for duplicates
if VALIDATE:
  unique_id = ["ORIGIN", "DEST", "OP_UNIQUE_CARRIER", "OP_CARRIER_FL_NUM", "FL_DATE"]
  w = Window.partitionBy(unique_id)
  display(airlines_silver.select('*', f.count("*")\
                         .over(w).alias('dupCount'))\
                         .where('dupCount > 1')\
                         .drop('dupCount'))
# OK => no duplicates  

In [0]:
# Evaluate whether every airport code has a valid ICAO
if VALIDATE:
  airlines_silver.select("*").where("origin_icao is null").count()
  airlines_silver.select("*").where("origin_time_zone is null").count()

# counts must be zero

In [0]:
# check for nulls/nans
if VALIDATE:
  display(check_nulls_nans(airlines_silver))

In [0]:
# cross-check number of records with airlines_clean
if VALIDATE:
  silver_count = airlines_silver.count()  
  clean_count = airlines_clean.count()
  print(f"silver_count {silver_count}")
  print(f"clean_count {clean_count}")  
  assert(silver_count == clean_count)

In [0]:
# check outliers
# Ensure no missing values - everything should be accounted for
if VERBOSE:
  display(airlines_silver.describe())

In [0]:
# print schema
if VERBOSE:
  airlines_silver.printSchema()

In [0]:
# display data
if VERBOSE:
  display(airlines_silver)

### weather_clean
- filter for US
- split fields where required
- convert dates to UTC (are they already in?) 
- add time bucket
- handle nulls
- compact multiple rows for the same time bucket
- interpolate missing measurements

In [0]:
# FILTER WEATHER DATA
# US only 
weather_us = weather_raw.filter((f.col('report_type') == 'FM-15') & (f.col('call_sign') != '99999'))
weather_us.createOrReplaceTempView("weather_us")
if VALIDATE:
  print(f"weather_us: {weather_us.count()} rows")

In [0]:
# verify counts of valid US stations (must equal ~2166)
if VALIDATE:
  display(weather_us.select(f.countDistinct("call_sign")))

In [0]:
# TRANSFORMATIONS
# pre-split wind and temperature fields and handle missing values
# apply scale refactoring as noted in ISD documentation
# add time bucket 
# rename columns
# collapse multiple measurements within a time bucket to a single one

def make_weather_clean(weatherDF):

  wind_split = f.split(weatherDF.WND, ',')
  tmp_split = f.split(weatherDF.TMP, ',')
  dew_split = f.split(weatherDF.DEW, ',')
    
    
  # Add one-hot encoded flags to indicate missing entries for wind, windangle, tmp and dew
  # Wind angle is the exception, only considered missing if windspeed was also missing
  weather_tmp = ( weatherDF  
                .withColumn("CALL_SIGN", trim(weatherDF.CALL_SIGN))
                .withColumn("miss_windAngle", f.when(wind_split.getItem(3) == "999", 1).otherwise(0))
                .withColumn("windangle", f.when(wind_split.getItem(3) == "0000", 0.)
                                          .otherwise(wind_split.getItem(0).cast("float")) )
                .withColumn("miss_wnd", f.when(wind_split.getItem(3) == "999", 1).otherwise(0))
                .withColumn("WND", wind_split.getItem(3).cast("float") / 10.)
                .withColumn("miss_tmp", f.when(tmp_split.getItem(0) == "+9999", 1).otherwise(0))
                .withColumn("TMP", tmp_split.getItem(0).cast("float") / 10.)
                .withColumn("miss_dew", f.when(dew_split.getItem(0) == "+9999", 1).otherwise(0))
                .withColumn("DEW", dew_split.getItem(0).cast("float") / 10.)
                .withColumn("CIG", f.split(weatherDF.CIG, ',').getItem(0).cast("float"))
                .withColumn("VIS", f.split(weatherDF.VIS, ',').getItem(0).cast("float"))
                .withColumn("time_bucket", f.date_format(f.date_trunc("hour", "date"), "yyyy-MM-dd HH:mm"))
               )
  
  # collapse multiple weather rows for a given timebucket into a single one
  unique_id = ["CALL_SIGN", "time_bucket"]
  w = Window.partitionBy(unique_id).orderBy(f.desc('DATE'))
  weather_tmp = weather_tmp.withColumn('Rank',f.dense_rank().over(w))
  weather_tmp = weather_tmp.filter(weather_tmp.Rank == 1).drop(weather_tmp.Rank)
  weather_tmp = weather_tmp.select("CALL_SIGN", "DATE", "windangle", "WND", "TMP", "DEW",
                                   "miss_windAngle", "miss_wnd", "miss_tmp", "miss_dew",
                                   "CIG", "VIS", "time_bucket", "LATITUDE", "LONGITUDE", "ELEVATION" )
  return weather_tmp




In [0]:
weather_tmp = make_weather_clean(weather_us)
weather_tmp.createOrReplaceTempView("weather_tmp")
if VALIDATE:
  print(f"weather_clean: {weather_tmp.count()} rows")

In [0]:
# VALIDATE TRANSFORMATIONS
# check for existence of multiple measurements with same time bucket - there should be none
if VALIDATE:
  unique_id = ["CALL_SIGN", "time_bucket"]
  w = Window.partitionBy(unique_id)
  display(weather_tmp.select('*', f.count("*")\
                      .over(w).alias('dupCount'))\
                      .where('dupCount > 1')\
                      .drop('dupCount'))
  
  # OK => no duplicates or multiples => no need to collapse

In [0]:
# check if there are any weather stations missing weather data that we care about i.e. CALL_SIGN, time_bucket, WND, windangle, CIG, VIS, TMP, DEW
if VALIDATE:
  df_tmp = weather_tmp.select('CALL_SIGN', 'time_bucket', 'WND', 'windangle', 'CIG', 'VIS', 'TMP', 'DEW',
                              "miss_windAngle", "miss_wnd", "miss_tmp", "miss_dew", "LATITUDE", "LONGITUDE", "ELEVATION")
  df = check_nulls_nans(df_tmp)
  display(df)
  
# All zeroes mean no nulls/nans => no need to correct for missing values

In [0]:
# CREATE WEATHER_CLEAN
weather_clean = weather_tmp
weather_clean.createOrReplaceTempView("weather_clean")
if VALIDATE:
  print(f"weather_clean: {weather_clean.count()} rows")

In [0]:
# VALIDATE WEATHER CLEAN
# check for outliers
if VERBOSE:
  display(weather_clean.describe())

In [0]:
# check for nulls/nans
if VALIDATE:
  df_tmp = weather_clean.select('CALL_SIGN', 'time_bucket', 'WND', 'windangle', 'CIG', 'VIS', 'TMP', 'DEW', 
                                "miss_windAngle", "miss_wnd", "miss_tmp", "miss_dew", "LATITUDE", "LONGITUDE", "ELEVATION")
  df = check_nulls_nans(df_tmp)
  display(df)

In [0]:
# DISPLAY DATA
if VERBOSE:
  display(weather_clean)

### weather_silver (airports join weather)
- join weather data with airport data, such that every ICAO in the airports data is associated with a CALL_SIGN in weather data

In [0]:
# before joining, ensure no nulls in airports_clean
# check for nulls/nans
if VALIDATE:
  display(check_nulls_nans(airports_clean))

In [0]:
# check to make sure there are no airports with ICAO = '\N'
display(airports_clean.select("*").where("length(ICAO) = 2"))


IATA,ICAO,Name,Tz


In [0]:
# before joining, ensure no nulls in weather_clean
# check for nulls/nans
if VALIDATE:
  display(check_nulls_nans(weather_clean))

In [0]:
# FUNCTION TO JOIN AIRPORTS WITH WEATHER
# ensure RIGHT JOIN for weather.join(airports)

def make_weather_silver(weatherDF, airportsDF):
    """ Join weather data with airports
        Weather Data (CALL_SIGN)
        Airports Data (ICAO)
    """
       
    # Join weather_raw to airport_codes based on ICAO
    ## rename CALL_SIGN to ICAO to facilitate join
    ## add a time_bucket field to capture the hour of recording
    ## select relevant columns from weather_raw to minimize working data
    ## inner join eliminates orphan weather records
    df = ( weatherDF.withColumnRenamed("CALL_SIGN", "ICAO")
                  .withColumnRenamed("DATE", "date")
                  .withColumnRenamed("NAME", "station_name")
                  .withColumnRenamed("WND", "wind")
                  .withColumnRenamed("CIG", "cig")
                  .withColumnRenamed("VIS", "vis")
                  .withColumnRenamed("TMP", "tmp")
                  .withColumnRenamed("DEW", "dew")
                  .withColumnRenamed("LATITUDE", "LATITUDE")
                  .withColumnRenamed("LONGITUDE", "LONGITUDE")
                  .withColumnRenamed("ELEVATION", "ELEVATION")
                  .select("time_bucket", "ICAO", "wind",  "windangle", "cig", "vis", "tmp", "dew",
                          "miss_windAngle", "miss_wnd", "miss_tmp", "miss_dew", "LATITUDE", "LONGITUDE", "ELEVATION")
                  .join(airportsDF, on="ICAO", how="right")
          )
    return df

In [0]:
# CREATE WEATHER_SILVER
weather_silver = make_weather_silver(weather_clean, airports_clean)
weather_silver.createOrReplaceTempView("weather_silver")
if VALIDATE:
  print(f"weather_silver: {weather_silver.count()} rows")


In [0]:
# check for duplicates
if VALIDATE:
  unique_id = ["ICAO", "time_bucket"]
  w = Window.partitionBy(unique_id)
  display(weather_silver.select('*', f.count("*")\
                        .over(w).alias('dupCount'))\
                        .where('dupCount > 1')\
                        .drop('dupCount'))
# OK => no duplicates 


In [0]:
# check for nulls/missing values - there should be none - should have all been handled by this point
if VALIDATE:
  df = check_nulls_nans(weather_silver)
  display(df)  

# all weather related fields must be zeroes
# if weather related fields are non-zero, these imply airports with no weather stations

In [0]:
# DESCRIPTIVE STATISTICS
if VALIDATE:
  # 2.63 minutes for 2015_2017
  display(weather_silver.describe())

In [0]:
# PRINT SCHEMA
if VERBOSE:
  weather_silver.printSchema()

In [0]:
# DISPLAY DATA
if VERBOSE:
  display(weather_silver)

## Silver (joined) Dataset

airports_silver joined with weather_silver

In [0]:
# FUNCTION TO JOIN AIRLINES AND WEATHER 
def make_airlines_join_weather(airlinesDF, weatherDF):
  # join airlines_silver and weather_silver on origin airport (origin_icao) and prediction time bucket (time_bucket)
  join_columns = ["ICAO","time_bucket"]
   
  # join airlines and weather data on origin airport 
  df = ( airlinesDF
        .withColumnRenamed("origin_icao", "ICAO")
        .withColumnRenamed("pred_time_bucket", "time_bucket")
        .join(weatherDF, join_columns, how="left")
        .withColumnRenamed("time_bucket", "pred_time_bucket")
        .withColumnRenamed("ICAO", "origin_icao")
       )
  return df



In [0]:
# (Madhukar) Identify common columns that need to be dropped after joining
columns_to_drop = list(set(airlines_silver.columns).intersection(set(weather_silver.columns)))
print("Drop these duplicated columns", columns_to_drop)

# CREATE SILVER DATA
silver_df = make_airlines_join_weather(airlines_silver, weather_silver).drop(weather_silver.Name)
silver_df.createOrReplaceTempView("silver_df")
if VALIDATE:
  airlines_silver_count = airlines_silver.count()  
  silver_df_count = silver_df.count()
  print(f"airlines_silver_count {airlines_silver_count}")
  print(f"silver_df_count {silver_df_count}")  
  assert(airlines_silver_count == silver_df_count)

In [0]:
if VALIDATE:
  query = f"""
select count(*) as count, "silver_df" as table from silver_df
union
select count(*), "airlines_silver" from airlines_silver
union
select count(*), "weather_silver" from weather_silver  
"""
  if VERBOSE:
    print(query)
  df = spark.sql(query)
  display(df)

In [0]:
# DESCRIPTIVE STATISTICS - all columns must have equal number of values
if VERBOSE:
  display(silver_df.describe())

In [0]:
# check for nulls/missing values - there should be none - should have all been handled by this point
if VALIDATE:
  df = check_nulls_nans(silver_df)
  display(df)  


In [0]:
# PRINT SCHEMA
if VERBOSE:
  silver_df.printSchema()

In [0]:
# DISPLAY DATA
if VERBOSE:
  display(silver_df)

## Gold Datasets
These datasets refelect datasets ready for training our models
* Gold_Train_Validate (2015-2017)
* Gold_Test (2018)
* once everything has been tested and is successful, we bring in 2019 via the same pipeline

#Enhanced Feature Engineering

##1. Departure Delay aggregates

In [0]:
def compute_airport_aggregated_delay(verbose=True):
  """
  Create a time able that summarises average delays of flights that
  departed in a given hour
  
  input: silver_df
  
  output: df with hourly delay aggregates:
  prior_count_DEP_TIME, 
  prior_avg_DEP_DELAY, 
  prior_agg_DEP_DEL15, 
  prior_avg_TAXI_OUT, 
  prior_avg_CARRIER_DELAY, 
  prior_avg_WEATHER_DELAY, 
  prior_avg_NAS_DELAY,
  prior_avg_SECURITY_DELAY, 
  prior_avg_LATE_AIRCRAFT_DELAY, 
  prior_DEP_OVERFLOW (this cant be calculated here b)
  as difference between planned and actual departures
  """
  
  table_name = 'silver_df'
  query = f"""
select ORIGIN, 
           act_dep_time_bucket as time_bucket,
           count(DEP_TIME) as prior_count_DEP_TIME, 
           avg(DEP_DELAY) as prior_avg_DEP_DELAY,            
           sum(DEP_DEL15) as prior_agg_DEP_DEL15, 
           avg(TAXI_OUT) as prior_avg_TAXI_OUT, 
           avg(CARRIER_DELAY) as prior_avg_CARRIER_DELAY, 
           avg(WEATHER_DELAY) as prior_avg_WEATHER_DELAY, 
           avg(NAS_DELAY) as prior_avg_NAS_DELAY, 
           avg(SECURITY_DELAY) as prior_avg_SECURITY_DELAY, 
           avg(LATE_AIRCRAFT_DELAY) as prior_avg_LATE_AIRCRAFT_DELAY
         from {table_name}
       group by ORIGIN, 
                act_dep_time_bucket 
    """

  if verbose:
    print(query)
    df = spark.sql(query)
    return df

  
airport_aggregated_delay = compute_airport_aggregated_delay().cache()
airport_aggregated_delay.createOrReplaceTempView("airport_aggregated_delay")

In [0]:
def compute_airport_aggregates(verbose=True):
  """
  Create a time able that summarises aggregates of flights that
  departed in a given hour
  
  input: silver_df
  
  output: df with hourly aggregates:
  count_CRS_DEP_TIME
  """
  
  table_name = 'silver_df'
  query = f"""
select ORIGIN, sch_dep_time_bucket,
           count(CRS_DEP_TIME) as count_CRS_DEP_TIME
         from {table_name}
       group by ORIGIN, 
                sch_dep_time_bucket 
    """

  if verbose:
    print(query)
    df = spark.sql(query)
    return df

  
airport_aggregates = compute_airport_aggregates().cache()
airport_aggregates.createOrReplaceTempView("airport_aggregates")

### Join departure aggregated delays with silver_df

In [0]:
# join departure aggregates
def join_airport_agg_and_silver(airport_aggregated_delay, silver_df):
  # rename columns to facilitate join
  agg_df = (airport_aggregated_delay
              .withColumnRenamed("time_bucket", "pred_time_bucket")
              .select("pred_time_bucket", "ORIGIN", "prior_count_DEP_TIME",
                      "prior_avg_DEP_DELAY", "prior_agg_DEP_DEL15", "prior_avg_TAXI_OUT",
                      "prior_avg_CARRIER_DELAY", "prior_avg_WEATHER_DELAY", "prior_avg_NAS_DELAY",
                      "prior_avg_SECURITY_DELAY", "prior_avg_LATE_AIRCRAFT_DELAY")
            )
              
  res = (silver_df
                  .withColumnRenamed("time_bucket", "pred_time_bucket")
                  .join(agg_df, on=['ORIGIN', 'pred_time_bucket'], how="left") )
  return res

gold_df_tmp_joined = join_airport_agg_and_silver(airport_aggregated_delay, silver_df)

In [0]:
# add scheduled hourly aggregate count_CRS_DEP_TIME and join it
def add_hourly_agg(airport_aggregates, gold_df_tmp_joined):
  # rename columns to facilitate join
  hourly_agg_df = (airport_aggregates
              .select("sch_dep_time_bucket", "ORIGIN", "count_CRS_DEP_TIME")
            )
              
  res = (gold_df_tmp_joined
                  .join(hourly_agg_df, on=['ORIGIN', 'sch_dep_time_bucket'], how="left") )
  return res

gold_df_tmp_joined = add_hourly_agg(airport_aggregates, gold_df_tmp_joined)

In [0]:
if VALIDATE:
  gold_df_tmp_count = gold_df_tmp_joined.count()  
  clean_count = airlines_clean.count()
  print(f"gold_df_tmp_count {gold_df_tmp_count}")
  print(f"clean_count {clean_count}")  
  assert(gold_df_tmp_count == clean_count)

## 2. Tracking flight delays in previous leg using TAIL_NUM

In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import col, lag, lead

window_w_offset = (Window.partitionBy('TAIL_NUM', 'FL_DATE').orderBy('dep_time_utc'))

# if null, impute 2 (so that prior_TN_DEP_DEL15 is OHE by 0, 1 and 2)
gold_df_tmp_joined = (gold_df_tmp_joined.withColumn('prior_TN_DEP_DEL15', lag(col('DEP_DEL15'), 1, 2).over(window_w_offset))).cache()


# introduce AIR_TIME_FLAG that flags 1 for >120min (that is, this flag is 0 for prior_TN_DEP_DEL15=[0,1], and 1 for prior_TN_DEP_DEL15=2)
gold_df_tmp_joined = ( gold_df_tmp_joined
                .withColumn("AIR_TIME_FLAG", f.when(gold_df_tmp_joined.AIR_TIME>120, 1).otherwise(0))
               )


gold_df_tmp_joined.createOrReplaceTempView("gold_df_tmp_joined")

In [0]:
# replace prior_TN_DEP_DEL15 by 2 for flights <120min duration (so that prior_TN_DEP_DEL15 is OHE by 0, 1 and 2)
from pyspark.sql import functions as F
gold_df_tmp_joined = gold_df_tmp_joined.withColumn("prior_TN_DEP_DEL15", F.when(F.col("AIR_TIME")<120, 2).otherwise(F.col("prior_TN_DEP_DEL15")))
gold_df_tmp_joined.createOrReplaceTempView("gold_df_tmp_joined")

#### *Convert Categorical Variables to Strings, and Numericals to Doubles

In [0]:
# identify all categorical columns and convert them to string, and convert all numerical columns to doubles

full_categorical_features = ['ORIGIN', 'sch_dep_time_bucket', 'ICAO', 'pred_time_bucket', 'DEST', 'YEAR', 'QUARTER', 'MONTH', 'DAY_OF_MONTH', 'DAY_OF_WEEK', 'FL_DATE', 'OP_UNIQUE_CARRIER', 'OP_CARRIER_AIRLINE_ID', 'OP_CARRIER', 'TAIL_NUM', 'OP_CARRIER_FL_NUM', 'ORIGIN_AIRPORT_ID', 'ORIGIN_AIRPORT_SEQ_ID', 'ORIGIN_CITY_MARKET_ID', 'ORIGIN_CITY_NAME', 'ORIGIN_STATE_ABR', 'ORIGIN_STATE_FIPS', 'ORIGIN_STATE_NM', 'ORIGIN_WAC', 'DEST_AIRPORT_ID', 'DEST_AIRPORT_SEQ_ID', 'DEST_CITY_MARKET_ID', 'DEST_CITY_NAME', 'DEST_STATE_ABR', 'DEST_STATE_FIPS', 'DEST_STATE_NM', 'DEST_WAC', 'CRS_DEP_TIME', 'DEP_TIME',  'DEP_DEL15', 'DEP_DELAY_GROUP', 'DEP_TIME_BLK', 'WHEELS_OFF', 'WHEELS_ON', 'CRS_ARR_TIME', 'ARR_TIME', 'ARR_DEL15', 'ARR_DELAY_GROUP', 'ARR_TIME_BLK', 'CANCELLED', 'CANCELLATION_CODE', 'DIVERTED',  'DISTANCE_GROUP', 'CARRIER_DELAY', 'WEATHER_DELAY', 'NAS_DELAY', 'SECURITY_DELAY', 'LATE_AIRCRAFT_DELAY', 'airport_name',  'origin_time_zone', 'dest_icao', 'dest_time_zone', 'crs_dep_time_utc', 'crs_arr_time_utc', 'dep_time_utc', 'arr_time_utc', 'sch_arr_time_bucket', 'act_dep_time_bucket', 'act_arr_time_bucket', 'station_name', 'IATA', 'Name', 'Timezone', 'Tz', 'ORIGIN-DEST', 'prior_TN_DEP_DEL15', 'holiday', 'AIR_TIME_FLAG'
                            ]

full_numerical_features = ['DEP_DELAY', 'DEP_DELAY_NEW', 'TAXI_OUT', 'TAXI_IN', 'ARR_DELAY', 'ARR_DELAY_NEW', 'CRS_ELAPSED_TIME', 'ACTUAL_ELAPSED_TIME', 'AIR_TIME', 'FLIGHTS', 'DISTANCE', 'origin_time_zone_offset', 'dest_time_zone_offset', 'CRS_DEP_HOUR', 'CRS_ARR_HOUR', 'wind', 'windangle', 'cig', 'vis', 'tmp', 'dew', 'LATITUDE', 'LONGITUDE', 'ELEVATION', 'count_CRS_DEP_TIME', 'prior_count_DEP_TIME', 'prior_agg_DEP_DEL15', 'prior_avg_DEP_DELAY', 'prior_avg_TAXI_OUT', 'prior_avg_CARRIER_DELAY', 'prior_avg_WEATHER_DELAY', 'prior_avg_NAS_DELAY', 'prior_avg_SECURITY_DELAY', 'prior_avg_LATE_AIRCRAFT_DELAY', "miss_windAngle", "miss_wnd", "miss_tmp", "miss_dew", 'prior_dep_agg_w_overflow'
                          ]

from pyspark.sql.functions import expr, col, column

def cast_features(df, full_categorical_features, full_numerical_features):
  for cat_feature in full_categorical_features:
    if cat_feature in df.columns:
      df = (df.withColumn(cat_feature, col(cat_feature).cast("string")))
    
  for num_feature in full_numerical_features:
    if num_feature in df.columns:
      df = (df.withColumn(num_feature, col(num_feature).cast("double")))
  
  return df

gold_df_tmp_cast = cast_features(gold_df_tmp_joined, full_categorical_features, full_numerical_features)

In [0]:
if VERBOSE :
  gold_df_tmp_cast.count()
  gold_df_tmp_cast.printSchema()

In [0]:
# identify the features that need to be saved in gold_df
feature_list = ['ORIGIN', 'DEST', 'ORIGIN-DEST', 
                'YEAR', 'QUARTER', 'MONTH', 'DAY_OF_MONTH', 'DAY_OF_WEEK', 'CRS_DEP_HOUR', 
                'FL_DATE', 'TAIL_NUM', 'FLIGHTS', 'DISTANCE', 'DISTANCE_GROUP',
                'wind', 'windangle', 'cig', 'vis', 'tmp', 'dew', "miss_windAngle", "miss_wnd", "miss_tmp", "miss_dew", 
                'LATITUDE', 'LONGITUDE', 'ELEVATION', 'holiday',
                'DEP_DEL15', 
                'count_CRS_DEP_TIME', 'prior_count_DEP_TIME', 'prior_agg_DEP_DEL15', 
                'prior_avg_DEP_DELAY', 'prior_avg_TAXI_OUT', 
                'prior_avg_CARRIER_DELAY', 'prior_avg_WEATHER_DELAY', 'prior_avg_NAS_DELAY', 
                'prior_avg_SECURITY_DELAY', 'prior_avg_LATE_AIRCRAFT_DELAY', 'prior_TN_DEP_DEL15', 'AIR_TIME_FLAG'
               ]




In [0]:
def df_trim(old_df, feature_list, keep_features=True):
  """
  returns trimmed DF that only keeps (if True, else drops) columns in features_list
  
  Input: original dataframe, list of features, keep features=True will keep only these features
  Output: trimmed dataframe
  """

  if keep_features:
    del_feature_list = set(list(silver_df.columns)) - set(feature_list)
  else:
    del_feature_list = feature_list
    
  trimmed_df = old_df.drop(*del_feature_list) # drop all columns in del_feature_list
  return trimmed_df

gold_df = df_trim(gold_df_tmp_cast, feature_list, True).cache() # this is needed to get rid of columns (like DIV1 etc) that have all nulls
gold_df.createOrReplaceTempView("gold_df")

In [0]:
if VERBOSE :
  gold_df.printSchema()

#### *Save Gold Data

In [0]:
# SAVE GOLD DATASET TO DISK
gold_data_path = f"/team7/{user_initials}/{DATASETS}/gold_dataset"

def persist_gold(df):
  # SAVE GOLD DATASET TO DISK
  (df.write.mode("overwrite").format("parquet").save(gold_data_path))

def init_gold_from_disk():
  # IMPORT GOLD DATA FROM DISK
  print("importing gold data from disk")
  df = sqlContext.read.option("header", "true").parquet(f"dbfs:" + gold_data_path)
  return df

if not FORCE_COMPUTE_ALL and file_exists(gold_data_path):
  gold_df = init_gold_from_disk()
  print("Gold data already exists at {}!".format(gold_data_path))

gold_df.createOrReplaceTempView("gold_df")

# Model Building

#### * Split Gold Data into Train-Validate, Test

In [0]:
# drop columns have do not have multiple values to avoid erroring out while indexing
if DATASETS == 'Toy':
  gold_df = gold_df.drop("YEAR").drop("QUARTER") 
elif DATASETS == '2015':
  gold_df = gold_df.drop("YEAR")

gold_df = gold_df.na.drop() 

if DATASETS == '2015-2019':
  train = gold_df.filter((gold_df.YEAR== '2015') | (gold_df.YEAR== '2016') | (gold_df.YEAR== '2017')| (gold_df.YEAR== '2018'))
  test = gold_df.filter((gold_df.YEAR== '2019'))
else:
  train, test = gold_df.randomSplit([.80, 0.20], seed=104)
  train.createOrReplaceTempView("train")
  test.createOrReplaceTempView("test")

In [0]:
if VERBOSE:
  display(gold_df.describe().cache())

#### *Balance Data: Generate Class Weights

In [0]:
# Uncomment to generate class weights
# dataset_size=float(train.count())
# numPositives=train.select("DEP_DEL15").where('DEP_DEL15 == 1').count()
# per_ones=(float(numPositives)/float(dataset_size))*100
# numNegatives=float(dataset_size-numPositives)
# print('The number of ones are {}'.format(numPositives))
# print('Percentage of ones are {}'.format(per_ones))
# BalancingRatio= numNegatives/dataset_size
# print('BalancingRatio = {}'.format(BalancingRatio))

# pre-calculated values here to save time
if DATASETS == "Toy":
  BalancingRatio = 0.7517532554108707
elif DATASETS == "2015":
  BalancingRatio = 0.7736956152970962
elif DATASETS == "2015_2017":
  BalancingRatio = 0.7794535029323947
else:
  BalancingRatio = 0.7792115713310707  
train=train.withColumn("classWeights", F.when(train.DEP_DEL15 == 1,BalancingRatio).otherwise(1-BalancingRatio))
train.select("classWeights").show(5)

### *Pipeline

In [0]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StandardScaler
from pyspark.ml.feature import RobustScaler
from pyspark.ml.feature import RFormula
from pyspark.ml.classification import LogisticRegression

"""
# RFormula treats strings as categorical features and one-hot encodes them
# RFormula creates two columns "label" and "features"
# "features" column is a big long string of all features
"""

# add in QUARTER, YEAR if 2015 or larger datasets, resp.
rForm_Toy = (RFormula(formula = "DEP_DEL15 ~ DEST + ORIGIN + MONTH + DAY_OF_MONTH + DAY_OF_WEEK + DISTANCE + DISTANCE_GROUP + CRS_DEP_HOUR + wind + windangle + cig + vis + tmp + dew + LATITUDE + LONGITUDE + ELEVATION + count_CRS_DEP_TIME  + prior_count_DEP_TIME + prior_agg_DEP_DEL15 + prior_avg_DEP_DELAY + prior_avg_TAXI_OUT + prior_avg_CARRIER_DELAY + prior_avg_WEATHER_DELAY + prior_avg_NAS_DELAY + prior_avg_SECURITY_DELAY + prior_avg_LATE_AIRCRAFT_DELAY + holiday + miss_windAngle + miss_wnd + miss_tmp + miss_dew + AIR_TIME_FLAG", featuresCol="num_ohe_features", labelCol="label").setHandleInvalid("skip"))

rForm_2015 = (RFormula(formula = "DEP_DEL15 ~ QUARTER + DEST + ORIGIN + MONTH + DAY_OF_MONTH + DAY_OF_WEEK + DISTANCE + DISTANCE_GROUP + CRS_DEP_HOUR + wind + windangle + cig + vis + tmp + dew + LATITUDE + LONGITUDE + ELEVATION + count_CRS_DEP_TIME  + prior_count_DEP_TIME + prior_agg_DEP_DEL15 + prior_avg_DEP_DELAY + prior_avg_TAXI_OUT + prior_avg_CARRIER_DELAY + prior_avg_WEATHER_DELAY + prior_avg_NAS_DELAY + prior_avg_SECURITY_DELAY + prior_avg_LATE_AIRCRAFT_DELAY + holiday + miss_windAngle + miss_wnd + miss_tmp + miss_dew + AIR_TIME_FLAG", featuresCol="num_ohe_features", labelCol="label").setHandleInvalid("skip"))

rForm = (RFormula(formula = "DEP_DEL15 ~ YEAR + QUARTER + DEST + ORIGIN + MONTH + DAY_OF_MONTH + DAY_OF_WEEK + DISTANCE + DISTANCE_GROUP + CRS_DEP_HOUR + wind + windangle + cig + vis + tmp + dew + LATITUDE + LONGITUDE + ELEVATION + count_CRS_DEP_TIME  + prior_count_DEP_TIME + prior_agg_DEP_DEL15 + prior_avg_DEP_DELAY + prior_avg_TAXI_OUT + prior_avg_CARRIER_DELAY + prior_avg_WEATHER_DELAY + prior_avg_NAS_DELAY + prior_avg_SECURITY_DELAY + prior_avg_LATE_AIRCRAFT_DELAY + holiday + miss_windAngle + miss_wnd + miss_tmp + miss_dew + AIR_TIME_FLAG", featuresCol="num_ohe_features", labelCol="label").setHandleInvalid("skip"))  
  
if DATASETS == "Toy":
  num_ohe_stage = rForm_Toy
elif DATASETS == '2015':
  num_ohe_stage = rForm_2015
else:
  num_ohe_stage = rForm


## Model Implementations

In [0]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import TrainValidationSplit
from pyspark.ml.classification import GBTClassifier
from pyspark.ml import PipelineModel
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
import pyspark.sql.functions as F
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql.types import FloatType

In [0]:
# Helper functions to save the "best model" to disk
gold_data_path = f"/team7/{user_initials}/{DATASETS}/gold_dataset"

def persist_model(tvsFitted, model_name):
  model_path = f"/team7/MR/{DATASETS}/{model_name}"
  best_model = tvsFitted.bestModel
  best_model.write().overwrite().save(model_path)

# load model from disk
def load_model(model_name):
  model_path = f"/team7/MR/{DATASETS}/{model_name}"
  loaded_model = PipelineModel.load(model_path)
  return loaded_model

###Logistic Regression Pipeline

In [0]:
# logistic regression takes in "label" and "features"
lr = LogisticRegression().setLabelCol("label").setFeaturesCol("features").setWeightCol("classWeights")
lr_pipeline = Pipeline(stages=[num_ohe_stage, VectorAssembler(inputCols=["num_ohe_features"], outputCol="features"), lr])

lr_params = (ParamGridBuilder().addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])
                            .addGrid(lr.regParam, [0.1, 2.0])
                            .build())


### Gradient Boosted Trees (GBT) Pipeline

In [0]:
gbt = GBTClassifier().setLabelCol("label").setFeaturesCol("features")
gbt_pipeline = Pipeline(stages=[num_ohe_stage, VectorAssembler(inputCols=["num_ohe_features"], outputCol="features"), gbt])


gbt_params = (ParamGridBuilder()
             .addGrid(gbt.maxDepth, [2, 4, 6])
             .addGrid(gbt.maxBins, [20, 30])
             .addGrid(gbt.maxIter, [10, 15])
             .build())

In [0]:
def run_and_persist_best_model(model_name, params, pipeline, data):
  evaluator = (BinaryClassificationEvaluator().setMetricName("areaUnderPR")
                                            .setRawPredictionCol("rawPrediction")
                                            .setLabelCol("label"))
  crossval = (CrossValidator(estimator=pipeline,
                          estimatorParamMaps=params,
                          evaluator=evaluator,
                          numFolds=5))  

  # train the model
  crossvalFitted = crossval.fit(data)
  persist_model(crossvalFitted, model_name)

In [0]:
if PERSIST_MODELS:
  run_and_persist_best_model("lr", lr_params, lr_pipeline, train)

In [0]:
if PERSIST_MODELS:
  run_and_persist_best_model("gbt", gbt_params, gbt_pipeline, train)

### Logistic Regression Results

In [0]:
# # evaluate the model
lr_transform_test = load_model("lr").transform(test)

areaUnderROC = BinaryClassificationEvaluator().setMetricName("areaUnderROC").evaluate(lr_transform_test)
areaUnderPR = BinaryClassificationEvaluator().setMetricName("areaUnderPR").evaluate(lr_transform_test)
f1 = MulticlassClassificationEvaluator().setMetricName("f1").evaluate(lr_transform_test)
weightedPrecision = MulticlassClassificationEvaluator().setMetricName("weightedPrecision").evaluate(lr_transform_test)
weightedRecall = MulticlassClassificationEvaluator().setMetricName("weightedRecall").evaluate(lr_transform_test)
accuracy = MulticlassClassificationEvaluator().setMetricName("accuracy").evaluate(lr_transform_test)

print("areaUnderROC = ", areaUnderROC)
print("areaUnderPR = ", areaUnderPR)
print("f1 = ", f1)
print("weightedPrecision = ", weightedPrecision)
print("weightedRecall = ", weightedRecall)
print("accuracy = ", accuracy)


### GBT Results

In [0]:
# # evaluate the model
gbt_transform_test = load_model("gbt").transform(test)

areaUnderROC = BinaryClassificationEvaluator().setMetricName("areaUnderROC").evaluate(gbt_transform_test)
areaUnderPR = BinaryClassificationEvaluator().setMetricName("areaUnderPR").evaluate(gbt_transform_test)
f1 = MulticlassClassificationEvaluator().setMetricName("f1").evaluate(gbt_transform_test)
weightedPrecision = MulticlassClassificationEvaluator().setMetricName("weightedPrecision").evaluate(gbt_transform_test)
weightedRecall = MulticlassClassificationEvaluator().setMetricName("weightedRecall").evaluate(gbt_transform_test)
accuracy = MulticlassClassificationEvaluator().setMetricName("accuracy").evaluate(gbt_transform_test)

print("areaUnderROC = ", areaUnderROC)
print("areaUnderPR = ", areaUnderPR)
print("f1 = ", f1)
print("weightedPrecision = ", weightedPrecision)
print("weightedRecall = ", weightedRecall)
print("accuracy = ", accuracy)


## Stacking

In [0]:
gbt_transform_train = load_model("gbt").transform(train)
lr_transform_train = load_model("lr").transform(train)
gbt_transform_train = gbt_transform_train.withColumnRenamed("prediction", "gbt_prediction").withColumn("gbt_prediction", col("gbt_prediction").cast("string"))
lr_transform_train = lr_transform_train.withColumnRenamed("prediction", "lr_prediction").withColumn("lr_prediction", col("lr_prediction").cast("string"))
gbt_transform_train.createOrReplaceTempView("gbt_transform_train")
lr_transform_train.createOrReplaceTempView("lr_transform_train")

In [0]:
stacked_train_predictions_df = lr_transform_train.select('ORIGIN', 'DEST', 'MONTH', 'DAY_OF_MONTH', 'DAY_OF_WEEK', 'FL_DATE', 'TAIL_NUM', 'DEP_DEL15', 'FLIGHTS', 'CRS_DEP_HOUR','lr_prediction').join(gbt_transform_train.select('ORIGIN', 'DEST', 'MONTH', 'DAY_OF_MONTH', 'DAY_OF_WEEK', 'FL_DATE', 'TAIL_NUM', 'DEP_DEL15', 'FLIGHTS', 'CRS_DEP_HOUR', 'gbt_prediction'), how="left", on=['ORIGIN', 'DEST', 'MONTH', 'DAY_OF_MONTH', 'DAY_OF_WEEK', 'FL_DATE', 'TAIL_NUM', 'DEP_DEL15', 'FLIGHTS', 'CRS_DEP_HOUR']).select("DEP_DEL15", "lr_prediction", "gbt_prediction")

stacked_train_predictions_df = (stacked_train_predictions_df.withColumn("DEP_DEL15", col("DEP_DEL15").cast("string")))
stacked_train_predictions_df.createOrReplaceTempView("stacked_train_predictions_df")

In [0]:
stacked_train_predictions_df.printSchema()

### Build pipeline of lr and gbt predictions

In [0]:
stacked_rForm = (RFormula(formula = "DEP_DEL15 ~ lr_prediction + gbt_prediction").setHandleInvalid("skip"))
meta_gbt = GBTClassifier()#.setLabelCol("label").setFeaturesCol("features")
stacked_pipeline = Pipeline(stages=[stacked_rForm, meta_gbt])

meta_params = (ParamGridBuilder()
             .addGrid(meta_gbt.maxDepth, [2, 4, 6])
             .addGrid(meta_gbt.maxBins, [20, 30])
             .addGrid(meta_gbt.maxIter, [10, 15])
             .build())

if PERSIST_MODELS: 
  run_and_persist_best_model("stacked_best_model", meta_params, stacked_pipeline, stacked_train_predictions_df)

### pass the test dataset through the stacked_pipeline

In [0]:
# generate lr_prediction and gbt_prediction for test dataset using lr_best_model and gbt_best_model

gbt_transform_test = load_model("gbt").transform(test)
lr_transform_test = load_model("lr").transform(test)
gbt_transform_test = gbt_transform_test.withColumnRenamed("prediction", "gbt_prediction").withColumn("gbt_prediction", col("gbt_prediction").cast("string"))
lr_transform_test = lr_transform_test.withColumnRenamed("prediction", "lr_prediction").withColumn("lr_prediction", col("lr_prediction").cast("string"))
gbt_transform_test.createOrReplaceTempView("gbt_transform_test")
lr_transform_test.createOrReplaceTempView("lr_transform_test")

# create new_test_df with lr_test_prediction and gbt_test_prediction
stacked_test_predictions_df = lr_transform_test.select('ORIGIN', 'DEST', 'MONTH', 'DAY_OF_MONTH', 'DAY_OF_WEEK', 'FL_DATE', 'TAIL_NUM', 'DEP_DEL15', 'FLIGHTS', 'CRS_DEP_HOUR','lr_prediction').join(gbt_transform_test.select('ORIGIN', 'DEST', 'MONTH', 'DAY_OF_MONTH', 'DAY_OF_WEEK', 'FL_DATE', 'TAIL_NUM', 'DEP_DEL15', 'FLIGHTS', 'CRS_DEP_HOUR', 'gbt_prediction'), how="left", on=['ORIGIN', 'DEST', 'MONTH', 'DAY_OF_MONTH', 'DAY_OF_WEEK', 'FL_DATE', 'TAIL_NUM', 'DEP_DEL15', 'FLIGHTS', 'CRS_DEP_HOUR']).select("DEP_DEL15", "lr_prediction", "gbt_prediction")

stacked_test_predictions_df = stacked_test_predictions_df.withColumn("DEP_DEL15", col("DEP_DEL15").cast("string"))#.withColumnRenamed("DEP_DEL15", "label")
stacked_test_predictions_df.createOrReplaceTempView("stacked_test_predictions_df")

In [0]:
# load the model and transform the stacked_test_predictions_df
model_name = "stacked_best_model"
loaded_prediction_and_labels = load_model(model_name).transform(stacked_test_predictions_df)

areaUnderROC = BinaryClassificationEvaluator().setMetricName("areaUnderROC").evaluate(loaded_prediction_and_labels)
areaUnderPR = BinaryClassificationEvaluator().setMetricName("areaUnderPR").evaluate(loaded_prediction_and_labels)
f1 = MulticlassClassificationEvaluator().setMetricName("f1").evaluate(loaded_prediction_and_labels)
weightedPrecision = MulticlassClassificationEvaluator().setMetricName("weightedPrecision").evaluate(loaded_prediction_and_labels)
weightedRecall = MulticlassClassificationEvaluator().setMetricName("weightedRecall").evaluate(loaded_prediction_and_labels)
accuracy = MulticlassClassificationEvaluator().setMetricName("accuracy").evaluate(loaded_prediction_and_labels)

print("areaUnderROC = ", areaUnderROC)
print("areaUnderPR = ", areaUnderPR)
print("f1 = ", f1)
print("weightedPrecision = ", weightedPrecision)
print("weightedRecall = ", weightedRecall)
print("accuracy = ", accuracy)

In [0]:
metrics = loaded_prediction_and_labels.select("label", "prediction").rdd.map(tuple)
metrics = MulticlassMetrics(metrics)
print("Confusion Matrix:", metrics.confusionMatrix().toArray())

In [0]:
y_true = loaded_prediction_and_labels.select("label").collect()
y_pred = loaded_prediction_and_labels.select("prediction").collect()
print(classification_report(y_true, y_pred))