In [None]:
%pip install azure-identity azure-mgmt-costmanagement azure.mgmt.costmanagement
dbutils.library.restartPython()

In [None]:
from datetime import datetime, timezone, timedelta
from pyspark.sql import functions as F
from pyspark.sql import Row
import time
import json
import requests
import logging
from azure.identity import ClientSecretCredential
from azure.mgmt.costmanagement import CostManagementClient
from azure.mgmt.costmanagement.models import (
    QueryDefinition,
    QueryDataset,
    QueryTimePeriod,
    QueryAggregation,
    QueryGrouping,
    ExportType,
    TimeframeType,
)

In [None]:
dbutils.widgets.text("catalog", "")
dbutils.widgets.text("schema", "")
dbutils.widgets.text("overlap_days", "3")
dbutils.widgets.text("subscription_id", "")
dbutils.widgets.text("scope", "")

In [None]:
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
overlap_days = int(dbutils.widgets.get("overlap_days") or "3")

In [None]:
# Disable Azure SDK verbose logs
logging.getLogger("azure").setLevel(logging.WARNING)
logging.getLogger("azure.core").setLevel(logging.WARNING)
logging.getLogger("azure.identity").setLevel(logging.WARNING)
logging.getLogger("py4j").setLevel(logging.ERROR)

In [None]:
audit_table = f"{catalog}.{schema}.dbspend360_audit_log"
target_table = f"{catalog}.{schema}.dbspend360_cloud_cost_explorer"

In [None]:
class AzureCostClient:
    def __init__(self, subscription_id, tenant_id, client_id, client_secret):
        self.subscription_id = subscription_id

        self.credential = ClientSecretCredential(
            tenant_id=tenant_id,
            client_id=client_id,
            client_secret=client_secret
        )

        self.client = CostManagementClient(self.credential)
        self.scope = f"/subscriptions/{self.subscription_id}"
        self.max_chunk_days = 10
        self.max_retries = 3

    # -------- Public API --------
    def group_by_job_clusterid_daily(
        self,
        start_date: datetime,
        end_date: datetime,
        tag_name: str = "clusterid",
    ):
        """Entry point: handles chunking and unions all results."""
        start_utc, end_utc = self._to_utc(start_date, end_date)

        chunks = self._build_chunks(start_utc, end_utc, self.max_chunk_days)

        all_chunk_dfs = []
        for chunk_start, chunk_end in chunks:
            print(f"Querying chunk {chunk_start} → {chunk_end}")
            df = self._query_with_retries(chunk_start, chunk_end, tag_name)
            if df is not None and df.limit(1).count() > 0:
                all_chunk_dfs.append(df)
            # small pause between chunks to avoid bursty usage
            time.sleep(5)

        if not all_chunk_dfs:
            return None

        # Union all chunks
        result_df = all_chunk_dfs[0]
        for df in all_chunk_dfs[1:]:
            result_df = result_df.unionByName(df, allowMissingColumns=True)

        return result_df

    # -------- Helpers: date & chunks --------
    def _to_utc(self, start_date: datetime, end_date: datetime):
        start_utc = start_date.astimezone(timezone.utc)
        end_utc = end_date.astimezone(timezone.utc)
        return start_utc, end_utc

    def _build_chunks(self, start_utc: datetime, end_utc: datetime, max_days: int):
        """Return list of (chunk_start, chunk_end) in UTC."""
        chunks = []
        current = start_utc
        while current <= end_utc:
            chunk_end = min(current + timedelta(days=max_days - 1), end_utc)
            chunks.append((current, chunk_end))
            current = chunk_end + timedelta(days=1)
        return chunks

    # -------- Helpers: query construction --------
    def _build_dataset(self, tag_name: str):
        return QueryDataset(
            granularity="Daily",
            aggregation={"totalCost": QueryAggregation(name="Cost", function="Sum")},
            grouping=[QueryGrouping(type="TagKey", name=tag_name)],
        )

    def _build_query_definition(self, start_utc: datetime, end_utc: datetime, dataset):
        return QueryDefinition(
            type=ExportType.ACTUAL_COST,
            timeframe=TimeframeType.CUSTOM,
            time_period=QueryTimePeriod(from_property=start_utc, to=end_utc),
            dataset=dataset,
        )

    def _build_query_body_json(self, start_utc: datetime, end_utc: datetime, tag_name: str):
        body = {
            "type": "ActualCost",
            "timeframe": "Custom",
            "timePeriod": {
                "from": start_utc.isoformat(),
                "to": end_utc.isoformat(),
            },
            "dataset": {
                "granularity": "Daily",
                "aggregation": {
                    "totalCost": {
                        "name": "Cost",
                        "function": "Sum",
                    }
                },
                "grouping": [
                    {"type": "TagKey", "name": tag_name}
                ],
            },
        }
        return json.dumps(body)

    # -------- Core call with retries --------
    def _query_with_retries(self, start_utc, end_utc, tag_name):
        dataset = self._build_dataset(tag_name)
        query = self._build_query_definition(start_utc, end_utc, dataset)
        query_json = self._build_query_body_json(start_utc, end_utc, tag_name)

        attempt = 0
        last_exception = None

        while attempt < self.max_retries:
            try:
                return self._execute_query(query, query_json)
            except Exception as e:
                last_exception = e
                attempt += 1

                # crude 429 detection from SDK exception text
                is_429 = "429" in str(e) or "Too many requests" in str(e)
                if attempt >= self.max_retries:
                    break

                if is_429:
                    # back off more aggressively on 429
                    wait_sec = 30 * attempt
                    print(f"Rate limited (429) on main query, waiting {wait_sec}s (attempt {attempt})...")
                    time.sleep(wait_sec)
                else:
                    wait_sec = 2 ** attempt
                    print(f"Error on main query, waiting {wait_sec}s (attempt {attempt})...")
                    time.sleep(wait_sec)

        raise last_exception

    # -------- Single query + pagination --------
    def _execute_query(self, query, query_json: str):
        result = self.client.query.usage(self.scope, parameters=query)

        if not result.rows:
            return None

        all_rows = list(result.rows)

        next_link = getattr(result, "next_link", None)
        token = None
        if next_link:
            token = self.credential.get_token(
                "https://management.azure.com/.default"
            ).token

        while next_link:
            next_link, page_rows = self._fetch_next_page(next_link, token, query_json)
            all_rows.extend(page_rows)
            if next_link:
                time.sleep(2)

        return self._rows_to_df(all_rows)

    def _fetch_next_page(self, next_link: str, token: str, query_json: str):
        while True:
            resp = requests.post(
                next_link,
                headers={
                    "Authorization": f"Bearer {token}",
                    "Content-Type": "application/json",
                },
                data=query_json,
            )

            if resp.status_code == 429:
                # Cost Management exposes specific retry headers
                headers = resp.headers
                retry_after = (
                    headers.get("x-ms-ratelimit-microsoft.costmanagement-qpu-retry-after")
                    or headers.get("x-ms-ratelimit-microsoft.costmanagement-entity-retry-after")
                    or headers.get("x-ms-ratelimit-microsoft.costmanagement-tenant-retry-after")
                    or headers.get("x-ms-ratelimit-microsoft.costmanagement-client-retry-after")
                    or headers.get("Retry-After")
                )

                wait_sec = int(retry_after) if retry_after is not None else 30
                print(f"429 throttled, waiting {wait_sec}s before retrying nextLink...")
                time.sleep(wait_sec)
                continue


            resp.raise_for_status()
            data = resp.json()
            props = data.get("properties", {})
            page_rows = props.get("rows", [])
            new_next_link = props.get("nextLink")
            return new_next_link, page_rows

    # -------- Helper: convert rows to DataFrame --------
    def _rows_to_df(self, rows):
        return spark.createDataFrame(
            rows,
            ["cost", "date_key", "tag_key", "cluster_id", "currency"],
        )


In [None]:
# =======================================================
# APP
# =======================================================
class AzureCostReporterApp:

    def __init__(self):
        scope=dbutils.widgets.get("scope")
        subscription_id=dbutils.widgets.get("subscription_id")
        tenant_id=dbutils.secrets.get(scope, "tenant_id")
        client_id=dbutils.secrets.get(scope, "client_id")
        client_secret= dbutils.secrets.get(scope, "client_secret")
        
        self.client = AzureCostClient(
            subscription_id,
            tenant_id,
            client_id,
            client_secret,
        )

    def run(self):

        wm = (
            spark.table(audit_table)
                 .filter("table_name = 'dbspend360_cloud_cost_explorer' AND status = 'SUCCESS'")
        )

        if wm.limit(1).count() == 0:
            last_end_date = datetime.now(timezone.utc).date() - timedelta(days=365-overlap_days)
        else:
            last_end_date = wm.agg(F.max("end_date")).collect()[0][0]

        start_dt = last_end_date - timedelta(days=overlap_days - 1)
        end_dt = datetime.now(timezone.utc).date()

        print(f"Querying Azure cost from {start_dt} to {end_dt} (overlap_days={overlap_days})")

        if start_dt > end_dt:
            message = (
                f"Invalid date window: start_dt={start_dt} > end_dt={end_dt}. "
                f"Check audit table and overlap_days={overlap_days}."
            )

            run_log_df = spark.createDataFrame([
                Row(
                    table_name="dbspend360_cloud_cost_explorer",
                    start_date=start_dt,
                    end_date=end_dt,
                    status="FAILED",
                    row_count=0,
                    message=message,
                    created_at=datetime.now(timezone.utc),
                )
            ])
            run_log_df.write.mode("append").insertInto(audit_table)
            dbutils.notebook.exit("FAILED: Invalid date window.")

        spark_df = self.client.group_by_job_clusterid_daily(
            start_date=datetime.combine(start_dt, datetime.min.time(), tzinfo=timezone.utc),
            end_date=datetime.combine(end_dt, datetime.max.time(), tzinfo=timezone.utc),
            tag_name="clusterid",
        )

        if spark_df is None or spark_df.limit(1).count() == 0:
            print("No Azure cost data returned by API for the requested range.")
            merged_row_count = 0
        else:
            spark_df = spark_df.withColumn(
                "cost_incurred_date",
                F.to_date(F.col("date_key").cast("string"), "yyyyMMdd"),
            )

            inc_df = (
                spark_df
                .filter((F.col("cluster_id").isNotNull()) & (F.col("cluster_id") != ""))
                .filter(F.col("cost_incurred_date").isNotNull())
            )

            if inc_df.limit(1).count() == 0:
                print("No incremental rows after filtering by cluster_id and cost_incurred_date.")
                merged_row_count = 0
            else:
                agg_df = (
                    inc_df
                    .groupBy("cluster_id", "currency", "cost_incurred_date")
                    .agg(F.sum("cost").alias("cloud_cost"))
                    .withColumn("created_at", F.current_timestamp())
                    .withColumn("updated_at", F.current_timestamp())
                )

                merged_row_count = agg_df.count()

                agg_df.createOrReplaceTempView("cloud_cost_inc")

                spark.sql(f"""
                MERGE INTO {target_table} AS t
                USING cloud_cost_inc AS s
                ON  t.cluster_id = s.cluster_id
                AND t.currency = s.currency
                AND t.cost_incurred_date = s.cost_incurred_date
                WHEN MATCHED THEN
                  UPDATE SET
                    t.cloud_cost = s.cloud_cost,
                    t.updated_at = current_timestamp()
                WHEN NOT MATCHED THEN
                  INSERT (cluster_id, cloud_cost, currency, cost_incurred_date, created_at, updated_at)
                  VALUES (s.cluster_id, s.cloud_cost, s.currency, s.cost_incurred_date,
                          current_timestamp(), current_timestamp())
                """)

        print(f"Merged {merged_row_count} rows into {target_table} for {start_dt} → {end_dt} (overlap_days={overlap_days}).")

        # SUCCESS ENTRY (REQUIRED FOR INCREMENTAL)
        run_log_df = spark.createDataFrame([
            Row(
                table_name="dbspend360_cloud_cost_explorer",
                start_date=start_dt,
                end_date=end_dt,
                status="SUCCESS",
                row_count=int(merged_row_count),
                message=f"overlap_days={overlap_days}",
                created_at=datetime.now(timezone.utc),
            )
        ])

        run_log_df.write.mode("append").insertInto(audit_table)

In [None]:
# =======================================================
# Execute
# =======================================================
app = AzureCostReporterApp()
app.run()