In [None]:
from datetime import datetime, timedelta, timezone
from pyspark.sql import functions as F
from pyspark.sql import Row
import logging

In [None]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("TotalJobSpendsReporter")

In [None]:
dbutils.widgets.text("catalog", "", "CATALOG")
dbutils.widgets.text("schema", "", "SCHEMA")
dbutils.widgets.text("overlap_days", "3", "Overlap days (min 2)")

In [None]:
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
overlap_days = int(dbutils.widgets.get("overlap_days") or "2")

if overlap_days < 2:
    logger.warning("overlap_days < 2; forcing to 2 for cost convergence best practice.")
    overlap_days = 2

audit_table          = f"{catalog}.{schema}.dbspend360_audit_log"
cloud_cost_table     = f"{catalog}.{schema}.dbspend360_cloud_cost_explorer"
databricks_cost_table= f"{catalog}.{schema}.dbspend360_dbu_cost"
total_job_spends_table = f"{catalog}.{schema}.dbspend360_total_job_spends"
error_log_table      = f"{catalog}.{schema}.dbspend360_error_log"

In [None]:
# =======================================================
# Total Job Spends Client
# =======================================================
class TotalJobSpendsClient:

    def __init__(
        self,
        audit_table: str,
        cloud_cost_table: str,
        databricks_cost_table: str,
        total_job_spends_table: str,
        error_log_table: str,
        overlap_days: int,
    ):
        self.audit_table = audit_table
        self.cloud_cost_table = cloud_cost_table
        self.databricks_cost_table = databricks_cost_table
        self.total_job_spends_table = total_job_spends_table
        self.error_log_table = error_log_table
        self.overlap_days = overlap_days

    def _get_date_window(self):
        wm = (
            spark.table(self.audit_table)
                 .filter("table_name = 'dbspend360_total_job_spends' AND status = 'SUCCESS'")
        )

        if wm.limit(1).count() == 0:
            last_end_date = datetime.now(timezone.utc).date() - timedelta(days=365 - self.overlap_days)
        else:
            last_end_date = wm.agg(F.max("end_date")).collect()[0][0]

        start_dt = last_end_date - timedelta(days=self.overlap_days - 1)
        end_dt = datetime.now(timezone.utc).date()
        return start_dt, end_dt

    def _log_run(self, start_dt, end_dt, status, row_count, message=""):
        run_log_df = spark.createDataFrame([
            Row(
                table_name="dbspend360_total_job_spends",
                start_date=start_dt,
                end_date=end_dt,
                status=status,
                row_count=int(row_count),
                message=message,
                created_at=datetime.now(timezone.utc)
            )
        ])
        run_log_df.write.mode("append").insertInto(self.audit_table)

    def _log_errors(self, dbu_df, cloud_df):
        # DBU without cloud cost
        dbu_only = (
            dbu_df.alias("d")
            .join(
                cloud_df.alias("a"),
                on=(
                    (F.col("d.cluster_id") == F.col("a.cluster_id")) &
                    (F.col("d.usage_date") == F.col("a.cost_incurred_date"))
                ),
                how="left_anti"
            )
        )

        if not dbu_only.isEmpty():
            dbu_err = (
                dbu_only
                .select(
                    F.lit("DBR_DBU").alias("source_system"),
                    F.lit("NO_MATCH_cloud_COST").alias("error_type"),
                    F.col("d.cluster_id").alias("cluster_id"),
                    "job_id",
                    "run_id",
                    "usage_date",
                    F.col("d.currency").alias("currency"),
                    F.lit("No matching cloud VM cost row for this DBU usage").alias("error_detail"),
                    F.to_json(F.struct("d.*")).alias("raw_record"),
                )
                .withColumn("created_at", F.lit(datetime.now(timezone.utc)))
            )

            dbu_err.write.mode("append").insertInto(self.error_log_table)

        # cloud without DBU cost
        cloud_only = (
            cloud_df.alias("a")
            .join(
                dbu_df.alias("d"),
                on=(
                    (F.col("a.cluster_id") == F.col("d.cluster_id")) &
                    (F.col("a.cost_incurred_date") == F.col("d.usage_date"))
                ),
                how="left_anti"
            )
        )

        if not cloud_only.isEmpty():
            cloud_err = (
                cloud_only
                .select(
                    F.lit("cloud_COST").alias("source_system"),
                    F.lit("NO_MATCH_DBR_DBU").alias("error_type"),
                    F.col("a.cluster_id").alias("cluster_id"),
                    F.lit(None).cast("string").alias("job_id"),
                    F.lit(None).cast("string").alias("run_id"),
                    F.lit(None).cast("date").alias("usage_date"),
                    F.col("a.currency").alias("currency"),
                    F.lit("No matching DBR DBU cost row for this cloud VM cost").alias("error_detail"),
                    F.to_json(F.struct("a.*")).alias("raw_record"),
                )
                .withColumn("created_at", F.lit(datetime.now(timezone.utc)))
            )

            cloud_err.write.mode("append").insertInto(self.error_log_table)

    def build_total_job_spends(self):
        start_dt, end_dt = self._get_date_window()

        if start_dt > end_dt:
            message = f"Invalid date window: start_dt={start_dt} > end_dt={end_dt}."
            logger.error(message)
            self._log_run(start_dt, end_dt, "FAILED", 0, message)
            dbutils.notebook.exit("FAILED: Invalid date window.")

        logger.info(f"Building dbspend360_total_job_spends for {start_dt} → {end_dt}")

        cloud_df = (spark.table(self.cloud_cost_table)
                        .alias("cc")
                        .filter(
                                (F.col("cost_incurred_date") >= F.lit(start_dt)) &
                                (F.col("cost_incurred_date") <= F.lit(end_dt))
                            )
        )
        dbu_df = (
            spark.table(self.databricks_cost_table)
                 .alias("dbu")
                 .filter(
                     (F.col("usage_date") >= F.lit(start_dt)) &
                     (F.col("usage_date") <= F.lit(end_dt))
                 )
        )

        if dbu_df.limit(1).count() == 0:
            logger.info("No DBU rows in this date window; nothing to join.")
            self._log_run(start_dt, end_dt, "SUCCESS", 0, "No DBU data in window")
            return

        joined = dbu_df.join(
            cloud_df,
            on=(
                (dbu_df["cluster_id"] == cloud_df["cluster_id"]) &
                (dbu_df["usage_date"] == cloud_df["cost_incurred_date"])
            ),
            how="inner"
        )

        joined = joined.withColumn(
            "final_currency",
            F.coalesce(F.col("dbu.currency"), F.col("cc.currency"))
        )

        joined = joined.withColumn(
            "cloud_cost",
            F.col("cc.cloud_cost")
        )

        final_df = (
            joined
            .select(
                F.col("dbu.cluster_id").alias("cluster_id"),
                "job_id",
                "run_id",
                "usage_date",
                F.col("cloud_cost").alias("cloud_cost"),
                F.col("databricks_cost"),
                F.col("final_currency").alias("currency")
            )
        )

        final_df = (
            final_df
            .withColumn("total_cost", F.col("cloud_cost") + F.col("databricks_cost"))
            .withColumn("created_at", F.current_timestamp())
            .withColumn("updated_at", F.current_timestamp())
        )

        final_df.createOrReplaceTempView("job_spends_inc")
        row_count = final_df.count()

        spark.sql(f"""
        MERGE INTO {self.total_job_spends_table} AS t
        USING job_spends_inc AS s
        ON  t.cluster_id = s.cluster_id
        AND t.job_id     = s.job_id
        AND t.run_id     = s.run_id
        AND t.usage_date = s.usage_date
        WHEN MATCHED THEN
          UPDATE SET
            t.cloud_cost      = s.cloud_cost,
            t.databricks_cost = s.databricks_cost,
            t.total_cost      = s.total_cost,
            t.updated_at      = current_timestamp()
        WHEN NOT MATCHED THEN
          INSERT (cluster_id, job_id, run_id, usage_date,
                  cloud_cost, databricks_cost, currency,
                  total_cost, created_at, updated_at)
          VALUES (s.cluster_id, s.job_id, s.run_id, s.usage_date,
                  s.cloud_cost, s.databricks_cost, s.currency,
                  s.total_cost, current_timestamp(), current_timestamp());
        """)

        # Error logging (DBU-only and cloud-only)
        self._log_errors(dbu_df, cloud_df)

        # Run log
        self._log_run(start_dt, end_dt, "SUCCESS", row_count, "")
        logger.info(f"Merged {row_count} rows into {self.total_job_spends_table} for {start_dt} → {end_dt}.")

In [None]:
# =======================================================
# APP
# =======================================================
class TotalJobSpendsApp:

    def __init__(self):
        catalog = dbutils.widgets.get("catalog")
        schema = dbutils.widgets.get("schema")
        overlap_days = int(dbutils.widgets.get("overlap_days") or "2")

        if overlap_days < 2:
            logger.warning("overlap_days < 2; forcing to 2 for cost convergence best practice.")
            overlap_days = 2

        audit_table           = f"{catalog}.{schema}.dbspend360_audit_log"
        cloud_cost_table      = f"{catalog}.{schema}.dbspend360_cloud_cost_explorer"
        databricks_cost_table = f"{catalog}.{schema}.dbspend360_dbu_cost"
        total_job_spends_table= f"{catalog}.{schema}.dbspend360_total_job_spends"
        error_log_table       = f"{catalog}.{schema}.dbspend360_error_log"

        self.client = TotalJobSpendsClient(
            audit_table=audit_table,
            cloud_cost_table=cloud_cost_table,
            databricks_cost_table=databricks_cost_table,
            total_job_spends_table=total_job_spends_table,
            error_log_table=error_log_table,
            overlap_days=overlap_days,
        )

    def run(self):
        self.client.build_total_job_spends()

In [None]:
# =======================================================
# Execute
# =======================================================
app = TotalJobSpendsApp()
app.run()