In [0]:
import numpy as np
import pandas as pd
import plotly as px
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from datetime import datetime
import time

import pyspark.sql.functions as f
import pyspark.sql.types as t
from pyspark.sql.functions import isnan, when, count, col

from pyspark.ml.feature import StringIndexer, VectorIndexer, VectorAssembler, StandardScaler, OneHotEncoder, SQLTransformer
from pyspark.ml.classification import LogisticRegression 
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline
from pyspark.sql import Window

from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

import mlflow
import mlflow.spark

import sys
import os
from pyspark.ml.classification import RandomForestClassificationModel


## False Positive Analysis

Since our business case is focused around precision due to a high cost of false positives, we wanted to take a closer look at the false positives in our validation set to see if we could gain further insight into which features might be making the largest contribution to the false positive rate.

In [0]:
# load the validation predictions
val_pred = spark.read.format("csv").option("header", "true").load("dbfs:/mnt/mids-w261/team20SSDK/reports/fp_results.csv") 
print(val_pred.count())

In [0]:
# create a false positive column to quickly identify false positive rows.
val_pred = val_pred.withColumn("fp", when((col("prediction") == 1) & (col("label")==0),1).otherwise(0))
print(val_pred.count())
display(val_pred)

NETWORK_CONGESTION,label,DELAYS_SO_FAR,ORIGIN_PR,ARR_HOUR_BIN,AVG_WND_SPEED_ORIGIN,AVG_WND_SPEED_DEST,MINUTES_AFTER_MIDNIGHT_ORIGIN,DEST_PR,AVG_VIS_DIS_ORIGIN,MINUTES_AFTER_MIDNIGHT_DEST,AVG_DEW_DEG_ORIGIN,CRS_ELAPSED_TIME,DEP_HOUR_BIN,QUARTER,CLASS_WEIGHTS,QUARTER_tmp,QUARTER_catVec,DEP_HOUR_BIN_tmp,DEP_HOUR_BIN_catVec,ARR_HOUR_BIN_tmp,ARR_HOUR_BIN_catVec,features,rawPrediction,probability,prediction,fp
1265.0,0.0,0.0,0.0007536514151696,2,82.0,0.0,1086,0.0010073788831143,16093.0,1135,-11.0,49.0,2,2,0.1786318007205677,1.0,"(4,[1],[1.0])",1.0,"(5,[1],[1.0])",1.0,"(5,[1],[1.0])","(25,[1,2,3,4,5,6,7,8,9,12,16,21],[1086.0,1135.0,1265.0,16093.0,0.0010073788831143871,0.0007536514151696089,-11.0,49.0,82.0,1.0,1.0,1.0])","[38.69104743528032,11.30895256471968]","[0.7738209487056064,0.2261790512943936]",0.0,0
2981.0,0.0,0.0,0.0007536514151696,3,93.0,98.0,79,0.0010073788831143,16093.0,128,72.0,49.0,3,3,0.1786318007205677,0.0,"(4,[0],[1.0])",0.0,"(5,[0],[1.0])",0.0,"(5,[0],[1.0])","(25,[1,2,3,4,5,6,7,8,9,10,11,15,20],[79.0,128.0,2981.0,16093.0,0.0010073788831143871,0.0007536514151696089,72.0,49.0,93.0,98.0,1.0,1.0,1.0])","[34.10906303174933,15.890936968250665]","[0.6821812606349866,0.31781873936501337]",0.0,0
1235.0,0.0,0.0,0.0007536514151696,2,82.0,82.0,1086,0.0010073788831143,16093.0,1140,-33.0,54.0,2,4,0.1786318007205677,2.0,"(4,[2],[1.0])",1.0,"(5,[1],[1.0])",1.0,"(5,[1],[1.0])","(25,[1,2,3,4,5,6,7,8,9,10,13,16,21],[1086.0,1140.0,1235.0,16093.0,0.0010073788831143871,0.0007536514151696089,-33.0,54.0,82.0,82.0,1.0,1.0,1.0])","[38.967426519892996,11.032573480106995]","[0.77934853039786,0.22065146960213994]",0.0,0
1876.0,0.0,0.0,0.0071298404165766,3,36.0,31.0,1425,0.0007536514151696,16093.0,88,-94.0,103.0,2,1,0.1786318007205677,3.0,"(4,[3],[1.0])",1.0,"(5,[1],[1.0])",0.0,"(5,[0],[1.0])","(25,[1,2,3,4,5,6,7,8,9,10,14,16,20],[1425.0,88.0,1876.0,16093.0,0.0007536514151696089,0.007129840416576604,-94.0,103.0,36.0,31.0,1.0,1.0,1.0])","[33.93053374712535,16.069466252874648]","[0.678610674942507,0.32138932505749296]",0.0,0
625.0,0.0,2.0,0.0071298404165766,1,46.0,51.0,997,0.0007536514151696,16093.0,1105,-239.0,108.0,1,1,0.1786318007205677,3.0,"(4,[3],[1.0])",2.0,"(5,[2],[1.0])",2.0,"(5,[2],[1.0])","(25,[0,1,2,3,4,5,6,7,8,9,10,14,17,22],[2.0,997.0,1105.0,625.0,16093.0,0.0007536514151696089,0.007129840416576604,-239.0,108.0,46.0,51.0,1.0,1.0,1.0])","[40.073693810267194,9.926306189732808]","[0.8014738762053439,0.19852612379465615]",0.0,0
1780.0,0.0,3.0,0.0071298404165766,1,21.0,103.0,1010,0.0007536514151696,16093.0,1121,-67.0,111.0,1,1,0.1786318007205677,3.0,"(4,[3],[1.0])",2.0,"(5,[2],[1.0])",2.0,"(5,[2],[1.0])","(25,[0,1,2,3,4,5,6,7,8,9,10,14,17,22],[3.0,1010.0,1121.0,1780.0,16093.0,0.0007536514151696089,0.007129840416576604,-67.0,111.0,21.0,103.0,1.0,1.0,1.0])","[39.46404719352122,10.535952806478772]","[0.7892809438704246,0.2107190561295755]",0.0,0
1693.0,0.0,0.0,0.0071298404165766,3,98.0,26.0,1359,0.0007536514151696,16093.0,25,89.0,106.0,2,3,0.1786318007205677,0.0,"(4,[0],[1.0])",1.0,"(5,[1],[1.0])",0.0,"(5,[0],[1.0])","(25,[1,2,3,4,5,6,7,8,9,10,11,16,20],[1359.0,25.0,1693.0,16093.0,0.0007536514151696089,0.007129840416576604,89.0,106.0,98.0,26.0,1.0,1.0,1.0])","[34.736475239539544,15.263524760460438]","[0.6947295047907911,0.30527049520920885]",0.0,0
1555.0,0.0,0.0,0.0071298404165766,3,0.0,62.0,1364,0.0007536514151696,16093.0,24,17.0,100.0,2,4,0.1786318007205677,2.0,"(4,[2],[1.0])",1.0,"(5,[1],[1.0])",0.0,"(5,[0],[1.0])","(25,[1,2,3,4,5,6,7,8,10,13,16,20],[1364.0,24.0,1555.0,16093.0,0.0007536514151696089,0.007129840416576604,17.0,100.0,62.0,1.0,1.0,1.0])","[35.498213941552876,14.501786058447122]","[0.7099642788310575,0.29003572116894244]",0.0,0
2588.0,1.0,0.0,0.0071298404165766,3,15.0,67.0,85,0.0006851165381782,9656.0,190,-50.0,105.0,3,1,0.8213681992794323,3.0,"(4,[3],[1.0])",0.0,"(5,[0],[1.0])",0.0,"(5,[0],[1.0])","(25,[1,2,3,4,5,6,7,8,9,10,14,15,20],[85.0,190.0,2588.0,9656.0,0.0006851165381782673,0.007129840416576604,-50.0,105.0,15.0,67.0,1.0,1.0,1.0])","[26.581341084766166,23.418658915233827]","[0.5316268216953234,0.4683731783046766]",0.0,0
582.0,0.0,0.0,0.0071298404165766,1,0.0,103.0,850,0.0006851165381782,16000.0,946,-100.0,96.0,1,1,0.1786318007205677,3.0,"(4,[3],[1.0])",2.0,"(5,[2],[1.0])",2.0,"(5,[2],[1.0])","(25,[1,2,3,4,5,6,7,8,10,14,17,22],[850.0,946.0,582.0,16000.0,0.0006851165381782673,0.007129840416576604,-100.0,96.0,103.0,1.0,1.0,1.0])","[39.39158451813084,10.608415481869168]","[0.7878316903626167,0.2121683096373833]",0.0,0


In [0]:
#Percent FP
str(100*1317864/ 7071462) + '%'


In [0]:
val_pred.registerTempTable('val_pred')

In [0]:
# compare the total rows vs. the false positive rows for values of DELAYS_SO_FAR
spark.sql('SELECT DELAYS_SO_FAR, count(FP), sum(fp) FROM val_pred group by DELAYS_SO_FAR ORDER BY DELAYS_SO_FAR').display()

DELAYS_SO_FAR,count(FP),sum(fp)
0.0,5417326,552653
1.0,1178526,573428
10.0,1,0
2.0,335739,149411
3.0,100837,33797
4.0,29040,6867
5.0,7955,1415
6.0,1705,261
7.0,284,27
8.0,47,5


In [0]:
# compare the total rows vs. the false positive rows for values of DEP_HOUR_BIN
spark.sql('SELECT DEP_HOUR_BIN, count(FP), sum(fp) FROM val_pred group by DEP_HOUR_BIN ORDER BY DEP_HOUR_BIN').display()

DEP_HOUR_BIN,count(FP),sum(fp)
0,199041,3315
1,1874179,72684
2,2134088,386947
3,2410937,712150
4,453217,142768


In [0]:
# compare the total rows vs. the false positive rows for values of ARR_HOUR_BIN
spark.sql('SELECT ARR_HOUR_BIN, count(FP), sum(fp) FROM val_pred group by ARR_HOUR_BIN ORDER BY ARR_HOUR_BIN').display()

ARR_HOUR_BIN,count(FP),sum(fp)
0,189485,65443
1,1131741,15645
2,2120239,276235
3,2539536,546065
4,1090461,414476


In [0]:
spark.sql('SELECT ROUND(MINUTES_AFTER_MIDNIGHT_ORIGIN / 60,0) as hour, sum(fp) as fp_count, count(FP) as non_fp_count FROM val_pred group by ROUND(MINUTES_AFTER_MIDNIGHT_ORIGIN / 60,0) ORDER BY hour').display()

hour,fp_count,non_fp_count
0.0,78719,189560
1.0,166939,354113
2.0,147128,290756
3.0,76682,208923
4.0,25110,117833
5.0,9661,67094
6.0,8014,52372
7.0,4387,37466
8.0,2227,24989
9.0,861,21937


In [0]:
spark.sql('SELECT ROUND(MINUTES_AFTER_MIDNIGHT_ORIGIN / 60,0) as h, count(FP) as non_fp_count, sum(fp) as fp_count FROM val_pred group by ROUND(MINUTES_AFTER_MIDNIGHT_ORIGIN / 60,0) ORDER BY h').display()

h,count(FP),sum(fp)
0.0,189560,78719
1.0,354113,166939
2.0,290756,147128
3.0,208923,76682
4.0,117833,25110
5.0,67094,9661
6.0,52372,8014
7.0,37466,4387
8.0,24989,2227
9.0,21937,861


In [0]:
# compare the total rows vs. the false positive rows for values of MINUTES_AFTER_MIDNIGHT_DEST
spark.sql('SELECT ROUND(MINUTES_AFTER_MIDNIGHT_DEST / 60,0) as h, count(FP), sum(fp) FROM val_pred group by ROUND(MINUTES_AFTER_MIDNIGHT_DEST / 60,0) ORDER BY h').display()

h,count(FP),sum(fp)
0.0,207636,44214
1.0,403475,88810
2.0,399333,112500
3.0,369969,144191
4.0,346735,156098
5.0,246744,96652
6.0,135366,43216
7.0,71041,21132
8.0,26478,7539
9.0,17669,4555


In [0]:
# compare the total rows vs. the false positive rows for values of NETWORK_CONGESTION
spark.sql('SELECT ROUND(NETWORK_CONGESTION / 100,0) as h, count(FP), sum(fp) FROM val_pred group by ROUND(NETWORK_CONGESTION / 100,0) ORDER BY h').display()

h,count(FP),sum(fp)
0.0,109886,54452
1.0,160387,63439
2.0,106500,26899
3.0,108977,14528
4.0,170299,13636
5.0,178645,13680
6.0,249562,18707
7.0,327595,25112
8.0,488422,37189
9.0,406032,44304


In [0]:
# compare the total rows vs. the false positive rows for values of AVG_DEW_DEG_ORIGIN
spark.sql('SELECT ROUND(AVG_DEW_DEG_ORIGIN / 100,0) as h, count(FP), sum(fp) FROM val_pred group by ROUND(AVG_DEW_DEG_ORIGIN / 100,0) ORDER BY h').display()

h,count(FP),sum(fp)
-4.0,73,12
-3.0,4811,534
-2.0,117732,23618
-1.0,799260,156474
0.0,1573278,275849
1.0,2082077,395064
2.0,2404837,449810
3.0,89394,16503


In [0]:
# compare the total rows vs. the false positive rows for values of AVG_WND_SPEED_ORIGIN
spark.sql('SELECT ROUND(AVG_WND_SPEED_ORIGIN / 10,0) as h, count(FP), sum(fp) FROM val_pred group by ROUND(AVG_WND_SPEED_ORIGIN / 10,0) ORDER BY h').display()

h,count(FP),sum(fp)
0.0,694193,64503
1.0,1482,138
2.0,1215839,157789
3.0,1337207,224181
4.0,1196419,239239
5.0,949470,215789
6.0,672181,164075
7.0,435289,108370
8.0,263838,66083
9.0,150877,38375


In [0]:
# compare the total rows vs. the false positive rows for values of AVG_WND_SPEED_DEST
spark.sql('SELECT ROUND(AVG_WND_SPEED_DEST / 10,0) as h, count(FP), sum(fp) FROM val_pred group by ROUND(AVG_WND_SPEED_DEST / 10,0) ORDER BY h').display()

h,count(FP),sum(fp)
0.0,683332,76671
1.0,1409,198
2.0,1228071,173974
3.0,1347979,230248
4.0,1194006,234253
5.0,952018,204480
6.0,671810,154793
7.0,432053,103737
8.0,260744,63854
9.0,148219,36793


In [0]:
# compare the total rows vs. the false positive rows for values of ORIGIN PAGERANK
spark.sql('SELECT ROUND(ORIGIN_PR * 1000,0) as h, count(FP), sum(fp) FROM val_pred group by ROUND(ORIGIN_PR * 1000,0) ORDER BY h').display()

h,count(FP),sum(fp)
0.0,22437,2846
1.0,631950,80432
2.0,469530,65717
3.0,285501,45605
4.0,254044,37213
5.0,192845,28497
6.0,371621,76141
7.0,275529,52260
8.0,111609,12921
9.0,143359,26454


In [0]:
# compare the total rows vs. the false positive rows for values of DESTINATION PAGERANK
spark.sql('SELECT ROUND(DEST_PR * 1000,0) as h, count(FP), sum(fp) FROM val_pred group by ROUND(DEST_PR * 1000,0) ORDER BY h').display()

h,count(FP),sum(fp)
0.0,22471,3175
1.0,632350,95124
2.0,469987,97411
3.0,286012,67407
4.0,254566,52789
5.0,193261,44926
6.0,371918,81559
7.0,275830,59478
8.0,111682,17757
9.0,143232,29592


In [0]:
# compare the total rows vs. the false positive rows for values of CRS_ELAPSED_TIME
spark.sql('SELECT ROUND(CRS_ELAPSED_TIME / 10,0) as h, count(FP), sum(fp) FROM val_pred group by ROUND(CRS_ELAPSED_TIME / 10,0) ORDER BY h').display()

h,count(FP),sum(fp)
2.0,677,62
3.0,1944,88
4.0,46466,1759
5.0,120442,10226
6.0,288961,44956
7.0,481695,83416
8.0,565257,104771
9.0,633635,124899
10.0,532845,102066
11.0,490632,88458


In [0]:
# compare the total rows vs. the false positive rows for values of QUARTER
spark.sql('SELECT QUARTER, count(FP), sum(fp) FROM val_pred group by QUARTER ORDER BY QUARTER').display()

QUARTER,count(FP),sum(fp)
1,1653399,322859
2,1806798,353654
3,1836800,349460
4,1774465,291891
