In [1]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DateType
from pyspark.sql.functions import lit, isnull, when, count, col, regexp_extract, concat_ws, to_date, expr, quarter, when, date_add, year, month, day, dayofweek, broadcast, avg, min, max, like

# Define spark session config
spark_configs = {
    'spark.master': 'spark://spark-iceberg:7077',
    'spark.sql.catalog.prod': 'org.apache.iceberg.spark.SparkCatalog',
    'spark.sql.catalog.prod.io-impl': 'org.apache.iceberg.aws.s3.S3FileIO',
    'spark.sql.catalog.prod.s3.endpoint': 'http://minio:9000',
    'spark.sql.catalog.prod.type': 'rest',
    'spark.sql.catalog.prod.uri': 'http://rest:8181',
    'spark.sql.catalog.prod.warehouse': 's3://warehouse',
    'spark.sql.defaultCatalog': 'prod',
    'spark.driver.memory': '1G',
    'spark.executor.memory': '1G'
}

# Initialize SparkSession
spark = (
    SparkSession
    .builder
    .appName('Agg Fact Testing')
    .config(map=spark_configs)
    .getOrCreate()
)

25/01/04 19:38:01 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [2]:
def generate_dim_dates_df(spark: SparkSession) -> DataFrame:
    # Initialize dates_df
    dates_df = spark.range(365) \
        .withColumn('date', expr('date_add("2015-01-01", CAST(id AS INT))')) \
        .withColumn('year', year('date')) \
        .withColumn('month', month('date')) \
        .withColumn('day', day('date')) \
        .withColumn('day_of_week', dayofweek('date')) \
        .withColumn('quarter', quarter('date')) \
        .drop('id')

    # List of U.S. federal holidays
    us_holidays_2015 = [
        ("2015-01-01", "New Year's Day"),
        ("2015-01-19", "Martin Luther King Jr. Day"),
        ("2015-02-16", "Presidents' Day"),
        ("2015-05-25", "Memorial Day"),
        ("2015-07-04", "Independence Day"),
        ("2015-09-07", "Labor Day"),
        ("2015-10-12", "Columbus Day"),
        ("2015-11-11", "Veterans Day"),
        ("2015-11-26", "Thanksgiving Day"),
        ("2015-12-25", "Christmas Day"),
    ]
    
    # Create holidays_df and cast date from STRING to DATE type
    holidays_df = spark.createDataFrame(us_holidays_2015, ['holiday_date', 'holiday_name'])
    holidays_df = holidays_df.withColumn('holiday_date', to_date('holiday_date'))

    # Join holidays to date_df and add is_holiday column
    dates_df = dates_df \
        .join(
            broadcast(holidays_df),
            dates_df.date == holidays_df.holiday_date,
            'left'
        ) \
        .withColumn(
            'is_holiday',
            when(col('holiday_name').isNotNull(), lit(True)).otherwise(lit(False))
        ) \
        .drop('holiday_date') \
        .sort('date')

    # Rearrange date to be first column,
    dates_df = dates_df.select('date', *[col(c) for c in dates_df.columns if c != 'date'])
    return dates_df

dates_df = generate_dim_dates_df(spark)

In [5]:
# Read flights table
flights_df = spark.table('prod.db.fact_flights')

# Join dim_date table to flights
flights_df = flights_df \
    .join(
        broadcast(dates_df),
        ['date']
    )

flights_df.createOrReplaceTempView('flights')

flights_df.show(10)

+----------+-------+-------------+-----------+--------------+-------------------+-------------------+--------------+---------------+--------+----------+--------------+------------+--------+--------+---------+-------+-----------------+------------+-------------+--------+---------+-------------------+----------------+--------------+-------------+-------------------+-------------+----------+----+-----+---+-----------+-------+------------+----------+
|      date|airline|flight_number|tail_number|origin_airport|destination_airport|scheduled_departure|departure_time|departure_delay|taxi_out|wheels_off|scheduled_time|elapsed_time|air_time|distance|wheels_on|taxi_in|scheduled_arrival|arrival_time|arrival_delay|diverted|cancelled|cancellation_reason|air_system_delay|security_delay|airline_delay|late_aircraft_delay|weather_delay|is_delayed|year|month|day|day_of_week|quarter|holiday_name|is_holiday|
+----------+-------+-------------+-----------+--------------+-------------------+-----------------

In [8]:
def agg_delay_metrics_rollup(
    input_df: DataFrame,
    non_time_columns: list[str],
    time_columns: list[str]
) -> DataFrame:

    # Concatenate both list of columns
    cube_columns = non_time_columns + time_columns

    # Calculate the aggregate values for the cube
    output_df = input_df \
        .rollup(*cube_columns) \
        .agg(
            count('*').alias('total_flights'),
            count(when(col('is_delayed') == True, 1)).alias('delayed_flights'),
            ( count(when(col('is_delayed') == True, 1)) / count('*') ).alias('delay_rate'),
            avg('departure_delay').alias('avg_delay_time')
        )

    # Calculate aggregation label to add as column
    time_agg_level = get_aggregation_level(time_columns)
    agg_level = get_aggregation_level(non_time_columns)

    # Add (time_)agg_level columns
    output_df = output_df \
        .withColumn('time_agg_level', time_agg_level) \
        .withColumn('agg_level', agg_level)
    
    return output_df

def get_aggregation_level(columns: list[str]):
    # TODO add comments
    # Returns col
    agg_level = concat_ws(
        '_',
        *[when(col(c).isNotNull(), lit(c)) for c in columns]
    )

    all_nulls = lit(True)
    for c in columns:
        all_nulls &= col(c).isNull()
    agg_level = when(all_nulls, 'all').otherwise(agg_level)
    return agg_level
    

In [13]:
temp_df = flights_df.sample(0.1)

# Confirm redundant aggs for cube with cancelled and cancellation_reason
time_columns = ['year', 'month']
non_time_columns = ['airline', 'origin_airport']
# non_time_columns = ['airline', 'origin_airport', 'cancelled', 'cancellation_reason']


# agg_df = agg_delay_metrics_cube_by_cols(temp_df, time_columns, non_time_columns)
agg_df = agg_delay_metrics_rollup(temp_df, non_time_columns, time_columns)

# agg_df.filter(col('agg_level') == 'cancelled_cancellation_reason')).show(100)
# agg_df \
#     .filter(
#         (col('year').isNull()) &
#         (col('month').isNull()) &
#         (col('airline').isNull()) &
#         (col('origin_airport').isNull()) &
#         (col('cancellation_reason') == lit('A'))
#     ) \
#     .show()
agg_df.count()
# agg_df.show()

                                                                                

17742

In [None]:
grouping_set_sql = """
SELECT
    airline,
    year,
    month,
    day_of_week,
    COUNT(*) AS total_flights,
    COUNT(CASE WHEN is_delayed = 1 THEN 1 END) AS delayed_flights,
    delayed_flights / total_flights AS delay_rate,
    AVG(departure_delay) AS avg_delay_time
FROM flights
GROUP BY
    GROUPING SETS (
        
            --airline,
            --day_of_week,
            --ROLLUP(year, month),
            --GROUPING SETS(airline, ROLLUP(year, month))
            CUBE(airline, (year, month), day_of_week)
        
    )
"""


temp_df = spark.sql(grouping_set_sql)

In [24]:
import itertools

def build_grouping_sets(grouping_columns):
    grouping_sets = []
    for i in range(len(grouping_columns) + 1):
        for subset in itertools.combinations(grouping_columns, i):
            if "cancelled" in subset:
                grouping_sets.append(subset)
    grouping_sets_str = str(tuple(grouping_sets)).replace("'", "").replace(",)", ")")
    return grouping_sets_str

grouping_cols = [
    'cancelled',
    'year',
    'month'
]
grouping_sets = build_grouping_sets(grouping_cols)
print(grouping_sets)

((cancelled), (cancelled, year), (cancelled, month), (cancelled, year, month))


In [None]:
(month) (day_of_week) (month, airport) (airline) (airline, month) (airline)

In [None]:
grouping_sets = (
    (year), (year, month),
    
    (airline), (airline, year), (airline, year, month),
    (origin_airport), (origin_airport, year), (origin_airport, year, month)
    
    ((cancelled), (cancelled, year), (cancelled, year, month))
    ((cancellation_reason), (cancellation_reason, year), (cancellation_reason, year, month))

    (day_of_week), 
)

In [51]:
temp_df.createOrReplaceTempView('flights')

query = """
SELECT 
    year,
    month,
    airline,
    COUNT(*) AS count,
    GROUPING_ID(year, month, airline) AS grouping_id
FROM flights
GROUP BY GROUPING SETS (
    (year),
    (year, month),
    
    (airline), 
    (airline, year), 
    (airline, year, month)
)
"""

cols = ['year', 'month', 'airline']

agg_df = spark.sql(query)
agg_level = get_aggregation_level(cols)
agg_df = agg_df.withColumn('agg_level', agg_level)

# agg_df.show()
agg_df.filter('grouping_id = "3"').show()



+----+-----+-------+------+-----------+---------+
|year|month|airline| count|grouping_id|agg_level|
+----+-----+-------+------+-----------+---------+
|2015| NULL|   NULL|582343|          3|     year|
+----+-----+-------+------+-----------+---------+



                                                                                