In [0]:
from datetime import datetime, timedelta
from pyspark.sql.functions import col, when, expr, trunc, date_format, weekofyear
from pyspark.sql.functions import last_day, dayofweek, date_sub
from pyspark.sql.types import DateType
from pyspark.sql import functions as F
import json

def last_week_apportion_spark(df, date_col_name, kpi_col_list, work_days):
    df = df.withColumn("year", F.year(F.col(date_col_name))) \
           .withColumn("month", F.month(F.col(date_col_name))) \
           .withColumn("last_day_of_month", F.last_day(F.col(date_col_name)))

    if work_days == 5:
        df = df.withColumn(
            "last_working_date",
            F.when(F.dayofweek(F.col("last_day_of_month")) == 7, F.date_sub(F.col("last_day_of_month"), 1))  # Saturday
             .when(F.dayofweek(F.col("last_day_of_month")) == 1, F.date_sub(F.col("last_day_of_month"), 2))  # Sunday
             .otherwise(F.col("last_day_of_month"))
        )
    else:
        df = df.withColumn("last_working_date", F.col("last_day_of_month"))

    df = df.withColumn("day_diff", F.datediff(F.col("last_working_date"), F.col(date_col_name)) + 1)

    for kpi in kpi_col_list:
        adj_kpi = f"adjusted_{kpi}"
        df = df.withColumn(
            adj_kpi,
            F.when(F.col("day_diff") < work_days, ((work_days - F.col("day_diff")) / work_days) * F.col(kpi)).otherwise(F.lit(0))
        )
        df = df.withColumn(kpi, F.col(kpi) - F.col(adj_kpi))

    new_rows = df
    for kpi in kpi_col_list:
        new_rows = new_rows.withColumn(kpi, F.col(f"adjusted_{kpi}"))

    new_rows = new_rows.withColumn(date_col_name, F.add_months(F.col(date_col_name), 1))
    combined_df = df.unionByName(new_rows)

    drop_cols = ["year", "month", "last_day_of_month", "last_working_date", "day_diff"] + [f"adjusted_{k}" for k in kpi_col_list]
    combined_df = combined_df.drop(*drop_cols)

    return combined_df


def modify_granularity_spark(
    df,
    geo_column,
    date_column,
    granularity_level_df,
    granularity_level_user_input,
    work_days,
    numerical_config_dict,
    categorical_config_dict
):
    def get_agg_exprs(numerical_dict, categorical_dict):
        agg_exprs = []

        for col_name, operation in numerical_dict.items():
            if operation == "sum":
                agg_exprs.append(F.sum(F.col(col_name)).alias(col_name))
            elif operation == "average":
                agg_exprs.append(F.avg(F.col(col_name)).alias(col_name))
            elif operation == "min":
                agg_exprs.append(F.min(F.col(col_name)).alias(col_name))
            elif operation == "max":
                agg_exprs.append(F.max(F.col(col_name)).alias(col_name))
            elif operation == "product":
                agg_exprs.append(F.expr(f"aggregate(collect_list({col_name}), 1D, (acc, x) -> acc * x)").alias(col_name))

        for col_name, operation in categorical_dict.items():
            if operation == "count":
                agg_exprs.append(F.count(F.col(col_name)).alias(f"{col_name}_count"))
            elif operation == "distinct count":
                agg_exprs.append(F.countDistinct(F.col(col_name)).alias(f"{col_name}_count_distinct"))
            # Pivot is ignored for now

        return agg_exprs

    output_columns = [geo_column] + [granularity_level_user_input.lower() + "_date"] + list(numerical_config_dict.keys()) + \
                     [f"{k}_count" if v == "count" else f"{k}_count_distinct" for k, v in categorical_config_dict.items() if v in ["count", "count distinct"]]

    if granularity_level_df == granularity_level_user_input:
        selected_cols = [geo_column, date_column] + list(numerical_config_dict.keys()) + list(categorical_config_dict.keys())
        return df.select(*selected_cols), date_column

    if granularity_level_df == "Daily" and granularity_level_user_input == "Weekly":
        df = df.withColumn("week_date", F.date_sub(F.col(date_column), F.dayofweek(F.col(date_column)) - 1))
        df = df.groupBy(geo_column, "week_date").agg(*get_agg_exprs(numerical_config_dict, categorical_config_dict)).orderBy(geo_column, "week_date")
        return df, "week_date"

    elif granularity_level_df == "Daily" and granularity_level_user_input == "Monthly":
        df = df.withColumn("month_date", F.date_format(F.col(date_column), "yyyy-MM-01"))
        df = df.groupBy(geo_column, "month_date").agg(*get_agg_exprs(numerical_config_dict, categorical_config_dict)).orderBy(geo_column, "month_date")
        return df, "month_date"

    elif granularity_level_df == "Weekly" and granularity_level_user_input == "Monthly":
        df = last_week_apportion_spark(df, date_column, list(numerical_config_dict.keys()), work_days)
        df = df.withColumn("month_date", F.date_format(F.col(date_column), "yyyy-MM-01"))
        df = df.groupBy(geo_column, "month_date").agg(*get_agg_exprs(numerical_config_dict, categorical_config_dict)).orderBy(geo_column, "month_date")
        return df, "month_date"

    else:
        raise ValueError("Unsupported granularity transformation")


In [0]:
# Set Spark config
spark.conf.set(
    "fs.azure.account.key.mmixstorage.blob.core.windows.net",
    "UZTHs33FPYTUvC9G51zk+DQQp/FWf31YOteoW+dEnKuprRgxvk53yS+IpEiLn1062IBpOyoKaXp4+AStRcA1Cw=="
)
spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")

dbutils.widgets.removeAll()


# Get widget inputs
dbutils.widgets.text("input_path", "")
dbutils.widgets.text("output_path", "")
dbutils.widgets.text("geo_col", "")
dbutils.widgets.text("date_col", "")
dbutils.widgets.text("granularity_level_df", "")
dbutils.widgets.text("granularity_level_user_input", "")
dbutils.widgets.text("work_days", "")
dbutils.widgets.text("numerical_config_dict", "{}")
dbutils.widgets.text("categorical_config_dict", "{}")

input_path = dbutils.widgets.get("input_path")
output_path = dbutils.widgets.get("output_path")
geo_col = dbutils.widgets.get("geo_col")
date_col = dbutils.widgets.get("date_col")
granularity_level_df = dbutils.widgets.get("granularity_level_df")
granularity_level_user_input = dbutils.widgets.get("granularity_level_user_input")
work_days = int(dbutils.widgets.get("work_days"))
numerical_config_dict = json.loads(dbutils.widgets.get("numerical_config_dict"))
categorical_config_dict = json.loads(dbutils.widgets.get("categorical_config_dict"))

base_blob_url = "wasbs://pre-processing@mmixstorage.blob.core.windows.net"

# Convert HTTPS URLs to WASBS Paths
def convert_to_wasbs(url):
    https_prefix = "https://mmixstorage.blob.core.windows.net/pre-processing/"
    return url.replace(https_prefix, f"{base_blob_url}/")

input_path  = convert_to_wasbs(input_path)
output_path = convert_to_wasbs(output_path)

# Read input
df = spark.read.option("header", True).option("inferSchema", True).csv(input_path)

#df = df.withColumn(date_col, F.to_date(F.col(date_col), "MM/dd/yyyy"))
df = df.filter(F.col(date_col).isNotNull())
df.printSchema()

# Modify granularity
modified_granular_df, new_date_col = modify_granularity_spark(
    df,
    geo_col,
    date_col,
    granularity_level_df,
    granularity_level_user_input,
    work_days,
    numerical_config_dict,
    categorical_config_dict
)

# Save result
display(modified_granular_df)
modified_granular_df.coalesce(1).write.mode("overwrite").option("header", True).csv(output_path)