In [None]:
from datetime import datetime, timedelta, timezone
from pyspark.sql import functions as F
from pyspark.sql import Row
import logging

In [None]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DBUCostReporter")

In [None]:
dbutils.widgets.text("catalog", "", "CATALOG")
dbutils.widgets.text("schema", "", "SCHEMA")
dbutils.widgets.text("overlap_days", "3", "Overlap days (min 2)")

In [None]:
# =======================================================
# DBU Cost Client
# =======================================================
class DBUCostClient:

    def __init__(self, audit_table: str, target_table: str, overlap_days: int):
        self.audit_table = audit_table
        self.target_table = target_table
        self.overlap_days = overlap_days
        # self.workspace_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterOwnerOrgId")

    def _get_date_window(self):
        wm = (
            spark.table(self.audit_table)
                 .filter("table_name = 'dbspend360_dbu_cost' AND status = 'SUCCESS'")
        )

        if wm.limit(1).count() == 0:
            last_end_date = datetime.now(timezone.utc).date() - timedelta(days=365 - self.overlap_days)
        else:
            last_end_date = wm.agg(F.max("end_date")).collect()[0][0]

        start_dt = last_end_date - timedelta(days=self.overlap_days - 1)
        end_dt = datetime.now(timezone.utc).date()

        return start_dt, end_dt

    def _log_run(self, start_dt, end_dt, status, row_count, message=""):
        run_log_df = spark.createDataFrame([
            Row(
                table_name="dbspend360_dbu_cost",
                start_date=start_dt,
                end_date=end_dt,
                status=status,
                row_count=int(row_count),
                message=message,
                created_at=datetime.now(timezone.utc)
            )
        ])
        run_log_df.write.mode("append").insertInto(self.audit_table)

    def compute_and_merge_dbu_cost(self):
        start_dt, end_dt = self._get_date_window()

        if start_dt > end_dt:
            message = f"Invalid date window: start_dt={start_dt} > end_dt={end_dt}."
            logger.error(message)
            self._log_run(start_dt, end_dt, "FAILED", 0, message)
            dbutils.notebook.exit("FAILED: Invalid date window.")

        logger.info(f"Loading DBU cost from {start_dt} to {end_dt}")

        # 1. job clusters
        cluster_df = (
            spark.table("system.compute.clusters")
                 .select("cluster_id", "cluster_name", "cluster_source", "workspace_id")
                 .filter("cluster_source = 'JOB'")
        )

        # 2. usage + list_prices, filtered by usage_date window
        usage_df = (
            spark.table("system.billing.usage")
                 .alias("usage")
                #  .filter(F.col("usage.workspace_id") == workspace_id)
                 .filter((F.col("usage.usage_date") >= F.lit(start_dt)) &
                         (F.col("usage.usage_date") <= F.lit(end_dt)))
        )

        list_prices_df = spark.table("system.billing.list_prices").alias("list_prices")

        df = (
            usage_df.join(
                list_prices_df,
                on=(
                    (F.col("usage.sku_name") == F.col("list_prices.sku_name")) &
                    (F.col("usage.usage_start_time") >= F.col("list_prices.price_start_time")) &
                    (
                        (F.col("usage.usage_start_time") < F.col("list_prices.price_end_time")) |
                        F.col("list_prices.price_end_time").isNull()
                    )
                ),
                how="left"
            )
        )

        # 3. filter to job runs and aggregate
        filtered_df = df.filter(
            F.col("usage.usage_metadata")["job_run_id"].isNotNull()
        )

        #updated to keep sku names as sku1+sku2+sku3 in case of multiple sku type for same cluster
        agg_df = (
            filtered_df
            .groupBy(
                F.col("usage.usage_metadata")["cluster_id"].alias("job_cluster_id"),
                F.col("usage.usage_metadata")["job_id"].alias("job_id"),
                F.col("usage.usage_metadata")["job_run_id"].alias("run_id"),
                F.col("usage.usage_date").alias("usage_date"),
                F.col("usage.workspace_id").alias("workspace_id")
            )
            .agg(
                F.sum(
                    F.col("usage.usage_quantity")
                    * F.col("list_prices.pricing")["default"].cast("double")
                ).alias("databricks_cost"),
                F.concat_ws(
                    " + ",
                    F.array_sort(F.collect_set(F.col("usage.sku_name")))
                ).alias("sku_name_merged")
            )
        )

        # 4. join with job clusters
        job_cluster_df = (
            cluster_df
            .select("cluster_id")
            .dropDuplicates(["cluster_id"])
        )

        joined_df = (
            agg_df.join(
                job_cluster_df,
                on=(agg_df["job_cluster_id"] == job_cluster_df["cluster_id"]),
                how="inner"
            )
            .drop("job_cluster_id")
        )

        joined_df = joined_df.withColumn("currency", F.lit("USD"))

        # 5. prepare incremental dataframe
        dbu_inc_df = (
            joined_df
            .select(
                "cluster_id",
                "job_id",
                "run_id",
                "usage_date",
                "databricks_cost",
                "currency",
                F.col("sku_name_merged").alias("sku_name"),
                "workspace_id",
            )
        )

        if dbu_inc_df.limit(1).count() == 0:
            logger.info("No DBU rows after filtering / aggregation.")
            merged_row_count = 0
        else:
            merged_row_count = dbu_inc_df.count()

            dbu_inc_df = (
                dbu_inc_df
                .withColumn("created_at", F.current_timestamp())
                .withColumn("updated_at", F.current_timestamp())
            )

            dbu_inc_df.createOrReplaceTempView("dbu_cost_inc")

            # 6. MERGE into target
            spark.sql(f"""
            MERGE INTO {self.target_table} AS t
            USING dbu_cost_inc AS s
            ON  t.cluster_id = s.cluster_id
            AND t.job_id     = s.job_id
            AND t.run_id     = s.run_id
            AND t.usage_date = s.usage_date
            WHEN MATCHED THEN
              UPDATE SET
                t.databricks_cost = s.databricks_cost,
                t.updated_at      = current_timestamp()
            WHEN NOT MATCHED THEN
              INSERT (cluster_id, job_id, run_id, usage_date,
                      databricks_cost, currency, sku_name, workspace_id,
                      created_at, updated_at)
              VALUES (s.cluster_id, s.job_id, s.run_id, s.usage_date,
                      s.databricks_cost, s.currency, s.sku_name, s.workspace_id,
                      current_timestamp(), current_timestamp());
            """)

        # 7. append run log
        self._log_run(start_dt, end_dt, "SUCCESS", merged_row_count, "")
        logger.info(f"Merged {merged_row_count} rows into {self.target_table} for {start_dt} â†’ {end_dt}.")

In [None]:
# =======================================================
# APP
# =======================================================
class DBUCostReporterApp:

    def __init__(self):
        catalog = dbutils.widgets.get("catalog")
        schema = dbutils.widgets.get("schema")
        overlap_days = int(dbutils.widgets.get("overlap_days") or "2")

        if overlap_days < 2:
            logger.warning("overlap_days < 2; forcing to 2 for cost convergence best practice.")
            overlap_days = 2

        audit_table = f"{catalog}.{schema}.dbspend360_audit_log"
        target_table = f"{catalog}.{schema}.dbspend360_dbu_cost"

        self.client = DBUCostClient(
            audit_table=audit_table,
            target_table=target_table,
            overlap_days=overlap_days,
        )

    def run(self):
        self.client.compute_and_merge_dbu_cost()

In [None]:
# =======================================================
# Execute
# =======================================================
app = DBUCostReporterApp()
app.run()