In [0]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

In [0]:
df = spark.read.parquet("gs://flight-delay-pred-data/processed/data_*.parquet")

df = df.withColumn("MonthDayAirline", 
                   F.concat(F.col("Month").cast("string"), 
                            F.lit(" "), 
                            F.col("DayofMonth").cast("string"), 
                            F.lit(" "), 
                            F.col("AirlineName")))

# fractions = {mda[0]: 0.15 for mda in df.select("MonthDayAirline").distinct().collect()}
# df = df.stat.sampleBy("MonthDayAirline", fractions, seed=69)

df = df.drop("FlightDate", "IATACode", "ActualDepartureTime", "ActualArrivalTime", 
             "TaxiInTime", "TaxiOutTime", "ActualElapsedTime", "AirTime", "Cancelled",
             "CarrierDelayMinutes", "WeatherDelayMinutes", "NASDelayMinutes", 
             "SecurityDelayMinutes", "LateAircraftDelayMinutes", 
             "ScheduledDepartureTimestamp", "ScheduledArrivalTimestamp", 
             "ActualArrivalTimestamp", "ActualDepartureTimestamp", "DepartureWeatherCode",
             "ArrivalWeatherCode", "Holiday", "ArrivalDelayed", "DepartureDelayedMinutes",
             "ArrivalDelayedMinutes", "DestinationCity", "OriginCity", 
             "Year", "DepartureDelayMinutes", "ArrivalDelayMinutes")

df = df.withColumn("DepartureDelayed", F.col("DepartureDelayed").cast("int"))               

df = df.withColumn("Id", F.monotonically_increasing_id())

mdas = df.select("MonthDayAirline").distinct().collect()

train_fractions = {mda[0]: 0.8 for mda in mdas}
train_df = df.sampleBy("MonthDayAirline", fractions=train_fractions, seed=69)

test_val_df = df.join(train_df, on="Id", how="left_anti")
val_fractions = {mda[0]: 0.5 for mda in mdas}
val_df = test_val_df.sampleBy("MonthDayAirline", fractions=val_fractions, seed=69)

test_df = test_val_df.join(val_df, on="Id", how="left_anti")

# TODO: Repartition the data every new dataset and change the name
dataset_name = "dataset_delayed_2014_to_2024"
df.drop("Id", "MonthDayAirline").repartition(10).write.parquet("gs://flight-delay-pred-data/ml/{}/all.parquet".format(dataset_name), mode="overwrite")
train_df.drop("Id", "MonthDayAirline").repartition(10).write.parquet("gs://flight-delay-pred-data/ml/{}/train.parquet".format(dataset_name), mode="overwrite")
val_df.drop("Id", "MonthDayAirline").repartition(5).write.parquet("gs://flight-delay-pred-data/ml/{}/val.parquet".format(dataset_name), mode="overwrite")
test_df.drop("Id", "MonthDayAirline").repartition(5).write.parquet("gs://flight-delay-pred-data/ml/{}/test.parquet".format(dataset_name), mode="overwrite")

In [0]:
train_df_2023_15pc = spark.read.parquet("gs://flight-delay-pred-data/ml/dataset_delayed_2023_15pc/train.parquet")
print("Number of rows in train_df_2023_15pc: {}".format(train_df_2023_15pc.count()))
val_df_2023_15pc = spark.read.parquet("gs://flight-delay-pred-data/ml/dataset_delayed_2023_15pc/val.parquet")
print("Number of rows in val_df_2023_15pc: {}".format(val_df_2023_15pc.count()))
test_df_2023_15pc = spark.read.parquet("gs://flight-delay-pred-data/ml/dataset_delayed_2023_15pc/test.parquet")
print("Number of rows in test_df_2023_15pc: {}".format(test_df_2023_15pc.count()))

Number of rows in train_df_2023_15pc: 804711
Number of rows in val_df_2023_15pc: 100503
Number of rows in test_df_2023_15pc: 101124


In [0]:
train_df_2023 = spark.read.parquet("gs://flight-delay-pred-data/ml/dataset_delayed_2023/train.parquet")
print("Number of rows in train_df_2023: {}".format(train_df_2023.count()))
val_df_2023 = spark.read.parquet("gs://flight-delay-pred-data/ml/dataset_delayed_2023/val.parquet")
print("Number of rows in val_df_2023: {}".format(val_df_2023.count()))
test_df_2023 = spark.read.parquet("gs://flight-delay-pred-data/ml/dataset_delayed_2023/test.parquet")
print("Number of rows in test_df_2023: {}".format(test_df_2023.count()))

Number of rows in train_df_2023: 5373933
Number of rows in val_df_2023: 672588
Number of rows in test_df_2023: 674206


In [0]:
train_df_2014_to_2024 = spark.read.parquet("gs://flight-delay-pred-data/ml/dataset_delayed_2014_to_2024/train.parquet")
print("Number of rows in train_df_2014_to_2024: {}".format(train_df_2014_to_2024.count()))
val_df_2014_to_2024 = spark.read.parquet("gs://flight-delay-pred-data/ml/dataset_delayed_2014_to_2024/val.parquet")
print("Number of rows in val_df_2014_to_2024: {}".format(val_df_2014_to_2024.count()))
test_df_2014_to_2024 = spark.read.parquet("gs://flight-delay-pred-data/ml/dataset_delayed_2014_to_2024/test.parquet")
print("Number of rows in test_df_2014_to_2024: {}".format(test_df_2014_to_2024.count()))

Number of rows in train_df_2014_to_2024: 49230933
Number of rows in val_df_2014_to_2024: 6154767
Number of rows in test_df_2014_to_2024: 6155664
